mirror of
https://github.com/huggingface/diffusers.git
synced 2026-03-26 10:28:21 +08:00
Compare commits
2 Commits
main
...
fix-torcha
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4e01e02395 | ||
|
|
5e5b575fb3 |
@@ -1,11 +0,0 @@
|
||||
# PR Review Rules
|
||||
|
||||
Review-specific rules for Claude. Focus on correctness — style is handled by ruff.
|
||||
|
||||
Before reviewing, read and apply the guidelines in:
|
||||
- [AGENTS.md](AGENTS.md) — coding style, dependencies, copied code, model conventions
|
||||
- [skills/model-integration/SKILL.md](skills/model-integration/SKILL.md) — attention pattern, pipeline rules, implementation checklist, gotchas
|
||||
- [skills/parity-testing/SKILL.md](skills/parity-testing/SKILL.md) — testing rules, comparison utilities
|
||||
- [skills/parity-testing/pitfalls.md](skills/parity-testing/pitfalls.md) — known pitfalls (dtype mismatches, config assumptions, etc.)
|
||||
|
||||
## Common mistakes (add new rules below this line)
|
||||
38
.github/workflows/claude_review.yml
vendored
38
.github/workflows/claude_review.yml
vendored
@@ -1,38 +0,0 @@
|
||||
name: Claude PR Review
|
||||
|
||||
on:
|
||||
issue_comment:
|
||||
types: [created]
|
||||
pull_request_review_comment:
|
||||
types: [created]
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
issues: read
|
||||
|
||||
jobs:
|
||||
claude-review:
|
||||
if: |
|
||||
(
|
||||
github.event_name == 'issue_comment' &&
|
||||
github.event.issue.pull_request &&
|
||||
github.event.issue.state == 'open' &&
|
||||
contains(github.event.comment.body, '@claude') &&
|
||||
(github.event.comment.author_association == 'MEMBER' ||
|
||||
github.event.comment.author_association == 'OWNER' ||
|
||||
github.event.comment.author_association == 'COLLABORATOR')
|
||||
) || (
|
||||
github.event_name == 'pull_request_review_comment' &&
|
||||
contains(github.event.comment.body, '@claude') &&
|
||||
(github.event.comment.author_association == 'MEMBER' ||
|
||||
github.event.comment.author_association == 'OWNER' ||
|
||||
github.event.comment.author_association == 'COLLABORATOR')
|
||||
)
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: anthropics/claude-code-action@v1
|
||||
with:
|
||||
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
claude_args: |
|
||||
--append-system-prompt "Review this PR against the rules in .ai/review-rules.md. Focus on correctness, not style (ruff handles style). Only review changes under src/diffusers/. Do NOT commit changes unless the comment explicitly asks you to using the phrase 'commit this'."
|
||||
2
.github/workflows/nightly_tests.yml
vendored
2
.github/workflows/nightly_tests.yml
vendored
@@ -341,7 +341,7 @@ jobs:
|
||||
additional_deps: ["peft", "kernels"]
|
||||
- backend: "torchao"
|
||||
test_location: "torchao"
|
||||
additional_deps: []
|
||||
additional_deps: [mslk-cuda]
|
||||
- backend: "optimum_quanto"
|
||||
test_location: "quanto"
|
||||
additional_deps: []
|
||||
|
||||
@@ -670,10 +670,6 @@
|
||||
- local: api/pipelines/z_image
|
||||
title: Z-Image
|
||||
title: Image
|
||||
- sections:
|
||||
- local: api/pipelines/llada2
|
||||
title: LLaDA2
|
||||
title: Text
|
||||
- sections:
|
||||
- local: api/pipelines/allegro
|
||||
title: Allegro
|
||||
@@ -722,8 +718,6 @@
|
||||
- sections:
|
||||
- local: api/schedulers/overview
|
||||
title: Overview
|
||||
- local: api/schedulers/block_refinement
|
||||
title: BlockRefinementScheduler
|
||||
- local: api/schedulers/cm_stochastic_iterative
|
||||
title: CMStochasticIterativeScheduler
|
||||
- local: api/schedulers/ddim_cogvideox
|
||||
|
||||
@@ -1,90 +0,0 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# LLaDA2
|
||||
|
||||
[LLaDA2](https://huggingface.co/collections/inclusionAI/llada21) is a family of discrete diffusion language models
|
||||
that generate text through block-wise iterative refinement. Instead of autoregressive token-by-token generation,
|
||||
LLaDA2 starts with a fully masked sequence and progressively unmasks tokens by confidence over multiple refinement
|
||||
steps.
|
||||
|
||||
## Usage
|
||||
|
||||
```py
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from diffusers import BlockRefinementScheduler, LLaDA2Pipeline
|
||||
|
||||
model_id = "inclusionAI/LLaDA2.1-mini"
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id, trust_remote_code=True, dtype=torch.bfloat16, device_map="auto"
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
|
||||
scheduler = BlockRefinementScheduler()
|
||||
|
||||
pipe = LLaDA2Pipeline(model=model, scheduler=scheduler, tokenizer=tokenizer)
|
||||
output = pipe(
|
||||
prompt="Write a short poem about the ocean.",
|
||||
gen_length=256,
|
||||
block_length=32,
|
||||
num_inference_steps=32,
|
||||
threshold=0.7,
|
||||
editing_threshold=0.5,
|
||||
max_post_steps=16,
|
||||
temperature=0.0,
|
||||
)
|
||||
print(output.texts[0])
|
||||
```
|
||||
|
||||
## Callbacks
|
||||
|
||||
Callbacks run after each refinement step. Pass `callback_on_step_end_tensor_inputs` to select which tensors are
|
||||
included in `callback_kwargs`. In the current implementation, `block_x` (the sequence window being refined) and
|
||||
`transfer_index` (mask-filling commit mask) are provided; return `{"block_x": ...}` from the callback to replace the
|
||||
window.
|
||||
|
||||
```py
|
||||
def on_step_end(pipe, step, timestep, callback_kwargs):
|
||||
block_x = callback_kwargs["block_x"]
|
||||
# Inspect or modify `block_x` here.
|
||||
return {"block_x": block_x}
|
||||
|
||||
out = pipe(
|
||||
prompt="Write a short poem.",
|
||||
callback_on_step_end=on_step_end,
|
||||
callback_on_step_end_tensor_inputs=["block_x"],
|
||||
)
|
||||
```
|
||||
|
||||
## Recommended parameters
|
||||
|
||||
LLaDA2.1 models support two modes:
|
||||
|
||||
| Mode | `threshold` | `editing_threshold` | `max_post_steps` |
|
||||
|------|-------------|---------------------|------------------|
|
||||
| Quality | 0.7 | 0.5 | 16 |
|
||||
| Speed | 0.5 | `None` | 16 |
|
||||
|
||||
Pass `editing_threshold=None`, `0.0`, or a negative value to turn off post-mask editing.
|
||||
|
||||
For LLaDA2.0 models, disable editing by passing `editing_threshold=None` or `0.0`.
|
||||
|
||||
For all models: `block_length=32`, `temperature=0.0`, `num_inference_steps=32`.
|
||||
|
||||
## LLaDA2Pipeline
|
||||
[[autodoc]] LLaDA2Pipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## LLaDA2PipelineOutput
|
||||
[[autodoc]] pipelines.LLaDA2PipelineOutput
|
||||
@@ -63,7 +63,6 @@ The table below lists all the pipelines currently available in 🤗 Diffusers an
|
||||
| [Latent Diffusion](latent_diffusion) | text2image, super-resolution |
|
||||
| [Latte](latte) | text2image |
|
||||
| [LEDITS++](ledits_pp) | image editing |
|
||||
| [LLaDA2](llada2) | text2text |
|
||||
| [Lumina-T2X](lumina) | text2image |
|
||||
| [Marigold](marigold) | depth-estimation, normals-estimation, intrinsic-decomposition |
|
||||
| [MultiDiffusion](panorama) | text2image |
|
||||
|
||||
@@ -1,25 +0,0 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# BlockRefinementScheduler
|
||||
|
||||
The `BlockRefinementScheduler` manages block-wise iterative refinement for discrete token diffusion. At each step it
|
||||
commits the most confident tokens and optionally edits already-committed tokens when the model predicts a different
|
||||
token with high confidence.
|
||||
|
||||
This scheduler is used by [`LLaDA2Pipeline`].
|
||||
|
||||
## BlockRefinementScheduler
|
||||
[[autodoc]] BlockRefinementScheduler
|
||||
|
||||
## BlockRefinementSchedulerOutput
|
||||
[[autodoc]] schedulers.scheduling_block_refinement.BlockRefinementSchedulerOutput
|
||||
@@ -248,24 +248,6 @@ Refer to the [diffusers/benchmarks](https://huggingface.co/datasets/diffusers/be
|
||||
|
||||
The [diffusers-torchao](https://github.com/sayakpaul/diffusers-torchao#benchmarking-results) repository also contains benchmarking results for compiled versions of Flux and CogVideoX.
|
||||
|
||||
## Kernels
|
||||
|
||||
[Kernels](https://huggingface.co/docs/kernels/index) is a library for building, distributing, and loading optimized compute kernels on the [Hub](https://huggingface.co/kernels-community). It supports [attention](./attention_backends#set_attention_backend) kernels and custom CUDA kernels for operations like RMSNorm, GEGLU, RoPE, and AdaLN.
|
||||
|
||||
The [Diffusers Pipeline Integration](https://github.com/huggingface/kernels/blob/main/skills/cuda-kernels/references/diffusers-integration.md) guide shows how to integrate a kernel with the [add cuda-kernels](https://github.com/huggingface/kernels/blob/main/skills/cuda-kernels/SKILL.md) skill. This skill enables an agent, like Claude or Codex, to write custom kernels targeted towards a specific model and your hardware.
|
||||
|
||||
> [!TIP]
|
||||
> Install the [add cuda-kernels](https://github.com/huggingface/kernels/blob/main/skills/cuda-kernels/SKILL.md) skill to teach an agent how to write a kernel. The [Custom kernels for all from Codex and Claude](https://huggingface.co/blog/custom-cuda-kernels-agent-skills) blog post covers this in more detail.
|
||||
|
||||
For example, a custom RMSNorm kernel (generated by the `add cuda-kernels` skill) with [torch.compile](#torchcompile) speeds up LTX-Video generation 1.43x on an H100.
|
||||
|
||||
<iframe
|
||||
src="https://huggingface.co/datasets/docs-benchmarks/kernel-ltx-video/embed/viewer/default/train"
|
||||
frameborder="0"
|
||||
width="100%"
|
||||
height="560px"
|
||||
></iframe>
|
||||
|
||||
## Dynamic quantization
|
||||
|
||||
[Dynamic quantization](https://pytorch.org/tutorials/recipes/recipes/dynamic_quantization.html) improves inference speed by reducing precision to enable faster math operations. This particular type of quantization determines how to scale the activations based on the data at runtime rather than using a fixed scaling factor. As a result, the scaling factor is more accurately aligned with the data.
|
||||
|
||||
@@ -1,50 +0,0 @@
|
||||
# Discrete Token Diffusion (Experimental)
|
||||
|
||||
This folder contains **training and sampling examples** for *discrete diffusion over token IDs* (language-model style), built to follow the `diffusers` + `accelerate` training conventions.
|
||||
|
||||
## LLaDA2
|
||||
|
||||
[LLaDA2](https://huggingface.co/collections/inclusionAI/llada21) generates text through block-wise iterative refinement. Instead of autoregressive token-by-token generation, it starts with a fully masked sequence and progressively unmasks tokens by confidence over multiple refinement steps.
|
||||
|
||||
### Train
|
||||
|
||||
The training script uses confidence-aware loss and works with any causal LM from the Hub (e.g. Qwen, Llama, Mistral):
|
||||
|
||||
```bash
|
||||
accelerate launch examples/discrete_diffusion/train_llada2.py \
|
||||
--model_name_or_path Qwen/Qwen2.5-0.5B \
|
||||
--dataset_name wikitext \
|
||||
--dataset_config_name wikitext-2-raw-v1 \
|
||||
--text_column text \
|
||||
--output_dir llada2-output \
|
||||
--max_train_steps 1000 \
|
||||
--prompt_length 32 \
|
||||
--block_length 32 \
|
||||
--lambda_conf 2.0 \
|
||||
--conf_temperature 0.5
|
||||
```
|
||||
|
||||
If you don't want to download a dataset, you can use random-token data:
|
||||
|
||||
```bash
|
||||
accelerate launch examples/discrete_diffusion/train_llada2.py \
|
||||
--model_name_or_path Qwen/Qwen2.5-0.5B \
|
||||
--output_dir llada2-output \
|
||||
--use_dummy_data \
|
||||
--num_dummy_samples 2048
|
||||
```
|
||||
|
||||
### Sample
|
||||
|
||||
```bash
|
||||
python examples/discrete_diffusion/sample_llada2.py \
|
||||
--model_id inclusionAI/LLaDA2.1-mini \
|
||||
--prompt "Write a short poem about the ocean." \
|
||||
--gen_length 256 \
|
||||
--num_inference_steps 32 \
|
||||
--threshold 0.7 \
|
||||
--editing_threshold 0.5 \
|
||||
--max_post_steps 16 \
|
||||
--use_chat_template \
|
||||
--add_generation_prompt
|
||||
```
|
||||
@@ -1,263 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Sample script for LLaDA2-style discrete diffusion text generation.
|
||||
|
||||
This script demonstrates how to use the LLaDA2Pipeline for text generation
|
||||
using block-wise iterative refinement.
|
||||
|
||||
Example usage:
|
||||
python sample_llada2.py --model_id inclusionAI/LLaDA2.0-mini --prompt "What is the capital of France?"
|
||||
python sample_llada2.py --model_id inclusionAI/LLaDA2.0-flash-CAP --prompt "Explain quantum computing." --temperature 0.7
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from diffusers import BlockRefinementScheduler, LLaDA2Pipeline
|
||||
from diffusers.hooks import apply_group_offloading
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Generate text using LLaDA2Pipeline with block-wise discrete diffusion."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_id",
|
||||
type=str,
|
||||
default="inclusionAI/LLaDA2.0-mini",
|
||||
help="HuggingFace model ID or path to local model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prompt",
|
||||
type=str,
|
||||
default="Why does Camus think that Sisyphus is happy?",
|
||||
help="Text prompt to generate from.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gen_length",
|
||||
type=int,
|
||||
default=2048,
|
||||
help="Number of tokens to generate.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--block_length",
|
||||
type=int,
|
||||
default=32,
|
||||
help="Size of each generation block.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_inference_steps",
|
||||
type=int,
|
||||
default=32,
|
||||
help="Number of refinement steps per block.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--temperature",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="Sampling temperature (0.0 for greedy).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--top_p",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Nucleus sampling probability threshold.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--top_k",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Top-k sampling parameter.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--threshold",
|
||||
type=float,
|
||||
default=0.95,
|
||||
help="Confidence threshold for committing tokens.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--editing_threshold",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Confidence threshold for editing already-committed tokens. Set to enable post-mask editing (e.g. 0.5).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_post_steps",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Maximum post-mask editing iterations per block (e.g. 16). Only used when --editing_threshold is set.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sampling_method",
|
||||
type=str,
|
||||
default="multinomial",
|
||||
choices=["auto", "greedy", "multinomial"],
|
||||
help="Sampling method for block refinement.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--eos_early_stop",
|
||||
action="store_true",
|
||||
help="Stop generation early when EOS token is generated.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_chat_template",
|
||||
action="store_true",
|
||||
help="Use the tokenizer chat template for the prompt.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--add_generation_prompt",
|
||||
action="store_true",
|
||||
help="Add the generation prompt when using the chat template.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default="cuda" if torch.cuda.is_available() else "cpu",
|
||||
help="Device to run inference on.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
type=str,
|
||||
default="bfloat16",
|
||||
choices=["float32", "float16", "bfloat16"],
|
||||
help="Model dtype.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Random seed for reproducibility.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--offload",
|
||||
type=str,
|
||||
default=None,
|
||||
choices=["group", "sequential"],
|
||||
help="Memory offloading strategy: 'group' for group offloading (faster), 'sequential' for sequential CPU offload (slower but lower memory).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--revision",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Model revision (branch, tag, or commit hash) to load from the Hub.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Parse dtype
|
||||
dtype_map = {
|
||||
"float32": torch.float32,
|
||||
"float16": torch.float16,
|
||||
"bfloat16": torch.bfloat16,
|
||||
}
|
||||
torch_dtype = dtype_map[args.dtype]
|
||||
|
||||
print(f"Loading model: {args.model_id}")
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model_id, trust_remote_code=True, revision=args.revision)
|
||||
|
||||
# Load model with appropriate memory settings based on offload strategy
|
||||
if args.offload == "group":
|
||||
# For group offloading, load to CPU first then apply hooks
|
||||
print("Using group offloading for memory efficiency...")
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
args.model_id,
|
||||
trust_remote_code=True,
|
||||
dtype=torch_dtype,
|
||||
low_cpu_mem_usage=True,
|
||||
revision=args.revision,
|
||||
)
|
||||
# Apply group offloading with CUDA streams for better performance
|
||||
onload_device = torch.device(args.device)
|
||||
offload_device = torch.device("cpu")
|
||||
apply_group_offloading(
|
||||
model,
|
||||
onload_device=onload_device,
|
||||
offload_device=offload_device,
|
||||
offload_type="leaf_level",
|
||||
use_stream=True,
|
||||
)
|
||||
elif args.offload == "sequential":
|
||||
# For sequential offloading, load to CPU first
|
||||
print("Using sequential CPU offloading (slower but lower memory)...")
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
args.model_id,
|
||||
trust_remote_code=True,
|
||||
dtype=torch_dtype,
|
||||
low_cpu_mem_usage=True,
|
||||
revision=args.revision,
|
||||
)
|
||||
# Sequential offloading will be applied via pipeline
|
||||
else:
|
||||
# Default: use device_map="auto" for automatic memory management
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
args.model_id,
|
||||
trust_remote_code=True,
|
||||
dtype=torch_dtype,
|
||||
device_map="auto",
|
||||
low_cpu_mem_usage=True,
|
||||
revision=args.revision,
|
||||
)
|
||||
model.eval()
|
||||
|
||||
# Create pipeline
|
||||
scheduler = BlockRefinementScheduler()
|
||||
pipe = LLaDA2Pipeline(model=model, scheduler=scheduler, tokenizer=tokenizer)
|
||||
|
||||
# Apply sequential CPU offload if requested
|
||||
if args.offload == "sequential":
|
||||
pipe.enable_sequential_cpu_offload()
|
||||
|
||||
# Set up generator for reproducibility
|
||||
generator = None
|
||||
if args.seed is not None:
|
||||
generator = torch.Generator(device=args.device).manual_seed(args.seed)
|
||||
|
||||
print(f"\nPrompt: {args.prompt}")
|
||||
print(
|
||||
f"Generating {args.gen_length} tokens with block_length={args.block_length}, steps={args.num_inference_steps}"
|
||||
)
|
||||
print("-" * 50)
|
||||
|
||||
# Generate
|
||||
output = pipe(
|
||||
prompt=args.prompt,
|
||||
use_chat_template=args.use_chat_template,
|
||||
add_generation_prompt=args.add_generation_prompt,
|
||||
gen_length=args.gen_length,
|
||||
block_length=args.block_length,
|
||||
num_inference_steps=args.num_inference_steps,
|
||||
temperature=args.temperature,
|
||||
top_p=args.top_p,
|
||||
top_k=args.top_k,
|
||||
threshold=args.threshold,
|
||||
editing_threshold=args.editing_threshold,
|
||||
max_post_steps=args.max_post_steps,
|
||||
sampling_method=args.sampling_method,
|
||||
eos_early_stop=args.eos_early_stop,
|
||||
generator=generator,
|
||||
)
|
||||
|
||||
print("\nGenerated text:")
|
||||
print(output.texts[0])
|
||||
|
||||
print(f"\nGenerated {output.sequences.shape[1]} tokens")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,321 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import math
|
||||
import os
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
from accelerate import Accelerator
|
||||
from accelerate.logging import get_logger
|
||||
from accelerate.utils import ProjectConfiguration, set_seed
|
||||
from datasets import load_dataset
|
||||
from torch.utils.data import DataLoader
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, DataCollatorForLanguageModeling, get_scheduler
|
||||
|
||||
from diffusers import BlockRefinementScheduler
|
||||
from diffusers.training_utils import compute_confidence_aware_loss
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainConfig:
|
||||
model_name_or_path: str
|
||||
dataset_name: str
|
||||
dataset_config_name: Optional[str]
|
||||
text_column: str
|
||||
cache_dir: Optional[str]
|
||||
use_dummy_data: bool
|
||||
num_dummy_samples: int
|
||||
|
||||
output_dir: str
|
||||
seed: int
|
||||
max_train_steps: int
|
||||
checkpointing_steps: int
|
||||
logging_steps: int
|
||||
|
||||
per_device_train_batch_size: int
|
||||
gradient_accumulation_steps: int
|
||||
learning_rate: float
|
||||
weight_decay: float
|
||||
lr_scheduler: str
|
||||
lr_warmup_steps: int
|
||||
|
||||
max_length: int
|
||||
prompt_length: int
|
||||
block_length: int
|
||||
|
||||
lambda_conf: float
|
||||
conf_temperature: float
|
||||
|
||||
|
||||
def parse_args() -> TrainConfig:
|
||||
parser = argparse.ArgumentParser(description="Train block-refinement with a confidence-aware loss on a causal LM.")
|
||||
|
||||
parser.add_argument("--model_name_or_path", type=str, default="Qwen/Qwen2.5-0.5B")
|
||||
parser.add_argument("--dataset_name", type=str, default="wikitext")
|
||||
parser.add_argument("--dataset_config_name", type=str, default="wikitext-2-raw-v1")
|
||||
parser.add_argument("--text_column", type=str, default="text")
|
||||
parser.add_argument("--cache_dir", type=str, default=None)
|
||||
parser.add_argument("--use_dummy_data", action="store_true", help="Use random-token data instead of downloading.")
|
||||
parser.add_argument("--num_dummy_samples", type=int, default=2048)
|
||||
|
||||
parser.add_argument("--output_dir", type=str, default="block-refinement-output")
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument("--max_train_steps", type=int, default=1000)
|
||||
parser.add_argument("--checkpointing_steps", type=int, default=500)
|
||||
parser.add_argument("--logging_steps", type=int, default=50)
|
||||
|
||||
parser.add_argument("--per_device_train_batch_size", type=int, default=1)
|
||||
parser.add_argument("--gradient_accumulation_steps", type=int, default=8)
|
||||
parser.add_argument("--learning_rate", type=float, default=2e-5)
|
||||
parser.add_argument("--weight_decay", type=float, default=0.0)
|
||||
parser.add_argument(
|
||||
"--lr_scheduler", type=str, default="cosine", choices=["linear", "cosine", "cosine_with_restarts"]
|
||||
)
|
||||
parser.add_argument("--lr_warmup_steps", type=int, default=100)
|
||||
|
||||
parser.add_argument("--max_length", type=int, default=256)
|
||||
parser.add_argument("--prompt_length", type=int, default=32)
|
||||
parser.add_argument("--block_length", type=int, default=32)
|
||||
|
||||
parser.add_argument("--lambda_conf", type=float, default=2.0)
|
||||
parser.add_argument("--conf_temperature", type=float, default=0.5)
|
||||
|
||||
args = parser.parse_args()
|
||||
return TrainConfig(**vars(args))
|
||||
|
||||
|
||||
def tokenize_fn(examples: Dict, tokenizer, text_column: str, max_length: int):
|
||||
texts = examples[text_column]
|
||||
texts = [t for t in texts if isinstance(t, str) and len(t.strip()) > 0]
|
||||
return tokenizer(texts, truncation=True, padding=False, max_length=max_length)
|
||||
|
||||
|
||||
class RandomTokenDataset(torch.utils.data.Dataset):
|
||||
def __init__(self, *, num_samples: int, seq_len: int, vocab_size: int, pad_token_id: int):
|
||||
self.num_samples = int(num_samples)
|
||||
self.seq_len = int(seq_len)
|
||||
self.vocab_size = int(vocab_size)
|
||||
self.pad_token_id = int(pad_token_id)
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
||||
def __getitem__(self, idx):
|
||||
del idx
|
||||
input_ids = torch.randint(0, self.vocab_size, (self.seq_len,), dtype=torch.long)
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
return {"input_ids": input_ids, "attention_mask": attention_mask}
|
||||
|
||||
|
||||
def main():
|
||||
cfg = parse_args()
|
||||
if cfg.prompt_length >= cfg.max_length:
|
||||
raise ValueError("`prompt_length` must be < `max_length`.")
|
||||
if cfg.block_length <= 0:
|
||||
raise ValueError("`block_length` must be > 0.")
|
||||
|
||||
project_config = ProjectConfiguration(project_dir=cfg.output_dir, logging_dir=os.path.join(cfg.output_dir, "logs"))
|
||||
accelerator = Accelerator(
|
||||
gradient_accumulation_steps=cfg.gradient_accumulation_steps,
|
||||
project_config=project_config,
|
||||
)
|
||||
if accelerator.is_main_process:
|
||||
os.makedirs(cfg.output_dir, exist_ok=True)
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
set_seed(cfg.seed)
|
||||
logger.info("Training configuration: %s", asdict(cfg))
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(cfg.model_name_or_path, use_fast=True, cache_dir=cfg.cache_dir)
|
||||
if tokenizer.pad_token_id is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
if tokenizer.mask_token_id is None:
|
||||
tokenizer.add_special_tokens({"mask_token": "[MASK]"})
|
||||
|
||||
load_dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
|
||||
model = AutoModelForCausalLM.from_pretrained(cfg.model_name_or_path, cache_dir=cfg.cache_dir, dtype=load_dtype)
|
||||
model.resize_token_embeddings(len(tokenizer))
|
||||
if load_dtype == torch.float32:
|
||||
model.to(dtype=torch.float32)
|
||||
|
||||
mask_token_id = int(tokenizer.mask_token_id)
|
||||
|
||||
if cfg.use_dummy_data:
|
||||
dataset = RandomTokenDataset(
|
||||
num_samples=cfg.num_dummy_samples,
|
||||
seq_len=cfg.max_length,
|
||||
vocab_size=len(tokenizer),
|
||||
pad_token_id=int(tokenizer.pad_token_id),
|
||||
)
|
||||
train_dataloader = DataLoader(
|
||||
dataset,
|
||||
shuffle=True,
|
||||
batch_size=cfg.per_device_train_batch_size,
|
||||
drop_last=True,
|
||||
)
|
||||
else:
|
||||
raw_datasets = load_dataset(cfg.dataset_name, cfg.dataset_config_name, cache_dir=cfg.cache_dir)
|
||||
if "train" not in raw_datasets:
|
||||
raise ValueError(f"Dataset {cfg.dataset_name} has no 'train' split.")
|
||||
|
||||
with accelerator.main_process_first():
|
||||
tokenized = raw_datasets["train"].map(
|
||||
lambda ex: tokenize_fn(ex, tokenizer, cfg.text_column, cfg.max_length),
|
||||
batched=True,
|
||||
remove_columns=raw_datasets["train"].column_names,
|
||||
desc="Tokenizing",
|
||||
)
|
||||
|
||||
collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False, return_tensors="pt")
|
||||
train_dataloader = DataLoader(
|
||||
tokenized, shuffle=True, collate_fn=collator, batch_size=cfg.per_device_train_batch_size, drop_last=True
|
||||
)
|
||||
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.learning_rate, weight_decay=cfg.weight_decay)
|
||||
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / cfg.gradient_accumulation_steps)
|
||||
num_train_epochs = math.ceil(cfg.max_train_steps / num_update_steps_per_epoch)
|
||||
|
||||
lr_scheduler = get_scheduler(
|
||||
name=cfg.lr_scheduler,
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=cfg.lr_warmup_steps,
|
||||
num_training_steps=cfg.max_train_steps,
|
||||
)
|
||||
|
||||
model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
model, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
|
||||
noise_scheduler = BlockRefinementScheduler(block_length=cfg.block_length)
|
||||
|
||||
global_step = 0
|
||||
model.train()
|
||||
|
||||
for _epoch in range(num_train_epochs):
|
||||
for batch in train_dataloader:
|
||||
with accelerator.accumulate(model):
|
||||
input_ids = batch["input_ids"]
|
||||
attention_mask = batch.get("attention_mask", torch.ones_like(input_ids))
|
||||
|
||||
gen = torch.Generator(device=input_ids.device).manual_seed(cfg.seed + global_step)
|
||||
noisy, noisy_rev, masked, masked_rev = noise_scheduler.add_noise(
|
||||
input_ids,
|
||||
attention_mask,
|
||||
prompt_length=cfg.prompt_length,
|
||||
block_length=cfg.block_length,
|
||||
mask_token_id=mask_token_id,
|
||||
generator=gen,
|
||||
)
|
||||
|
||||
position_ids = (
|
||||
torch.arange(input_ids.shape[1], device=input_ids.device).unsqueeze(0).expand_as(input_ids)
|
||||
)
|
||||
|
||||
logits = model(input_ids=noisy, attention_mask=attention_mask, position_ids=position_ids).logits
|
||||
logits_rev = model(
|
||||
input_ids=noisy_rev, attention_mask=attention_mask, position_ids=position_ids
|
||||
).logits
|
||||
|
||||
logits = logits.clone()
|
||||
logits[..., mask_token_id] = torch.finfo(logits.dtype).min
|
||||
logits_rev = logits_rev.clone()
|
||||
logits_rev[..., mask_token_id] = torch.finfo(logits_rev.dtype).min
|
||||
|
||||
valid = attention_mask.to(dtype=torch.bool)
|
||||
masked = masked & valid
|
||||
masked_rev = masked_rev & valid
|
||||
|
||||
labels = input_ids.clone()
|
||||
labels[~masked] = -100
|
||||
labels_rev = input_ids.clone()
|
||||
labels_rev[~masked_rev] = -100
|
||||
|
||||
weights = masked.to(dtype=logits.dtype)
|
||||
weights_rev = masked_rev.to(dtype=logits.dtype)
|
||||
|
||||
loss, loss_sft, loss_conf = compute_confidence_aware_loss(
|
||||
logits,
|
||||
labels,
|
||||
lambda_conf=cfg.lambda_conf,
|
||||
temperature=cfg.conf_temperature,
|
||||
per_token_weights=weights,
|
||||
)
|
||||
loss_rev, loss_sft_rev, loss_conf_rev = compute_confidence_aware_loss(
|
||||
logits_rev,
|
||||
labels_rev,
|
||||
lambda_conf=cfg.lambda_conf,
|
||||
temperature=cfg.conf_temperature,
|
||||
per_token_weights=weights_rev,
|
||||
)
|
||||
|
||||
total_loss = loss + loss_rev
|
||||
accelerator.backward(total_loss)
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
if accelerator.sync_gradients:
|
||||
global_step += 1
|
||||
|
||||
if global_step % cfg.logging_steps == 0 and accelerator.is_main_process:
|
||||
logger.info(
|
||||
"step=%d loss=%.4f sft=%.4f conf=%.4f lr=%.6g",
|
||||
global_step,
|
||||
total_loss.item(),
|
||||
(loss_sft + loss_sft_rev).item(),
|
||||
(loss_conf + loss_conf_rev).item(),
|
||||
lr_scheduler.get_last_lr()[0],
|
||||
)
|
||||
print(
|
||||
f"step={global_step} loss={total_loss.item():.4f} "
|
||||
f"sft={(loss_sft + loss_sft_rev).item():.4f} "
|
||||
f"conf={(loss_conf + loss_conf_rev).item():.4f} "
|
||||
f"lr={lr_scheduler.get_last_lr()[0]:.6g}"
|
||||
)
|
||||
|
||||
if cfg.checkpointing_steps > 0 and global_step % cfg.checkpointing_steps == 0:
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
save_dir = os.path.join(cfg.output_dir, f"checkpoint-{global_step}")
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
accelerator.unwrap_model(model).save_pretrained(save_dir, save_function=accelerator.save)
|
||||
tokenizer.save_pretrained(save_dir)
|
||||
|
||||
if global_step >= cfg.max_train_steps:
|
||||
break
|
||||
|
||||
if global_step >= cfg.max_train_steps:
|
||||
break
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
final_dir = os.path.join(cfg.output_dir, "final")
|
||||
os.makedirs(final_dir, exist_ok=True)
|
||||
accelerator.unwrap_model(model).save_pretrained(final_dir, save_function=accelerator.save)
|
||||
tokenizer.save_pretrained(final_dir)
|
||||
|
||||
logger.info("Done.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -344,8 +344,6 @@ else:
|
||||
_import_structure["schedulers"].extend(
|
||||
[
|
||||
"AmusedScheduler",
|
||||
"BlockRefinementScheduler",
|
||||
"BlockRefinementSchedulerOutput",
|
||||
"CMStochasticIterativeScheduler",
|
||||
"CogVideoXDDIMScheduler",
|
||||
"CogVideoXDPMScheduler",
|
||||
@@ -582,8 +580,6 @@ else:
|
||||
"LDMTextToImagePipeline",
|
||||
"LEditsPPPipelineStableDiffusion",
|
||||
"LEditsPPPipelineStableDiffusionXL",
|
||||
"LLaDA2Pipeline",
|
||||
"LLaDA2PipelineOutput",
|
||||
"LongCatImageEditPipeline",
|
||||
"LongCatImagePipeline",
|
||||
"LTX2ConditionPipeline",
|
||||
@@ -1128,8 +1124,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .quantizers import DiffusersQuantizer
|
||||
from .schedulers import (
|
||||
AmusedScheduler,
|
||||
BlockRefinementScheduler,
|
||||
BlockRefinementSchedulerOutput,
|
||||
CMStochasticIterativeScheduler,
|
||||
CogVideoXDDIMScheduler,
|
||||
CogVideoXDPMScheduler,
|
||||
@@ -1345,8 +1339,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
LDMTextToImagePipeline,
|
||||
LEditsPPPipelineStableDiffusion,
|
||||
LEditsPPPipelineStableDiffusionXL,
|
||||
LLaDA2Pipeline,
|
||||
LLaDA2PipelineOutput,
|
||||
LongCatImageEditPipeline,
|
||||
LongCatImagePipeline,
|
||||
LTX2ConditionPipeline,
|
||||
|
||||
@@ -285,7 +285,6 @@ else:
|
||||
]
|
||||
)
|
||||
_import_structure["latte"] = ["LattePipeline"]
|
||||
_import_structure["llada2"] = ["LLaDA2Pipeline", "LLaDA2PipelineOutput"]
|
||||
_import_structure["ltx"] = [
|
||||
"LTXPipeline",
|
||||
"LTXImageToVideoPipeline",
|
||||
@@ -729,7 +728,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
LEditsPPPipelineStableDiffusion,
|
||||
LEditsPPPipelineStableDiffusionXL,
|
||||
)
|
||||
from .llada2 import LLaDA2Pipeline, LLaDA2PipelineOutput
|
||||
from .longcat_image import LongCatImageEditPipeline, LongCatImagePipeline
|
||||
from .ltx import (
|
||||
LTXConditionPipeline,
|
||||
|
||||
@@ -1,47 +0,0 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import (
|
||||
DIFFUSERS_SLOW_IMPORT,
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
get_objects_from_module,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
|
||||
|
||||
_dummy_objects = {}
|
||||
_import_structure = {}
|
||||
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils import dummy_torch_and_transformers_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["pipeline_llada2"] = ["LLaDA2Pipeline", "LLaDA2PipelineOutput"]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
|
||||
else:
|
||||
from .pipeline_llada2 import LLaDA2Pipeline, LLaDA2PipelineOutput
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(
|
||||
__name__,
|
||||
globals()["__file__"],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
)
|
||||
|
||||
for name, value in _dummy_objects.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
@@ -1,491 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable
|
||||
|
||||
import torch
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from ...schedulers import BlockRefinementScheduler
|
||||
from ...utils import BaseOutput, logging, replace_example_docstring
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```python
|
||||
>>> import torch
|
||||
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
>>> from diffusers import BlockRefinementScheduler, LLaDA2Pipeline
|
||||
|
||||
>>> model_id = "inclusionAI/LLaDA2.1-mini"
|
||||
>>> model = AutoModelForCausalLM.from_pretrained(
|
||||
... model_id, trust_remote_code=True, dtype=torch.bfloat16, device_map="auto"
|
||||
... )
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
|
||||
>>> scheduler = BlockRefinementScheduler()
|
||||
|
||||
>>> pipe = LLaDA2Pipeline(model=model, scheduler=scheduler, tokenizer=tokenizer)
|
||||
>>> output = pipe(prompt="What is the meaning of life?", gen_length=256)
|
||||
>>> print(output.texts[0])
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLaDA2PipelineOutput(BaseOutput):
|
||||
sequences: torch.LongTensor
|
||||
texts: list[str] | None = None
|
||||
|
||||
|
||||
class LLaDA2Pipeline(DiffusionPipeline):
|
||||
r"""
|
||||
Pipeline for LLaDA2-style discrete diffusion text generation via block-wise iterative refinement.
|
||||
|
||||
This pipeline maintains a template sequence filled with a `mask_token_id` and refines it in blocks. In each
|
||||
refinement step, it samples candidate tokens for the active block and commits a subset based on confidence.
|
||||
|
||||
The model is expected to accept an attention mask and `position_ids`, and to return logits of shape `[batch, seq,
|
||||
vocab_size]`.
|
||||
"""
|
||||
|
||||
model: Any
|
||||
scheduler: BlockRefinementScheduler
|
||||
tokenizer: Any
|
||||
|
||||
_callback_tensor_inputs = ["block_x", "x0", "x0_p", "transfer_index", "confidence", "active_block"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Any,
|
||||
scheduler: BlockRefinementScheduler,
|
||||
tokenizer: Any | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.register_modules(model=model, scheduler=scheduler, tokenizer=tokenizer)
|
||||
self.eos_token_id = getattr(self.tokenizer, "eos_token_id", None) if self.tokenizer is not None else None
|
||||
self.mask_token_id = getattr(self.tokenizer, "mask_token_id", None) if self.tokenizer is not None else None
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
# --- Prompt encoding ---
|
||||
|
||||
def _prepare_input_ids(
|
||||
self,
|
||||
*,
|
||||
prompt: str | list[str] | None,
|
||||
messages: list[dict[str, str]] | None,
|
||||
input_ids: torch.LongTensor | None,
|
||||
use_chat_template: bool,
|
||||
add_generation_prompt: bool,
|
||||
chat_template_kwargs: dict[str, Any] | None,
|
||||
) -> torch.LongTensor:
|
||||
"""Convert prompt/messages/input_ids to a [batch, seq] LongTensor."""
|
||||
if input_ids is not None:
|
||||
if input_ids.ndim == 1:
|
||||
input_ids = input_ids.unsqueeze(0)
|
||||
if input_ids.ndim != 2:
|
||||
raise ValueError(f"`input_ids` must be 2D, got shape {tuple(input_ids.shape)}.")
|
||||
if input_ids.dtype != torch.long:
|
||||
raise ValueError(f"`input_ids` must be int64 token IDs, got dtype={input_ids.dtype}.")
|
||||
return input_ids
|
||||
|
||||
if self.tokenizer is None:
|
||||
raise ValueError("Tokenizer is required when `input_ids` is not provided.")
|
||||
|
||||
if messages is not None and prompt is not None:
|
||||
raise ValueError("Provide either `prompt` or `messages`, not both.")
|
||||
if messages is None and prompt is None:
|
||||
raise ValueError("Provide one of `prompt`, `messages`, or `input_ids`.")
|
||||
|
||||
chat_template_kwargs = chat_template_kwargs or {}
|
||||
|
||||
if messages is not None:
|
||||
encoded = self.tokenizer.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=add_generation_prompt,
|
||||
tokenize=True,
|
||||
return_tensors="pt",
|
||||
return_dict=True,
|
||||
**chat_template_kwargs,
|
||||
)
|
||||
return encoded["input_ids"]
|
||||
|
||||
if use_chat_template and getattr(self.tokenizer, "chat_template", None):
|
||||
if isinstance(prompt, list):
|
||||
raise ValueError("`prompt` must be a string when `use_chat_template=True`.")
|
||||
encoded = self.tokenizer.apply_chat_template(
|
||||
[{"role": "user", "content": prompt}],
|
||||
add_generation_prompt=add_generation_prompt,
|
||||
tokenize=True,
|
||||
return_tensors="pt",
|
||||
return_dict=True,
|
||||
**chat_template_kwargs,
|
||||
)
|
||||
return encoded["input_ids"]
|
||||
|
||||
encoded = self.tokenizer(prompt, return_tensors="pt", padding=isinstance(prompt, list))
|
||||
return encoded["input_ids"]
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt: str | list[str] | None,
|
||||
messages: list[dict[str, str]] | None,
|
||||
input_ids: torch.LongTensor | None,
|
||||
gen_length: int,
|
||||
block_length: int,
|
||||
num_inference_steps: int,
|
||||
minimal_topk: int,
|
||||
threshold: float,
|
||||
sampling_method: str,
|
||||
output_type: str,
|
||||
callback_on_step_end: Callable | PipelineCallback | MultiPipelineCallbacks | None,
|
||||
callback_on_step_end_tensor_inputs: list[str] | None,
|
||||
):
|
||||
# Input source validation
|
||||
if prompt is None and messages is None and input_ids is None:
|
||||
raise ValueError("Provide one of `prompt`, `messages`, or `input_ids`.")
|
||||
if prompt is not None and messages is not None:
|
||||
raise ValueError("Provide either `prompt` or `messages`, not both.")
|
||||
if input_ids is not None:
|
||||
if input_ids.ndim not in (1, 2):
|
||||
raise ValueError(f"`input_ids` must be 1D or 2D, got shape {tuple(input_ids.shape)}.")
|
||||
if input_ids.dtype != torch.long:
|
||||
raise ValueError(f"`input_ids` must be int64 token IDs, got dtype={input_ids.dtype}.")
|
||||
if prompt is not None and input_ids is None and self.tokenizer is None:
|
||||
raise ValueError("Tokenizer is required when `input_ids` is not provided.")
|
||||
if messages is not None and input_ids is None and self.tokenizer is None:
|
||||
raise ValueError("Tokenizer is required when `input_ids` is not provided.")
|
||||
|
||||
# Generation parameter validation
|
||||
if gen_length <= 0:
|
||||
raise ValueError(f"`gen_length` must be > 0, got {gen_length}.")
|
||||
if block_length <= 0:
|
||||
raise ValueError(f"`block_length` must be > 0, got {block_length}.")
|
||||
if num_inference_steps <= 0:
|
||||
raise ValueError(f"`num_inference_steps` must be > 0, got {num_inference_steps}.")
|
||||
if minimal_topk <= 0:
|
||||
raise ValueError(f"`minimal_topk` must be > 0, got {minimal_topk}.")
|
||||
if not (0.0 <= threshold <= 1.0) and not (threshold > 1.0):
|
||||
raise ValueError(f"`threshold` must be in [0, 1] (or > 1 to force top-k commits), got {threshold}.")
|
||||
if sampling_method not in {"auto", "greedy", "multinomial"}:
|
||||
raise ValueError(
|
||||
f"`sampling_method` must be one of {{'auto','greedy','multinomial'}}, got {sampling_method!r}."
|
||||
)
|
||||
if output_type not in {"seq", "text"}:
|
||||
raise ValueError(f"`output_type` must be 'seq' or 'text', got {output_type!r}.")
|
||||
|
||||
# Callback validation
|
||||
if callback_on_step_end is not None and isinstance(
|
||||
callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)
|
||||
):
|
||||
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found "
|
||||
f"{[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: str | list[str] | None = None,
|
||||
messages: list[dict[str, str]] | None = None,
|
||||
input_ids: torch.LongTensor | None = None,
|
||||
use_chat_template: bool = True,
|
||||
add_generation_prompt: bool = True,
|
||||
gen_length: int = 2048,
|
||||
block_length: int = 32,
|
||||
num_inference_steps: int = 32,
|
||||
temperature: float = 0.0,
|
||||
top_p: float | None = None,
|
||||
top_k: int | None = None,
|
||||
sampling_method: str = "multinomial",
|
||||
threshold: float = 0.7,
|
||||
editing_threshold: float | None = 0.5,
|
||||
max_post_steps: int = 16,
|
||||
minimal_topk: int = 1,
|
||||
eos_early_stop: bool = True,
|
||||
eos_token_id: int | None = None,
|
||||
mask_token_id: int | None = None,
|
||||
generator: torch.Generator | None = None,
|
||||
output_type: str = "text",
|
||||
return_dict: bool = True,
|
||||
callback_on_step_end: Callable[[int, int, dict], None]
|
||||
| PipelineCallback
|
||||
| MultiPipelineCallbacks
|
||||
| None = None,
|
||||
callback_on_step_end_tensor_inputs: list[str] | None = None,
|
||||
) -> LLaDA2PipelineOutput | tuple[torch.LongTensor, list[str] | None]:
|
||||
"""
|
||||
Generate text with block-wise refinement.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
Prompt text. When `use_chat_template` is `True` (default) and a tokenizer with a chat template is
|
||||
available, the prompt is wrapped in a chat message before tokenization.
|
||||
messages (`List[Dict[str, str]]`, *optional*):
|
||||
Chat messages to encode (e.g. `[{"role": "user", "content": "Hello"}]`). Takes precedence over `prompt`
|
||||
when provided. Requires a tokenizer with `apply_chat_template`.
|
||||
input_ids (`torch.LongTensor`, *optional*):
|
||||
Pre-tokenized input IDs. Takes precedence over `prompt` and `messages`.
|
||||
use_chat_template (`bool`, defaults to `True`):
|
||||
Whether to wrap the prompt in a chat template.
|
||||
add_generation_prompt (`bool`, defaults to `True`):
|
||||
Whether to add the generation prompt when using chat templates.
|
||||
gen_length (`int`):
|
||||
Number of tokens to generate.
|
||||
block_length (`int`):
|
||||
Block size for refinement.
|
||||
num_inference_steps (`int`):
|
||||
Number of refinement steps per block.
|
||||
temperature (`float`):
|
||||
Sampling temperature.
|
||||
top_p (`float`, *optional*):
|
||||
Nucleus sampling cutoff.
|
||||
top_k (`int`, *optional*):
|
||||
Top-k sampling cutoff.
|
||||
sampling_method (`str`):
|
||||
Sampling method (`auto`, `greedy`, `multinomial`).
|
||||
threshold (`float`):
|
||||
Confidence threshold for committing tokens.
|
||||
editing_threshold (`float`, *optional*):
|
||||
Confidence threshold for editing already-committed (non-mask) tokens. When positive, after all mask
|
||||
tokens in a block are resolved, the pipeline continues refining: if the model predicts a different
|
||||
token with confidence above this threshold, the existing token is replaced. Set to `None`, `0.0`, or a
|
||||
negative value to disable editing. Defaults to `0.5`.
|
||||
max_post_steps (`int`):
|
||||
Maximum number of additional refinement iterations after all mask tokens in a block are resolved. Only
|
||||
used when `editing_threshold` is enabled. Defaults to `16`.
|
||||
minimal_topk (`int`):
|
||||
Minimum number of tokens to commit per step.
|
||||
eos_early_stop (`bool`):
|
||||
Whether to stop after committing EOS in a block.
|
||||
eos_token_id (`int`, *optional*):
|
||||
EOS token ID to use for early stopping.
|
||||
mask_token_id (`int`, *optional*):
|
||||
Mask token ID to use for the template.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
RNG for sampling.
|
||||
output_type (`str`, defaults to `"text"`):
|
||||
Output format. `"text"` decodes sequences into strings (requires a tokenizer). `"seq"` returns raw
|
||||
token ID sequences only.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether to return a [`LLaDA2PipelineOutput`] instead of a tuple.
|
||||
callback_on_step_end (`Callable` or `PipelineCallback`, *optional*):
|
||||
Callback executed after each refinement step with signature `callback_on_step_end(self, step: int,
|
||||
timestep: int, callback_kwargs: Dict)`.
|
||||
callback_on_step_end_tensor_inputs (`List[str]`, *optional*):
|
||||
Tensor keys to pass to the callback. Allowed keys: `block_x`, `x0`, `x0_p`, `transfer_index`,
|
||||
`confidence`, `active_block`.
|
||||
|
||||
Examples:
|
||||
"""
|
||||
# 1. Check inputs early
|
||||
if callback_on_step_end is not None and isinstance(
|
||||
callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)
|
||||
):
|
||||
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
||||
if callback_on_step_end_tensor_inputs is None:
|
||||
callback_on_step_end_tensor_inputs = ["block_x"]
|
||||
|
||||
self.check_inputs(
|
||||
prompt=prompt,
|
||||
messages=messages,
|
||||
input_ids=input_ids,
|
||||
gen_length=gen_length,
|
||||
block_length=block_length,
|
||||
num_inference_steps=num_inference_steps,
|
||||
minimal_topk=minimal_topk,
|
||||
threshold=threshold,
|
||||
sampling_method=sampling_method,
|
||||
output_type=output_type,
|
||||
callback_on_step_end=callback_on_step_end,
|
||||
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
||||
)
|
||||
|
||||
# 2. Prepare input IDs from prompt/messages/input_ids
|
||||
prompt_ids = self._prepare_input_ids(
|
||||
prompt=prompt,
|
||||
messages=messages,
|
||||
input_ids=input_ids,
|
||||
use_chat_template=use_chat_template,
|
||||
add_generation_prompt=add_generation_prompt,
|
||||
chat_template_kwargs=None,
|
||||
)
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
if prompt_ids.ndim == 1:
|
||||
prompt_ids = prompt_ids.unsqueeze(0)
|
||||
prompt_ids = prompt_ids.to(device=device)
|
||||
batch_size, prompt_length = prompt_ids.shape
|
||||
|
||||
if eos_token_id is None:
|
||||
eos_token_id = self.eos_token_id
|
||||
if mask_token_id is None:
|
||||
mask_token_id = self.mask_token_id
|
||||
if mask_token_id is None:
|
||||
raise ValueError("`mask_token_id` must be provided (or available on the tokenizer).")
|
||||
|
||||
num_inference_steps = min(num_inference_steps, gen_length // minimal_topk)
|
||||
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
|
||||
# 3. Build attention mask and position IDs
|
||||
num_blocks = (prompt_length + gen_length + block_length - 1) // block_length
|
||||
total_length = num_blocks * block_length
|
||||
|
||||
# 2D attention mask (no padding) — the model handles backend-specific conversion internally.
|
||||
attn_mask = torch.ones((batch_size, total_length), device=device, dtype=torch.long)
|
||||
|
||||
position_ids = torch.arange(total_length, device=device, dtype=torch.long).unsqueeze(0).expand(batch_size, -1)
|
||||
|
||||
# 4. Prepare latents (fully masked sequence)
|
||||
x = torch.full((batch_size, total_length), mask_token_id, device=device, dtype=torch.long)
|
||||
if prompt_length > 0:
|
||||
x[:, :prompt_length] = prompt_ids
|
||||
|
||||
prefill_blocks = prompt_length // block_length
|
||||
self._num_timesteps = num_inference_steps * max(num_blocks - prefill_blocks, 0)
|
||||
|
||||
finished = torch.zeros((batch_size,), device=device, dtype=torch.bool)
|
||||
editing_enabled = editing_threshold is not None and editing_threshold > 0.0
|
||||
global_step = 0
|
||||
|
||||
# 5. Block-wise refinement loop
|
||||
block_progress_bar_config = getattr(self, "_progress_bar_config", {}).copy()
|
||||
block_progress_bar_config["position"] = 0
|
||||
block_progress_bar_config["desc"] = "Blocks"
|
||||
for num_block in tqdm(range(prefill_blocks, num_blocks), **block_progress_bar_config):
|
||||
current_window_end = (num_block + 1) * block_length
|
||||
block_x = x[:, :current_window_end]
|
||||
block_attn_mask = attn_mask[:, :current_window_end]
|
||||
block_position_ids = position_ids[:, :current_window_end]
|
||||
|
||||
# Identify which positions in the block are prompt (non-editable).
|
||||
block_start_pos = num_block * block_length
|
||||
prompt_mask_in_block = torch.zeros(block_length, device=device, dtype=torch.bool)
|
||||
if block_start_pos < prompt_length:
|
||||
prompt_end_in_block = min(prompt_length - block_start_pos, block_length)
|
||||
prompt_mask_in_block[:prompt_end_in_block] = True
|
||||
|
||||
post_steps = 0
|
||||
step_idx = 0
|
||||
should_continue = True
|
||||
self.set_progress_bar_config(position=1, leave=False, desc=f"Block {num_block} Inference Steps")
|
||||
progress_bar = self.progress_bar(total=num_inference_steps)
|
||||
|
||||
while should_continue:
|
||||
block_tokens = block_x[:, -block_length:]
|
||||
masks_remaining = (block_tokens == mask_token_id).any()
|
||||
|
||||
if not masks_remaining:
|
||||
post_steps += 1
|
||||
|
||||
logits = self.model(block_x, attention_mask=block_attn_mask, position_ids=block_position_ids).logits
|
||||
block_logits = logits[:, -block_length:, :]
|
||||
|
||||
scheduler_output = self.scheduler.step(
|
||||
model_output=block_logits,
|
||||
timestep=step_idx,
|
||||
sample=block_tokens,
|
||||
mask_token_id=mask_token_id,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
sampling_method=sampling_method,
|
||||
threshold=threshold,
|
||||
editing_threshold=editing_threshold,
|
||||
minimal_topk=minimal_topk,
|
||||
prompt_mask=prompt_mask_in_block,
|
||||
generator=generator,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
transfer_index = scheduler_output.transfer_index
|
||||
editing_transfer_index = scheduler_output.editing_transfer_index
|
||||
final_transfer = transfer_index | editing_transfer_index
|
||||
|
||||
if final_transfer.any():
|
||||
block_x[:, -block_length:] = scheduler_output.prev_sample
|
||||
|
||||
if eos_early_stop and eos_token_id is not None:
|
||||
finished = self.scheduler.check_eos_finished(
|
||||
cur_x=block_x,
|
||||
sampled_tokens=scheduler_output.sampled_tokens,
|
||||
final_transfer=final_transfer,
|
||||
finished=finished,
|
||||
eos_token_id=eos_token_id,
|
||||
mask_token_id=mask_token_id,
|
||||
prompt_length=prompt_length,
|
||||
)
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, global_step, step_idx, callback_kwargs)
|
||||
block_x = callback_outputs.pop("block_x", block_x)
|
||||
|
||||
global_step += 1
|
||||
if masks_remaining:
|
||||
step_idx += 1
|
||||
progress_bar.update(1)
|
||||
|
||||
should_continue = self.scheduler.check_block_should_continue(
|
||||
step_idx=step_idx,
|
||||
masks_remaining=masks_remaining,
|
||||
editing_enabled=editing_enabled,
|
||||
editing_transfer_index=editing_transfer_index,
|
||||
post_steps=post_steps,
|
||||
max_post_steps=max_post_steps,
|
||||
finished=finished,
|
||||
)
|
||||
|
||||
progress_bar.close()
|
||||
x[:, :current_window_end] = block_x
|
||||
if eos_early_stop and finished.all():
|
||||
break
|
||||
|
||||
# 6. Post-process output
|
||||
generated = x[:, : prompt_length + gen_length]
|
||||
sequences = generated[:, prompt_length:]
|
||||
if eos_token_id is not None and batch_size == 1:
|
||||
eos_positions = (sequences[0] == eos_token_id).nonzero(as_tuple=True)[0]
|
||||
if len(eos_positions) > 0:
|
||||
sequences = sequences[:, : int(eos_positions[0].item()) + 1]
|
||||
|
||||
texts = None
|
||||
if output_type == "text" and self.tokenizer is not None:
|
||||
texts = self.tokenizer.batch_decode(sequences, skip_special_tokens=True)
|
||||
|
||||
if not return_dict:
|
||||
return sequences.to(device=device), texts
|
||||
return LLaDA2PipelineOutput(sequences=sequences.to(device=device), texts=texts)
|
||||
|
||||
|
||||
__all__ = ["LLaDA2Pipeline", "LLaDA2PipelineOutput"]
|
||||
@@ -40,7 +40,6 @@ except OptionalDependencyNotAvailable:
|
||||
else:
|
||||
_import_structure["deprecated"] = ["KarrasVeScheduler", "ScoreSdeVpScheduler"]
|
||||
_import_structure["scheduling_amused"] = ["AmusedScheduler"]
|
||||
_import_structure["scheduling_block_refinement"] = ["BlockRefinementScheduler", "BlockRefinementSchedulerOutput"]
|
||||
_import_structure["scheduling_consistency_decoder"] = ["ConsistencyDecoderScheduler"]
|
||||
_import_structure["scheduling_consistency_models"] = ["CMStochasticIterativeScheduler"]
|
||||
_import_structure["scheduling_ddim"] = ["DDIMScheduler"]
|
||||
@@ -146,7 +145,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
else:
|
||||
from .deprecated import KarrasVeScheduler, ScoreSdeVpScheduler
|
||||
from .scheduling_amused import AmusedScheduler
|
||||
from .scheduling_block_refinement import BlockRefinementScheduler, BlockRefinementSchedulerOutput
|
||||
from .scheduling_consistency_decoder import ConsistencyDecoderScheduler
|
||||
from .scheduling_consistency_models import CMStochasticIterativeScheduler
|
||||
from .scheduling_ddim import DDIMScheduler
|
||||
|
||||
@@ -1,460 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import BaseOutput
|
||||
from .scheduling_utils import SchedulerMixin
|
||||
|
||||
|
||||
@dataclass
|
||||
class BlockRefinementSchedulerOutput(BaseOutput):
|
||||
"""
|
||||
Output class for block refinement scheduling.
|
||||
|
||||
Args:
|
||||
prev_sample (`torch.LongTensor` of shape `(batch_size, block_length)`):
|
||||
Updated block tokens after the current refinement step.
|
||||
transfer_index (`torch.BoolTensor` of shape `(batch_size, block_length)`):
|
||||
Boolean mask indicating which tokens were committed (mask-filling).
|
||||
editing_transfer_index (`torch.BoolTensor` of shape `(batch_size, block_length)`):
|
||||
Boolean mask indicating which tokens were edited (non-mask replacement).
|
||||
sampled_tokens (`torch.LongTensor` of shape `(batch_size, block_length)`):
|
||||
Sampled token IDs from the model logits.
|
||||
sampled_probs (`torch.Tensor` of shape `(batch_size, block_length)`):
|
||||
Probabilities of the sampled tokens.
|
||||
"""
|
||||
|
||||
prev_sample: torch.LongTensor
|
||||
transfer_index: torch.BoolTensor
|
||||
editing_transfer_index: torch.BoolTensor
|
||||
sampled_tokens: torch.LongTensor
|
||||
sampled_probs: torch.Tensor
|
||||
|
||||
|
||||
class BlockRefinementScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
Scheduler for block-wise iterative refinement (commit-by-confidence).
|
||||
|
||||
At each step, the scheduler samples candidate tokens from model logits and commits those with the highest
|
||||
confidence. The number of tokens to commit per step is determined by evenly distributing the block length across
|
||||
the number of refinement steps.
|
||||
|
||||
Optionally supports editing: after all mask tokens are resolved, tokens can be replaced if the model predicts a
|
||||
different token with confidence above a positive `editing_threshold` (`None`, `0.0`, or negative disables editing).
|
||||
"""
|
||||
|
||||
order = 1
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
block_length: int = 32,
|
||||
num_inference_steps: int = 32,
|
||||
threshold: float = 0.95,
|
||||
editing_threshold: float | None = None,
|
||||
minimal_topk: int = 1,
|
||||
):
|
||||
self.num_inference_steps = num_inference_steps
|
||||
self.timesteps = torch.arange(self.num_inference_steps - 1, -1, -1, dtype=torch.long)
|
||||
self._transfer_schedule: torch.LongTensor | None = None
|
||||
|
||||
def set_timesteps(self, num_inference_steps: int, device: str | torch.device | None = None) -> None:
|
||||
if num_inference_steps <= 0:
|
||||
raise ValueError(f"`num_inference_steps` must be > 0, got {num_inference_steps}.")
|
||||
self.num_inference_steps = num_inference_steps
|
||||
self.timesteps = torch.arange(self.num_inference_steps - 1, -1, -1, device=device, dtype=torch.long)
|
||||
self._transfer_schedule = self.get_num_transfer_tokens(self.config.block_length, self.num_inference_steps).to(
|
||||
device=device if device is not None else "cpu"
|
||||
)
|
||||
|
||||
def get_num_transfer_tokens(self, block_length: int, num_inference_steps: int) -> torch.LongTensor:
|
||||
"""Evenly distribute `block_length` token commits across `num_inference_steps` steps."""
|
||||
if num_inference_steps <= 0:
|
||||
return torch.zeros((0,), dtype=torch.long)
|
||||
base = block_length // num_inference_steps
|
||||
remainder = block_length % num_inference_steps
|
||||
out = torch.full((num_inference_steps,), base, dtype=torch.long)
|
||||
out[:remainder] += 1
|
||||
return out
|
||||
|
||||
# --- SAR sampling utilities ---
|
||||
|
||||
@staticmethod
|
||||
def _top_p_filtering(logits: torch.Tensor, top_p: float | None) -> torch.Tensor:
|
||||
"""Nucleus (top-p) logit filtering."""
|
||||
if top_p is None or top_p >= 1.0:
|
||||
return logits
|
||||
if not (0.0 < top_p <= 1.0):
|
||||
raise ValueError(f"`top_p` must be in (0, 1], got {top_p}.")
|
||||
|
||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
|
||||
sorted_probs = torch.softmax(sorted_logits, dim=-1)
|
||||
cumulative_probs = sorted_probs.cumsum(dim=-1)
|
||||
|
||||
sorted_indices_to_remove = cumulative_probs > float(top_p)
|
||||
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
||||
sorted_indices_to_remove[..., 0] = 0
|
||||
|
||||
sorted_logits = sorted_logits.masked_fill(sorted_indices_to_remove, torch.finfo(sorted_logits.dtype).min)
|
||||
filtered = logits.scatter(-1, sorted_indices, sorted_logits)
|
||||
return filtered
|
||||
|
||||
@staticmethod
|
||||
def _top_k_filtering(logits: torch.Tensor, top_k: int | None) -> torch.Tensor:
|
||||
"""Top-k logit filtering."""
|
||||
if top_k is None or top_k <= 0:
|
||||
return logits
|
||||
if top_k >= logits.shape[-1]:
|
||||
return logits
|
||||
values, _ = torch.topk(logits, k=top_k, dim=-1)
|
||||
min_keep = values[..., -1, None]
|
||||
return logits.masked_fill(logits < min_keep, torch.finfo(logits.dtype).min)
|
||||
|
||||
@staticmethod
|
||||
def _sample_from_logits(
|
||||
logits: torch.Tensor,
|
||||
*,
|
||||
temperature: float,
|
||||
top_k: int | None,
|
||||
top_p: float | None,
|
||||
generator: torch.Generator | None,
|
||||
use_multinomial: bool,
|
||||
) -> tuple[torch.LongTensor, torch.Tensor]:
|
||||
"""Sample tokens from logits with temperature scaling, top-k, and top-p."""
|
||||
if temperature < 0:
|
||||
raise ValueError(f"`temperature` must be >= 0, got {temperature}.")
|
||||
|
||||
vocab_size = logits.shape[-1]
|
||||
flat_logits = logits.reshape(-1, vocab_size)
|
||||
|
||||
if temperature == 0.0 or not use_multinomial:
|
||||
probs = torch.softmax(flat_logits.float(), dim=-1)
|
||||
token = flat_logits.argmax(dim=-1, keepdim=True)
|
||||
token_prob = torch.gather(probs, -1, token)
|
||||
return token.view(*logits.shape[:-1]), token_prob.view(*logits.shape[:-1])
|
||||
|
||||
scaled = flat_logits
|
||||
if temperature != 1.0:
|
||||
scaled = flat_logits / temperature
|
||||
|
||||
filtered = BlockRefinementScheduler._top_k_filtering(scaled, top_k=top_k)
|
||||
filtered = BlockRefinementScheduler._top_p_filtering(filtered, top_p=top_p)
|
||||
|
||||
probs = torch.softmax(filtered.float(), dim=-1)
|
||||
token = torch.multinomial(probs, num_samples=1, generator=generator)
|
||||
token_prob = torch.gather(probs, -1, token)
|
||||
|
||||
return token.view(*logits.shape[:-1]), token_prob.view(*logits.shape[:-1])
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: torch.Tensor,
|
||||
timestep: int | torch.Tensor,
|
||||
sample: torch.LongTensor,
|
||||
*,
|
||||
mask_token_id: int,
|
||||
temperature: float = 0.0,
|
||||
top_p: float | None = None,
|
||||
top_k: int | None = None,
|
||||
sampling_method: str = "auto",
|
||||
threshold: float | None = None,
|
||||
editing_threshold: float | None = None,
|
||||
minimal_topk: int | None = None,
|
||||
prompt_mask: torch.BoolTensor | None = None,
|
||||
generator: torch.Generator | None = None,
|
||||
return_dict: bool = True,
|
||||
) -> (
|
||||
BlockRefinementSchedulerOutput
|
||||
| tuple[torch.LongTensor, torch.BoolTensor, torch.BoolTensor, torch.LongTensor, torch.Tensor]
|
||||
):
|
||||
"""
|
||||
Perform a single refinement step: sample from logits, commit confident tokens, and optionally edit existing
|
||||
ones.
|
||||
|
||||
Args:
|
||||
model_output (`torch.Tensor` of shape `(batch_size, block_length, vocab_size)`):
|
||||
Raw logits from the model for the current block.
|
||||
timestep (`int` or `torch.Tensor`):
|
||||
Current step index within the block's refinement schedule.
|
||||
sample (`torch.LongTensor` of shape `(batch_size, block_length)`):
|
||||
Current block token IDs (contains mask tokens for uncommitted positions).
|
||||
mask_token_id (`int`):
|
||||
Token ID used for masked positions.
|
||||
temperature (`float`):
|
||||
Sampling temperature.
|
||||
top_p (`float`, *optional*):
|
||||
Nucleus sampling cutoff.
|
||||
top_k (`int`, *optional*):
|
||||
Top-k sampling cutoff.
|
||||
sampling_method (`str`):
|
||||
Sampling method (`auto`, `greedy`, `multinomial`).
|
||||
threshold (`float`, *optional*):
|
||||
Confidence threshold for committing tokens. Defaults to config value.
|
||||
editing_threshold (`float`, *optional*):
|
||||
Confidence threshold for editing non-mask tokens; must be positive to enable editing. Defaults to
|
||||
config value.
|
||||
minimal_topk (`int`, *optional*):
|
||||
Minimum tokens to commit per step. Defaults to config value.
|
||||
prompt_mask (`torch.BoolTensor`, *optional*):
|
||||
Boolean mask of shape `(block_length,)` where `True` marks prompt (non-editable) positions.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
RNG for sampling.
|
||||
return_dict (`bool`):
|
||||
Whether to return a `BlockRefinementSchedulerOutput` or a tuple.
|
||||
"""
|
||||
if threshold is None:
|
||||
threshold = float(self.config.threshold)
|
||||
if editing_threshold is None:
|
||||
editing_threshold = self.config.editing_threshold
|
||||
if minimal_topk is None:
|
||||
minimal_topk = self.config.minimal_topk
|
||||
|
||||
# Sample from logits
|
||||
use_multinomial = sampling_method == "multinomial" or (sampling_method == "auto" and temperature != 0.0)
|
||||
sampled_tokens, sampled_probs = self._sample_from_logits(
|
||||
model_output,
|
||||
temperature=temperature,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
generator=generator,
|
||||
use_multinomial=use_multinomial,
|
||||
)
|
||||
|
||||
batch_size, block_length = sample.shape
|
||||
active_block = sample == mask_token_id
|
||||
masks_remaining = active_block.any()
|
||||
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
step_index = int(timestep.item())
|
||||
else:
|
||||
step_index = int(timestep)
|
||||
|
||||
# --- Mask-filling transfer ---
|
||||
transfer_index = torch.zeros_like(sampled_tokens, dtype=torch.bool)
|
||||
if masks_remaining and self._transfer_schedule is not None:
|
||||
clamped_step = min(step_index, len(self._transfer_schedule) - 1)
|
||||
num_to_transfer = int(self._transfer_schedule[clamped_step].item())
|
||||
|
||||
confidence = torch.where(
|
||||
active_block,
|
||||
sampled_probs.to(dtype=torch.float32),
|
||||
torch.full_like(sampled_probs, -torch.inf, dtype=torch.float32),
|
||||
)
|
||||
|
||||
for b in range(batch_size):
|
||||
high_conf = confidence[b] > threshold
|
||||
if high_conf.sum().item() >= num_to_transfer:
|
||||
transfer_index[b] = high_conf
|
||||
else:
|
||||
k = min(num_to_transfer, int(active_block[b].sum().item()))
|
||||
if k > 0:
|
||||
_, idx = torch.topk(confidence[b], k=k)
|
||||
transfer_index[b, idx] = True
|
||||
|
||||
# --- Editing transfer (non-mask, non-prompt positions) ---
|
||||
editing_enabled = editing_threshold is not None and editing_threshold > 0.0
|
||||
editing_transfer_index = torch.zeros_like(sampled_tokens, dtype=torch.bool)
|
||||
if editing_enabled:
|
||||
if prompt_mask is None:
|
||||
prompt_mask = torch.zeros(block_length, device=sample.device, dtype=torch.bool)
|
||||
editable = (~active_block) & (~prompt_mask.unsqueeze(0))
|
||||
editing_conf = torch.where(
|
||||
editable,
|
||||
sampled_probs.to(dtype=torch.float32),
|
||||
torch.full_like(sampled_probs, -torch.inf, dtype=torch.float32),
|
||||
)
|
||||
high_conf_edit = editing_conf > float(editing_threshold)
|
||||
token_changed = sampled_tokens != sample
|
||||
editing_transfer_index = high_conf_edit & token_changed & editable
|
||||
|
||||
# Apply transfers
|
||||
final_transfer = transfer_index | editing_transfer_index
|
||||
prev_sample = sample.clone()
|
||||
if final_transfer.any():
|
||||
prev_sample[final_transfer] = sampled_tokens[final_transfer]
|
||||
|
||||
if not return_dict:
|
||||
return prev_sample, transfer_index, editing_transfer_index, sampled_tokens, sampled_probs
|
||||
return BlockRefinementSchedulerOutput(
|
||||
prev_sample=prev_sample,
|
||||
transfer_index=transfer_index,
|
||||
editing_transfer_index=editing_transfer_index,
|
||||
sampled_tokens=sampled_tokens,
|
||||
sampled_probs=sampled_probs,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def check_eos_finished(
|
||||
cur_x: torch.LongTensor,
|
||||
sampled_tokens: torch.LongTensor,
|
||||
final_transfer: torch.BoolTensor,
|
||||
finished: torch.BoolTensor,
|
||||
eos_token_id: int,
|
||||
mask_token_id: int,
|
||||
prompt_length: int,
|
||||
) -> torch.BoolTensor:
|
||||
"""
|
||||
Update per-batch finished flags when EOS tokens are committed.
|
||||
|
||||
Args:
|
||||
cur_x (`torch.LongTensor` of shape `(batch_size, seq_len)`):
|
||||
Current full sequence including all blocks up to the current window.
|
||||
sampled_tokens (`torch.LongTensor` of shape `(batch_size, block_length)`):
|
||||
Tokens sampled by the scheduler in this step.
|
||||
final_transfer (`torch.BoolTensor` of shape `(batch_size, block_length)`):
|
||||
Combined mask of committed and edited positions.
|
||||
finished (`torch.BoolTensor` of shape `(batch_size,)`):
|
||||
Current per-batch finished flags.
|
||||
eos_token_id (`int`):
|
||||
EOS token ID.
|
||||
mask_token_id (`int`):
|
||||
Mask token ID.
|
||||
prompt_length (`int`):
|
||||
Number of prompt tokens at the start of the sequence.
|
||||
|
||||
Returns:
|
||||
`torch.BoolTensor`: Updated finished flags.
|
||||
"""
|
||||
batch_size = cur_x.shape[0]
|
||||
for b in range(batch_size):
|
||||
if finished[b]:
|
||||
continue
|
||||
eos_in_commits = (sampled_tokens[b][final_transfer[b]] == eos_token_id).any().item()
|
||||
if not eos_in_commits:
|
||||
continue
|
||||
eos_pos = (cur_x[b] == eos_token_id).nonzero(as_tuple=True)
|
||||
if len(eos_pos[0]) == 0:
|
||||
continue
|
||||
eos_pos = int(eos_pos[0][0].item())
|
||||
if prompt_length >= eos_pos:
|
||||
continue
|
||||
if (cur_x[b, prompt_length:eos_pos] != mask_token_id).all().item():
|
||||
finished[b] = True
|
||||
return finished
|
||||
|
||||
def check_block_should_continue(
|
||||
self,
|
||||
step_idx: int,
|
||||
masks_remaining: bool,
|
||||
editing_enabled: bool,
|
||||
editing_transfer_index: torch.BoolTensor,
|
||||
post_steps: int,
|
||||
max_post_steps: int,
|
||||
finished: torch.BoolTensor,
|
||||
) -> bool:
|
||||
"""
|
||||
Determine whether the inner refinement loop should continue for the current block.
|
||||
|
||||
Args:
|
||||
step_idx (`int`):
|
||||
Current refinement step index within this block.
|
||||
masks_remaining (`bool`):
|
||||
Whether any mask tokens remain in the block.
|
||||
editing_enabled (`bool`):
|
||||
Whether editing mode is active.
|
||||
editing_transfer_index (`torch.BoolTensor`):
|
||||
Which tokens were edited in this step.
|
||||
post_steps (`int`):
|
||||
Number of post-mask editing steps taken so far.
|
||||
max_post_steps (`int`):
|
||||
Maximum allowed post-mask editing steps.
|
||||
finished (`torch.BoolTensor`):
|
||||
Per-batch finished flags (from EOS detection).
|
||||
|
||||
Returns:
|
||||
`bool`: `True` if refinement should continue, `False` to break.
|
||||
"""
|
||||
if finished.all():
|
||||
return False
|
||||
if not masks_remaining and not editing_enabled:
|
||||
return False
|
||||
if not masks_remaining and not editing_transfer_index.any():
|
||||
return False
|
||||
if masks_remaining and step_idx >= self.num_inference_steps:
|
||||
return False
|
||||
if not masks_remaining and post_steps > max_post_steps:
|
||||
return False
|
||||
return True
|
||||
|
||||
def add_noise(
|
||||
self,
|
||||
original_samples: torch.LongTensor,
|
||||
attention_mask: torch.LongTensor,
|
||||
*,
|
||||
prompt_length: int,
|
||||
block_length: int,
|
||||
mask_token_id: int,
|
||||
generator: torch.Generator | None = None,
|
||||
) -> tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor, torch.BoolTensor]:
|
||||
"""
|
||||
Apply the forward (noising) process for semi-autoregressive block masking.
|
||||
|
||||
For each block after the prompt, a random fraction of valid (non-padding) tokens are replaced with
|
||||
`mask_token_id`. Two complementary views are returned: `noisy` and `noisy_rev`, where the masked positions in
|
||||
one are the unmasked positions in the other.
|
||||
|
||||
Args:
|
||||
original_samples (`torch.LongTensor` of shape `(batch_size, seq_len)`):
|
||||
Clean token IDs.
|
||||
attention_mask (`torch.LongTensor` of shape `(batch_size, seq_len)`):
|
||||
Padding mask (1 for valid, 0 for padding).
|
||||
prompt_length (`int`):
|
||||
Number of leading prompt tokens to keep unmasked.
|
||||
block_length (`int`):
|
||||
Block size for masking.
|
||||
mask_token_id (`int`):
|
||||
Token ID to use for masked positions.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
RNG for reproducibility.
|
||||
|
||||
Returns:
|
||||
`tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor, torch.BoolTensor]`:
|
||||
`(noisy, noisy_rev, masked, masked_rev)` — the two complementary noisy sequences and their
|
||||
corresponding boolean masks.
|
||||
"""
|
||||
batch_size, seq_len = original_samples.shape
|
||||
device = original_samples.device
|
||||
|
||||
noisy = original_samples.clone()
|
||||
noisy_rev = original_samples.clone()
|
||||
masked = torch.zeros_like(original_samples, dtype=torch.bool)
|
||||
masked_rev = torch.zeros_like(original_samples, dtype=torch.bool)
|
||||
|
||||
valid = attention_mask.to(dtype=torch.bool)
|
||||
for block_start in range(prompt_length, seq_len, block_length):
|
||||
block_end = min(seq_len, block_start + block_length)
|
||||
seg_len = block_end - block_start
|
||||
if seg_len <= 0:
|
||||
continue
|
||||
|
||||
p_mask = torch.rand((batch_size, 1), device=device, generator=generator)
|
||||
seg = torch.rand((batch_size, seg_len), device=device, generator=generator) < p_mask
|
||||
seg = seg & valid[:, block_start:block_end]
|
||||
seg_rev = (~seg) & valid[:, block_start:block_end]
|
||||
|
||||
masked[:, block_start:block_end] = seg
|
||||
masked_rev[:, block_start:block_end] = seg_rev
|
||||
|
||||
noisy = torch.where(masked, torch.full_like(noisy, mask_token_id), noisy)
|
||||
noisy_rev = torch.where(masked_rev, torch.full_like(noisy_rev, mask_token_id), noisy_rev)
|
||||
return noisy, noisy_rev, masked, masked_rev
|
||||
|
||||
|
||||
__all__ = ["BlockRefinementScheduler", "BlockRefinementSchedulerOutput"]
|
||||
@@ -11,7 +11,6 @@ from typing import Any, Iterable
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
if getattr(torch, "distributed", None) is not None:
|
||||
@@ -110,92 +109,6 @@ def compute_snr(noise_scheduler, timesteps):
|
||||
return snr
|
||||
|
||||
|
||||
def compute_confidence_aware_loss(
|
||||
logits: torch.Tensor,
|
||||
labels: torch.Tensor,
|
||||
*,
|
||||
lambda_conf: float = 0.0,
|
||||
temperature: float = 1.0,
|
||||
per_token_weights: torch.Tensor | None = None,
|
||||
ignore_index: int = -100,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Computes a confidence-aware training loss for token classification-style heads.
|
||||
|
||||
This loss combines:
|
||||
- `loss_sft`: standard supervised cross-entropy on all non-ignored labels.
|
||||
- `loss_conf`: an entropy penalty applied only on tokens that are already predicted correctly.
|
||||
|
||||
Args:
|
||||
logits (`torch.Tensor`): Logits of shape `(..., vocab_size)`.
|
||||
labels (`torch.Tensor`): Labels of shape `(...)`, matching `logits.shape[:-1]`. Values set to `ignore_index`
|
||||
are excluded from both losses.
|
||||
lambda_conf (`float`, *optional*, defaults to `0.0`): Weight for the confidence term.
|
||||
temperature (`float`, *optional*, defaults to `1.0`): Temperature used for the entropy term only. Lower values
|
||||
sharpen the distribution and change the strength of the confidence gradients.
|
||||
per_token_weights (`torch.Tensor`, *optional*): Optional weights of shape `(...)` to reweight both losses per
|
||||
token (e.g. schedule-aware weights). Tokens with weight `0` contribute nothing.
|
||||
ignore_index (`int`, *optional*, defaults to `-100`): Ignore index for labels.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, torch.Tensor, torch.Tensor]`: `(loss, loss_sft, loss_conf)`.
|
||||
"""
|
||||
if logits.ndim < 2:
|
||||
raise ValueError(f"`logits` must have at least 2 dims, got shape {tuple(logits.shape)}.")
|
||||
if labels.shape != logits.shape[:-1]:
|
||||
raise ValueError(
|
||||
f"`labels` shape must match `logits.shape[:-1]`, got labels={tuple(labels.shape)} logits={tuple(logits.shape)}."
|
||||
)
|
||||
if temperature <= 0:
|
||||
raise ValueError(f"`temperature` must be > 0, got {temperature}.")
|
||||
|
||||
valid = labels.ne(ignore_index)
|
||||
if per_token_weights is None:
|
||||
weights = torch.ones_like(labels, dtype=logits.dtype)
|
||||
else:
|
||||
if per_token_weights.shape != labels.shape:
|
||||
raise ValueError(
|
||||
f"`per_token_weights` shape must match `labels` shape, got {tuple(per_token_weights.shape)} != {tuple(labels.shape)}."
|
||||
)
|
||||
weights = per_token_weights.to(dtype=logits.dtype)
|
||||
|
||||
# Supervised CE (optionally weighted).
|
||||
vocab_size = logits.shape[-1]
|
||||
per_token_nll = F.cross_entropy(
|
||||
logits.reshape(-1, vocab_size),
|
||||
labels.reshape(-1),
|
||||
reduction="none",
|
||||
ignore_index=ignore_index,
|
||||
).reshape_as(labels)
|
||||
|
||||
denom_sft = (weights * valid.to(weights.dtype)).sum().clamp_min(1)
|
||||
loss_sft = (per_token_nll * weights * valid.to(per_token_nll.dtype)).sum() / denom_sft
|
||||
|
||||
# Confidence loss: penalize entropy only where prediction is already correct.
|
||||
if lambda_conf == 0.0:
|
||||
loss_conf = torch.zeros((), device=logits.device, dtype=loss_sft.dtype)
|
||||
return loss_sft, loss_sft, loss_conf
|
||||
|
||||
with torch.no_grad():
|
||||
pred = logits.argmax(dim=-1)
|
||||
correct = valid & pred.eq(labels)
|
||||
|
||||
scaled_logits = logits.float()
|
||||
if temperature != 1.0:
|
||||
scaled_logits = scaled_logits / float(temperature)
|
||||
|
||||
probs = torch.softmax(scaled_logits, dim=-1)
|
||||
eps = torch.finfo(probs.dtype).tiny
|
||||
log_probs = torch.log(probs.clamp_min(eps))
|
||||
entropy = -(probs * log_probs).sum(dim=-1).to(dtype=logits.dtype)
|
||||
|
||||
denom_conf = (weights * correct.to(weights.dtype)).sum().clamp_min(1)
|
||||
loss_conf = (entropy * weights * correct.to(entropy.dtype)).sum() / denom_conf
|
||||
|
||||
loss = loss_sft + float(lambda_conf) * loss_conf
|
||||
return loss, loss_sft, loss_conf
|
||||
|
||||
|
||||
def resolve_interpolation_mode(interpolation_type: str):
|
||||
"""
|
||||
Maps a string describing an interpolation function to the corresponding torchvision `InterpolationMode` enum. The
|
||||
|
||||
@@ -2518,36 +2518,6 @@ class AmusedScheduler(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class BlockRefinementScheduler(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class BlockRefinementSchedulerOutput(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class CMStochasticIterativeScheduler(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
||||
@@ -2222,36 +2222,6 @@ class LEditsPPPipelineStableDiffusionXL(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class LLaDA2Pipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class LLaDA2PipelineOutput(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class LongCatImageEditPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
|
||||
@@ -177,6 +177,11 @@ class QuantizationTesterMixin:
|
||||
model_quantized.to(torch_device)
|
||||
|
||||
inputs = self.get_dummy_inputs()
|
||||
model_dtype = next(model_quantized.parameters()).dtype
|
||||
inputs = {
|
||||
k: v.to(dtype=model_dtype) if torch.is_tensor(v) and torch.is_floating_point(v) else v
|
||||
for k, v in inputs.items()
|
||||
}
|
||||
output = model_quantized(**inputs, return_dict=False)[0]
|
||||
|
||||
assert output is not None, "Model output is None"
|
||||
@@ -930,6 +935,7 @@ class TorchAoTesterMixin(TorchAoConfigMixin, QuantizationTesterMixin):
|
||||
"""Test that device_map='auto' works correctly with quantization."""
|
||||
self._test_quantization_device_map(TorchAoConfigMixin.TORCHAO_QUANT_TYPES["int8wo"])
|
||||
|
||||
@pytest.mark.xfail(reason="dequantize is not implemented in torchao")
|
||||
def test_torchao_dequantize(self):
|
||||
"""Test that dequantize() works correctly."""
|
||||
self._test_dequantize(TorchAoConfigMixin.TORCHAO_QUANT_TYPES["int8wo"])
|
||||
|
||||
@@ -18,7 +18,7 @@ import unittest
|
||||
import torch
|
||||
|
||||
from diffusers import DDIMScheduler, DDPMScheduler, UNet2DModel
|
||||
from diffusers.training_utils import compute_confidence_aware_loss, set_seed
|
||||
from diffusers.training_utils import set_seed
|
||||
|
||||
from ..testing_utils import slow
|
||||
|
||||
@@ -85,47 +85,3 @@ class TrainingTests(unittest.TestCase):
|
||||
|
||||
self.assertTrue(torch.allclose(ddpm_noisy_images, ddim_noisy_images, atol=1e-5))
|
||||
self.assertTrue(torch.allclose(ddpm_noise_pred, ddim_noise_pred, atol=1e-5))
|
||||
|
||||
def test_confidence_aware_loss(self):
|
||||
logits = torch.tensor([[[5.0, 0.0], [0.0, 5.0]]])
|
||||
labels = torch.tensor([[0, 0]])
|
||||
weights = torch.tensor([[1.0, 2.0]])
|
||||
|
||||
loss, loss_sft, loss_conf = compute_confidence_aware_loss(
|
||||
logits, labels, lambda_conf=0.0, per_token_weights=weights
|
||||
)
|
||||
self.assertTrue(torch.allclose(loss, loss_sft))
|
||||
self.assertTrue(torch.allclose(loss_conf, torch.zeros_like(loss_conf)))
|
||||
|
||||
lambda_conf = 0.25
|
||||
loss, loss_sft, loss_conf = compute_confidence_aware_loss(
|
||||
logits, labels, lambda_conf=lambda_conf, per_token_weights=weights
|
||||
)
|
||||
|
||||
# Manual expected values for the small 2-class case.
|
||||
per_token_nll = torch.nn.functional.cross_entropy(logits.view(-1, 2), labels.view(-1), reduction="none").view(
|
||||
1, 2
|
||||
)
|
||||
expected_sft = (per_token_nll * weights).sum() / weights.sum()
|
||||
|
||||
pred = logits.argmax(dim=-1)
|
||||
correct = pred.eq(labels)
|
||||
log_probs = torch.log_softmax(logits.float(), dim=-1)
|
||||
probs = log_probs.exp()
|
||||
entropy = -(probs * log_probs).sum(dim=-1).to(dtype=logits.dtype)
|
||||
expected_conf = (entropy * weights * correct.to(entropy.dtype)).sum() / (
|
||||
weights * correct.to(weights.dtype)
|
||||
).sum().clamp_min(1)
|
||||
|
||||
expected = expected_sft + lambda_conf * expected_conf
|
||||
self.assertTrue(torch.allclose(loss_sft, expected_sft))
|
||||
self.assertTrue(torch.allclose(loss_conf, expected_conf))
|
||||
self.assertTrue(torch.allclose(loss, expected))
|
||||
|
||||
# Temperature affects only the confidence term.
|
||||
loss_t, loss_sft_t, loss_conf_t = compute_confidence_aware_loss(
|
||||
logits, labels, lambda_conf=lambda_conf, temperature=0.5, per_token_weights=weights
|
||||
)
|
||||
self.assertTrue(torch.allclose(loss_sft_t, expected_sft))
|
||||
self.assertFalse(torch.allclose(loss_conf_t, expected_conf))
|
||||
self.assertTrue(torch.allclose(loss_t, loss_sft_t + lambda_conf * loss_conf_t))
|
||||
|
||||
@@ -1,245 +0,0 @@
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import BlockRefinementScheduler, LLaDA2Pipeline
|
||||
|
||||
|
||||
class _DummyModelOutput:
|
||||
def __init__(self, logits):
|
||||
self.logits = logits
|
||||
|
||||
|
||||
class _DummyCausalLM(torch.nn.Module):
|
||||
def __init__(self, vocab_size: int):
|
||||
super().__init__()
|
||||
self.vocab_size = int(vocab_size)
|
||||
self.register_buffer("_device_anchor", torch.empty(0))
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return torch.float32
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self._device_anchor.device
|
||||
|
||||
def forward(self, input_ids, attention_mask=None, position_ids=None, **kwargs):
|
||||
batch_size, seq_len = input_ids.shape
|
||||
logits = torch.zeros((batch_size, seq_len, self.vocab_size), device=input_ids.device, dtype=torch.float32)
|
||||
|
||||
# Make confidence vary with token position so top-k commits are deterministic.
|
||||
positions = torch.arange(seq_len, device=input_ids.device, dtype=torch.float32).view(1, seq_len, 1)
|
||||
token_ids = (torch.arange(seq_len, device=input_ids.device) % (self.vocab_size - 2)).view(1, seq_len, 1)
|
||||
logits.scatter_(2, token_ids.expand(batch_size, -1, -1), 1.0 + positions.expand(batch_size, -1, -1) * 0.1)
|
||||
return _DummyModelOutput(logits=logits)
|
||||
|
||||
|
||||
def _make_pipeline(tokenizer=None):
|
||||
model = _DummyCausalLM(vocab_size=32)
|
||||
scheduler = BlockRefinementScheduler()
|
||||
return LLaDA2Pipeline(model=model, scheduler=scheduler, tokenizer=tokenizer)
|
||||
|
||||
|
||||
class LLaDA2PipelineTest(unittest.TestCase):
|
||||
def test_pipeline_runs(self):
|
||||
pipe = _make_pipeline().to("cpu")
|
||||
|
||||
input_ids = torch.tensor([[5, 6, 7, 8], [1, 2, 3, 4]], dtype=torch.long)
|
||||
out = pipe(
|
||||
input_ids=input_ids,
|
||||
use_chat_template=False,
|
||||
gen_length=24,
|
||||
block_length=8,
|
||||
num_inference_steps=8,
|
||||
temperature=0.0,
|
||||
threshold=2.0, # force top-k commits
|
||||
minimal_topk=1,
|
||||
eos_early_stop=False,
|
||||
mask_token_id=31,
|
||||
eos_token_id=None,
|
||||
output_type="seq",
|
||||
)
|
||||
|
||||
self.assertEqual(out.sequences.shape, (2, 24))
|
||||
self.assertFalse((out.sequences == 31).any().item())
|
||||
|
||||
def test_pipeline_return_tuple(self):
|
||||
pipe = _make_pipeline().to("cpu")
|
||||
|
||||
input_ids = torch.tensor([[5, 6, 7, 8]], dtype=torch.long)
|
||||
sequences, texts = pipe(
|
||||
input_ids=input_ids,
|
||||
use_chat_template=False,
|
||||
gen_length=16,
|
||||
block_length=8,
|
||||
num_inference_steps=4,
|
||||
temperature=0.0,
|
||||
threshold=2.0,
|
||||
minimal_topk=1,
|
||||
eos_early_stop=False,
|
||||
mask_token_id=31,
|
||||
output_type="seq",
|
||||
return_dict=False,
|
||||
)
|
||||
|
||||
self.assertEqual(sequences.shape, (1, 16))
|
||||
self.assertIsNone(texts)
|
||||
|
||||
def test_output_type_seq(self):
|
||||
"""output_type='seq' should return sequences but no texts."""
|
||||
pipe = _make_pipeline().to("cpu")
|
||||
|
||||
out = pipe(
|
||||
input_ids=torch.tensor([[5, 6, 7, 8]], dtype=torch.long),
|
||||
use_chat_template=False,
|
||||
gen_length=16,
|
||||
block_length=8,
|
||||
num_inference_steps=4,
|
||||
temperature=0.0,
|
||||
threshold=2.0,
|
||||
minimal_topk=1,
|
||||
eos_early_stop=False,
|
||||
mask_token_id=31,
|
||||
output_type="seq",
|
||||
)
|
||||
|
||||
self.assertIsNotNone(out.sequences)
|
||||
self.assertEqual(out.sequences.shape, (1, 16))
|
||||
self.assertIsNone(out.texts)
|
||||
|
||||
def test_output_type_text_without_tokenizer(self):
|
||||
"""output_type='text' without a tokenizer should return texts=None."""
|
||||
pipe = _make_pipeline(tokenizer=None).to("cpu")
|
||||
|
||||
out = pipe(
|
||||
input_ids=torch.tensor([[5, 6, 7, 8]], dtype=torch.long),
|
||||
use_chat_template=False,
|
||||
gen_length=16,
|
||||
block_length=8,
|
||||
num_inference_steps=4,
|
||||
temperature=0.0,
|
||||
threshold=2.0,
|
||||
minimal_topk=1,
|
||||
eos_early_stop=False,
|
||||
mask_token_id=31,
|
||||
output_type="text",
|
||||
)
|
||||
|
||||
self.assertIsNotNone(out.sequences)
|
||||
self.assertIsNone(out.texts)
|
||||
|
||||
def test_output_type_text_with_tokenizer(self):
|
||||
"""output_type='text' with a tokenizer should return decoded texts."""
|
||||
tok = type(
|
||||
"Tok",
|
||||
(),
|
||||
{
|
||||
"eos_token_id": None,
|
||||
"mask_token_id": 31,
|
||||
"batch_decode": lambda self, seqs, **kw: [f"decoded_{len(s)}" for s in seqs],
|
||||
},
|
||||
)()
|
||||
pipe = _make_pipeline(tokenizer=tok).to("cpu")
|
||||
|
||||
out = pipe(
|
||||
input_ids=torch.tensor([[5, 6, 7, 8]], dtype=torch.long),
|
||||
use_chat_template=False,
|
||||
gen_length=16,
|
||||
block_length=8,
|
||||
num_inference_steps=4,
|
||||
temperature=0.0,
|
||||
threshold=2.0,
|
||||
minimal_topk=1,
|
||||
eos_early_stop=False,
|
||||
output_type="text",
|
||||
)
|
||||
|
||||
self.assertIsNotNone(out.sequences)
|
||||
self.assertIsNotNone(out.texts)
|
||||
self.assertEqual(len(out.texts), 1)
|
||||
self.assertTrue(out.texts[0].startswith("decoded_"))
|
||||
|
||||
def test_output_type_invalid_raises(self):
|
||||
"""Invalid output_type should raise ValueError."""
|
||||
pipe = _make_pipeline().to("cpu")
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
pipe(
|
||||
input_ids=torch.tensor([[5, 6, 7, 8]], dtype=torch.long),
|
||||
use_chat_template=False,
|
||||
gen_length=16,
|
||||
block_length=8,
|
||||
num_inference_steps=4,
|
||||
mask_token_id=31,
|
||||
output_type="invalid",
|
||||
)
|
||||
|
||||
def test_prepare_input_ids_from_tensor(self):
|
||||
pipe = _make_pipeline()
|
||||
ids = torch.tensor([[1, 2, 3]], dtype=torch.long)
|
||||
result = pipe._prepare_input_ids(
|
||||
prompt=None,
|
||||
messages=None,
|
||||
input_ids=ids,
|
||||
use_chat_template=False,
|
||||
add_generation_prompt=False,
|
||||
chat_template_kwargs=None,
|
||||
)
|
||||
self.assertTrue(torch.equal(result, ids))
|
||||
|
||||
def test_prepare_input_ids_from_1d_tensor(self):
|
||||
pipe = _make_pipeline()
|
||||
ids = torch.tensor([1, 2, 3], dtype=torch.long)
|
||||
result = pipe._prepare_input_ids(
|
||||
prompt=None,
|
||||
messages=None,
|
||||
input_ids=ids,
|
||||
use_chat_template=False,
|
||||
add_generation_prompt=False,
|
||||
chat_template_kwargs=None,
|
||||
)
|
||||
self.assertEqual(result.shape, (1, 3))
|
||||
|
||||
def test_prepare_input_ids_no_tokenizer_raises(self):
|
||||
pipe = _make_pipeline(tokenizer=None)
|
||||
with self.assertRaises(ValueError):
|
||||
pipe._prepare_input_ids(
|
||||
prompt="hello",
|
||||
messages=None,
|
||||
input_ids=None,
|
||||
use_chat_template=False,
|
||||
add_generation_prompt=False,
|
||||
chat_template_kwargs=None,
|
||||
)
|
||||
|
||||
def test_prepare_input_ids_both_prompt_and_messages_raises(self):
|
||||
pipe = _make_pipeline()
|
||||
# Manually set tokenizer to a simple object so _prepare_input_ids doesn't short-circuit
|
||||
pipe.tokenizer = type("Tok", (), {"eos_token_id": None, "mask_token_id": None})()
|
||||
with self.assertRaises(ValueError):
|
||||
pipe._prepare_input_ids(
|
||||
prompt="hello",
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
input_ids=None,
|
||||
use_chat_template=False,
|
||||
add_generation_prompt=False,
|
||||
chat_template_kwargs=None,
|
||||
)
|
||||
|
||||
def test_prepare_input_ids_neither_raises(self):
|
||||
pipe = _make_pipeline()
|
||||
pipe.tokenizer = type("Tok", (), {"eos_token_id": None, "mask_token_id": None})()
|
||||
with self.assertRaises(ValueError):
|
||||
pipe._prepare_input_ids(
|
||||
prompt=None,
|
||||
messages=None,
|
||||
input_ids=None,
|
||||
use_chat_template=False,
|
||||
add_generation_prompt=False,
|
||||
chat_template_kwargs=None,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -1534,18 +1534,14 @@ class PipelineTesterMixin:
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
pipe.to("cpu")
|
||||
model_devices = [
|
||||
component.device.type for component in components.values() if getattr(component, "device", None)
|
||||
]
|
||||
model_devices = [component.device.type for component in components.values() if hasattr(component, "device")]
|
||||
self.assertTrue(all(device == "cpu" for device in model_devices))
|
||||
|
||||
output_cpu = pipe(**self.get_dummy_inputs("cpu"))[0]
|
||||
self.assertTrue(np.isnan(output_cpu).sum() == 0)
|
||||
|
||||
pipe.to(torch_device)
|
||||
model_devices = [
|
||||
component.device.type for component in components.values() if getattr(component, "device", None)
|
||||
]
|
||||
model_devices = [component.device.type for component in components.values() if hasattr(component, "device")]
|
||||
self.assertTrue(all(device == torch_device for device in model_devices))
|
||||
|
||||
output_device = pipe(**self.get_dummy_inputs(torch_device))[0]
|
||||
@@ -1556,11 +1552,11 @@ class PipelineTesterMixin:
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
model_dtypes = [component.dtype for component in components.values() if getattr(component, "dtype", None)]
|
||||
model_dtypes = [component.dtype for component in components.values() if hasattr(component, "dtype")]
|
||||
self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes))
|
||||
|
||||
pipe.to(dtype=torch.float16)
|
||||
model_dtypes = [component.dtype for component in components.values() if getattr(component, "dtype", None)]
|
||||
model_dtypes = [component.dtype for component in components.values() if hasattr(component, "dtype")]
|
||||
self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes))
|
||||
|
||||
def test_attention_slicing_forward_pass(self, expected_max_diff=1e-3):
|
||||
|
||||
@@ -1,470 +0,0 @@
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import BlockRefinementScheduler
|
||||
|
||||
|
||||
class BlockRefinementSchedulerTest(unittest.TestCase):
|
||||
def get_scheduler(self, **kwargs):
|
||||
config = {
|
||||
"block_length": 32,
|
||||
"num_inference_steps": 8,
|
||||
"threshold": 0.95,
|
||||
"editing_threshold": None,
|
||||
"minimal_topk": 1,
|
||||
}
|
||||
config.update(kwargs)
|
||||
return BlockRefinementScheduler(**config)
|
||||
|
||||
def _make_logits_from_probs(self, target_probs: torch.Tensor, vocab_size: int = 100) -> torch.Tensor:
|
||||
"""Create logits where softmax of the target token has approximately the given probability."""
|
||||
batch_size, block_length = target_probs.shape
|
||||
logits = torch.zeros(batch_size, block_length, vocab_size)
|
||||
# Set token 0 as the "predicted" token with a logit proportional to desired probability
|
||||
for b in range(batch_size):
|
||||
for t in range(block_length):
|
||||
p = target_probs[b, t].item()
|
||||
if p > 0:
|
||||
logits[b, t, t % (vocab_size - 1)] = 10.0 * p
|
||||
return logits
|
||||
|
||||
def test_set_timesteps(self):
|
||||
scheduler = self.get_scheduler()
|
||||
scheduler.set_timesteps(8)
|
||||
self.assertEqual(scheduler.num_inference_steps, 8)
|
||||
self.assertEqual(len(scheduler.timesteps), 8)
|
||||
self.assertEqual(scheduler.timesteps[0].item(), 7)
|
||||
self.assertEqual(scheduler.timesteps[-1].item(), 0)
|
||||
|
||||
def test_set_timesteps_invalid(self):
|
||||
scheduler = self.get_scheduler()
|
||||
with self.assertRaises(ValueError):
|
||||
scheduler.set_timesteps(0)
|
||||
|
||||
def test_get_num_transfer_tokens_even(self):
|
||||
scheduler = self.get_scheduler()
|
||||
schedule = scheduler.get_num_transfer_tokens(block_length=32, num_inference_steps=8)
|
||||
self.assertEqual(schedule.sum().item(), 32)
|
||||
self.assertEqual(len(schedule), 8)
|
||||
self.assertTrue((schedule == 4).all().item())
|
||||
|
||||
def test_get_num_transfer_tokens_remainder(self):
|
||||
scheduler = self.get_scheduler()
|
||||
schedule = scheduler.get_num_transfer_tokens(block_length=10, num_inference_steps=3)
|
||||
self.assertEqual(schedule.sum().item(), 10)
|
||||
self.assertEqual(len(schedule), 3)
|
||||
self.assertEqual(schedule[0].item(), 4)
|
||||
self.assertEqual(schedule[1].item(), 3)
|
||||
self.assertEqual(schedule[2].item(), 3)
|
||||
|
||||
def test_transfer_schedule_created_on_set_timesteps(self):
|
||||
scheduler = self.get_scheduler(block_length=16)
|
||||
scheduler.set_timesteps(4)
|
||||
self.assertIsNotNone(scheduler._transfer_schedule)
|
||||
self.assertEqual(scheduler._transfer_schedule.sum().item(), 16)
|
||||
|
||||
def test_save_load_config_round_trip(self):
|
||||
scheduler = self.get_scheduler(block_length=64, threshold=0.8, editing_threshold=0.5, minimal_topk=2)
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
scheduler.save_config(tmpdir)
|
||||
loaded = BlockRefinementScheduler.from_pretrained(tmpdir)
|
||||
|
||||
self.assertEqual(loaded.config.block_length, 64)
|
||||
self.assertEqual(loaded.config.threshold, 0.8)
|
||||
self.assertEqual(loaded.config.editing_threshold, 0.5)
|
||||
self.assertEqual(loaded.config.minimal_topk, 2)
|
||||
|
||||
def test_from_config(self):
|
||||
scheduler = self.get_scheduler(block_length=16, threshold=0.7)
|
||||
new_scheduler = BlockRefinementScheduler.from_config(scheduler.config)
|
||||
self.assertEqual(new_scheduler.config.block_length, 16)
|
||||
self.assertEqual(new_scheduler.config.threshold, 0.7)
|
||||
|
||||
def test_step_commits_tokens(self):
|
||||
"""Verify that step() commits mask tokens based on confidence."""
|
||||
scheduler = self.get_scheduler(block_length=8)
|
||||
scheduler.set_timesteps(2)
|
||||
|
||||
batch_size, block_length, vocab_size = 1, 8, 32
|
||||
mask_id = 31
|
||||
|
||||
sample = torch.full((batch_size, block_length), mask_id, dtype=torch.long)
|
||||
# Create logits where confidence decreases with position
|
||||
logits = torch.zeros(batch_size, block_length, vocab_size)
|
||||
for i in range(block_length):
|
||||
logits[0, i, i] = 10.0 - i # decreasing confidence
|
||||
|
||||
out = scheduler.step(
|
||||
model_output=logits,
|
||||
timestep=0,
|
||||
sample=sample,
|
||||
mask_token_id=mask_id,
|
||||
temperature=0.0,
|
||||
threshold=0.95,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
# With 8 tokens and 2 steps, first step should commit 4 tokens
|
||||
committed = out.transfer_index[0].sum().item()
|
||||
self.assertEqual(committed, 4)
|
||||
|
||||
def test_step_no_editing_by_default(self):
|
||||
"""Without editing_threshold, no non-mask tokens should be changed."""
|
||||
scheduler = self.get_scheduler(block_length=4)
|
||||
scheduler.set_timesteps(2)
|
||||
|
||||
vocab_size = 32
|
||||
sample = torch.tensor([[10, 20, 31, 31]], dtype=torch.long)
|
||||
logits = torch.zeros(1, 4, vocab_size)
|
||||
logits[0, :, 15] = 10.0 # predict token 15 for all positions
|
||||
|
||||
out = scheduler.step(
|
||||
model_output=logits,
|
||||
timestep=0,
|
||||
sample=sample,
|
||||
mask_token_id=31,
|
||||
temperature=0.0,
|
||||
editing_threshold=None,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
self.assertFalse(out.editing_transfer_index.any().item())
|
||||
self.assertFalse(out.transfer_index[0, 0].item())
|
||||
self.assertFalse(out.transfer_index[0, 1].item())
|
||||
|
||||
def test_step_editing_replaces_tokens(self):
|
||||
"""With editing_threshold, non-mask tokens with high confidence and different prediction get replaced."""
|
||||
scheduler = self.get_scheduler(block_length=4)
|
||||
scheduler.set_timesteps(2)
|
||||
|
||||
vocab_size = 32
|
||||
sample = torch.tensor([[10, 20, 31, 31]], dtype=torch.long)
|
||||
logits = torch.zeros(1, 4, vocab_size)
|
||||
# Token 0: predict 50 (different from 10) with very high logit
|
||||
logits[0, 0, 15] = 20.0
|
||||
# Token 1: predict 20 (same as current)
|
||||
logits[0, 1, 20] = 20.0
|
||||
# Mask tokens
|
||||
logits[0, 2, 5] = 5.0
|
||||
logits[0, 3, 6] = 5.0
|
||||
|
||||
out = scheduler.step(
|
||||
model_output=logits,
|
||||
timestep=0,
|
||||
sample=sample,
|
||||
mask_token_id=31,
|
||||
temperature=0.0,
|
||||
editing_threshold=0.5,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
# Token 0 should be edited (different prediction, high confidence)
|
||||
self.assertTrue(out.editing_transfer_index[0, 0].item())
|
||||
# Token 1 should NOT be edited (same prediction)
|
||||
self.assertFalse(out.editing_transfer_index[0, 1].item())
|
||||
|
||||
def test_step_prompt_mask_prevents_editing(self):
|
||||
"""Prompt positions should never be edited even with editing enabled."""
|
||||
scheduler = self.get_scheduler(block_length=4)
|
||||
scheduler.set_timesteps(2)
|
||||
|
||||
vocab_size = 32
|
||||
sample = torch.tensor([[10, 20, 31, 31]], dtype=torch.long)
|
||||
logits = torch.zeros(1, 4, vocab_size)
|
||||
logits[0, :, 15] = 20.0
|
||||
prompt_mask = torch.tensor([True, True, False, False])
|
||||
|
||||
out = scheduler.step(
|
||||
model_output=logits,
|
||||
timestep=0,
|
||||
sample=sample,
|
||||
mask_token_id=31,
|
||||
temperature=0.0,
|
||||
editing_threshold=0.5,
|
||||
prompt_mask=prompt_mask,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
self.assertFalse(out.editing_transfer_index[0, 0].item())
|
||||
self.assertFalse(out.editing_transfer_index[0, 1].item())
|
||||
|
||||
def test_step_return_tuple(self):
|
||||
"""Verify tuple output when return_dict=False."""
|
||||
scheduler = self.get_scheduler(block_length=4)
|
||||
scheduler.set_timesteps(2)
|
||||
|
||||
vocab_size = 32
|
||||
sample = torch.full((1, 4), 31, dtype=torch.long)
|
||||
logits = torch.randn(1, 4, vocab_size)
|
||||
|
||||
result = scheduler.step(
|
||||
model_output=logits,
|
||||
timestep=0,
|
||||
sample=sample,
|
||||
mask_token_id=31,
|
||||
temperature=0.0,
|
||||
return_dict=False,
|
||||
)
|
||||
|
||||
self.assertIsInstance(result, tuple)
|
||||
self.assertEqual(len(result), 5)
|
||||
|
||||
def test_step_batched(self):
|
||||
"""Verify step works with batch_size > 1."""
|
||||
scheduler = self.get_scheduler(block_length=4)
|
||||
scheduler.set_timesteps(2)
|
||||
|
||||
batch_size, vocab_size = 3, 32
|
||||
mask_id = 31
|
||||
sample = torch.full((batch_size, 4), mask_id, dtype=torch.long)
|
||||
logits = torch.randn(batch_size, 4, vocab_size)
|
||||
|
||||
out = scheduler.step(
|
||||
model_output=logits,
|
||||
timestep=0,
|
||||
sample=sample,
|
||||
mask_token_id=mask_id,
|
||||
temperature=0.0,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
self.assertEqual(out.prev_sample.shape, (batch_size, 4))
|
||||
self.assertEqual(out.transfer_index.shape, (batch_size, 4))
|
||||
|
||||
def test_check_block_should_continue_finished(self):
|
||||
scheduler = self.get_scheduler()
|
||||
scheduler.set_timesteps(8)
|
||||
finished = torch.tensor([True, True])
|
||||
result = scheduler.check_block_should_continue(
|
||||
step_idx=0,
|
||||
masks_remaining=True,
|
||||
editing_enabled=False,
|
||||
editing_transfer_index=torch.zeros(2, 32, dtype=torch.bool),
|
||||
post_steps=0,
|
||||
max_post_steps=16,
|
||||
finished=finished,
|
||||
)
|
||||
self.assertFalse(result)
|
||||
|
||||
def test_check_block_should_continue_no_masks_no_edits(self):
|
||||
scheduler = self.get_scheduler()
|
||||
scheduler.set_timesteps(8)
|
||||
finished = torch.tensor([False])
|
||||
result = scheduler.check_block_should_continue(
|
||||
step_idx=5,
|
||||
masks_remaining=False,
|
||||
editing_enabled=True,
|
||||
editing_transfer_index=torch.zeros(1, 32, dtype=torch.bool),
|
||||
post_steps=1,
|
||||
max_post_steps=16,
|
||||
finished=finished,
|
||||
)
|
||||
self.assertFalse(result)
|
||||
|
||||
def test_check_block_should_continue_steps_exhausted(self):
|
||||
scheduler = self.get_scheduler()
|
||||
scheduler.set_timesteps(8)
|
||||
finished = torch.tensor([False])
|
||||
result = scheduler.check_block_should_continue(
|
||||
step_idx=8,
|
||||
masks_remaining=True,
|
||||
editing_enabled=False,
|
||||
editing_transfer_index=torch.zeros(1, 32, dtype=torch.bool),
|
||||
post_steps=0,
|
||||
max_post_steps=16,
|
||||
finished=finished,
|
||||
)
|
||||
self.assertFalse(result)
|
||||
|
||||
def test_check_eos_finished_marks_batch(self):
|
||||
"""When EOS is committed and all tokens before it are unmasked, mark batch as finished."""
|
||||
mask_id, eos_id, prompt_length = 99, 2, 2
|
||||
# cur_x: [prompt, prompt, token, eos, mask, mask]
|
||||
cur_x = torch.tensor([[10, 11, 5, eos_id, mask_id, mask_id]], dtype=torch.long)
|
||||
sampled_tokens = torch.tensor([[0, 0, 0, eos_id]], dtype=torch.long)
|
||||
final_transfer = torch.tensor([[False, False, False, True]])
|
||||
finished = torch.tensor([False])
|
||||
|
||||
finished = BlockRefinementScheduler.check_eos_finished(
|
||||
cur_x=cur_x,
|
||||
sampled_tokens=sampled_tokens,
|
||||
final_transfer=final_transfer,
|
||||
finished=finished,
|
||||
eos_token_id=eos_id,
|
||||
mask_token_id=mask_id,
|
||||
prompt_length=prompt_length,
|
||||
)
|
||||
self.assertTrue(finished[0].item())
|
||||
|
||||
def test_check_eos_finished_ignores_when_masks_before_eos(self):
|
||||
"""If there are still mask tokens between prompt and EOS, don't mark as finished."""
|
||||
mask_id, eos_id, prompt_length = 99, 2, 2
|
||||
# cur_x: [prompt, prompt, mask, eos] — mask before EOS
|
||||
cur_x = torch.tensor([[10, 11, mask_id, eos_id]], dtype=torch.long)
|
||||
sampled_tokens = torch.tensor([[0, 0]], dtype=torch.long)
|
||||
final_transfer = torch.tensor([[False, True]])
|
||||
finished = torch.tensor([False])
|
||||
|
||||
finished = BlockRefinementScheduler.check_eos_finished(
|
||||
cur_x=cur_x,
|
||||
sampled_tokens=sampled_tokens,
|
||||
final_transfer=final_transfer,
|
||||
finished=finished,
|
||||
eos_token_id=eos_id,
|
||||
mask_token_id=mask_id,
|
||||
prompt_length=prompt_length,
|
||||
)
|
||||
self.assertFalse(finished[0].item())
|
||||
|
||||
def test_check_eos_finished_already_finished(self):
|
||||
"""Already-finished batches should stay finished."""
|
||||
mask_id, eos_id = 99, 2
|
||||
cur_x = torch.tensor([[10, 11, 5, 6]], dtype=torch.long)
|
||||
sampled_tokens = torch.tensor([[0, 0]], dtype=torch.long)
|
||||
final_transfer = torch.tensor([[False, False]])
|
||||
finished = torch.tensor([True])
|
||||
|
||||
finished = BlockRefinementScheduler.check_eos_finished(
|
||||
cur_x=cur_x,
|
||||
sampled_tokens=sampled_tokens,
|
||||
final_transfer=final_transfer,
|
||||
finished=finished,
|
||||
eos_token_id=eos_id,
|
||||
mask_token_id=mask_id,
|
||||
prompt_length=2,
|
||||
)
|
||||
self.assertTrue(finished[0].item())
|
||||
|
||||
def test_add_noise(self):
|
||||
scheduler = self.get_scheduler(block_length=4)
|
||||
input_ids = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]], dtype=torch.long)
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
mask_token_id = 99
|
||||
|
||||
gen = torch.Generator().manual_seed(42)
|
||||
noisy, noisy_rev, masked, masked_rev = scheduler.add_noise(
|
||||
input_ids,
|
||||
attention_mask,
|
||||
prompt_length=2,
|
||||
block_length=4,
|
||||
mask_token_id=mask_token_id,
|
||||
generator=gen,
|
||||
)
|
||||
|
||||
# Prompt positions should never be masked
|
||||
self.assertFalse(masked[0, 0].item())
|
||||
self.assertFalse(masked[0, 1].item())
|
||||
self.assertFalse(masked_rev[0, 0].item())
|
||||
self.assertFalse(masked_rev[0, 1].item())
|
||||
|
||||
# Noisy should have mask_token_id where masked is True
|
||||
self.assertTrue((noisy[masked] == mask_token_id).all().item())
|
||||
self.assertTrue((noisy_rev[masked_rev] == mask_token_id).all().item())
|
||||
|
||||
# masked and masked_rev should be complementary within valid non-prompt positions
|
||||
non_prompt = torch.zeros_like(masked)
|
||||
non_prompt[0, 2:] = True
|
||||
combined = masked | masked_rev
|
||||
self.assertTrue((combined[0, 2:] == non_prompt[0, 2:]).all().item())
|
||||
|
||||
|
||||
class TestTopPFiltering(unittest.TestCase):
|
||||
def test_top_p_filtering(self):
|
||||
logits = torch.tensor([[1.0, 2.0, 3.0, 4.0]])
|
||||
filtered = BlockRefinementScheduler._top_p_filtering(logits, top_p=0.5)
|
||||
self.assertTrue((filtered > torch.finfo(filtered.dtype).min).any())
|
||||
self.assertTrue((filtered == torch.finfo(filtered.dtype).min).any())
|
||||
|
||||
def test_top_p_filtering_none(self):
|
||||
logits = torch.tensor([[1.0, 2.0, 3.0]])
|
||||
result = BlockRefinementScheduler._top_p_filtering(logits, top_p=None)
|
||||
self.assertTrue(torch.equal(result, logits))
|
||||
|
||||
def test_top_p_filtering_one(self):
|
||||
logits = torch.tensor([[1.0, 2.0, 3.0]])
|
||||
result = BlockRefinementScheduler._top_p_filtering(logits, top_p=1.0)
|
||||
self.assertTrue(torch.equal(result, logits))
|
||||
|
||||
|
||||
class TestTopKFiltering(unittest.TestCase):
|
||||
def test_top_k_filtering(self):
|
||||
logits = torch.tensor([[1.0, 4.0, 2.0, 3.0]])
|
||||
filtered = BlockRefinementScheduler._top_k_filtering(logits, top_k=2)
|
||||
self.assertAlmostEqual(filtered[0, 1].item(), 4.0)
|
||||
self.assertAlmostEqual(filtered[0, 3].item(), 3.0)
|
||||
self.assertEqual(filtered[0, 0].item(), torch.finfo(filtered.dtype).min)
|
||||
self.assertEqual(filtered[0, 2].item(), torch.finfo(filtered.dtype).min)
|
||||
|
||||
def test_top_k_filtering_none(self):
|
||||
logits = torch.tensor([[1.0, 2.0, 3.0]])
|
||||
result = BlockRefinementScheduler._top_k_filtering(logits, top_k=None)
|
||||
self.assertTrue(torch.equal(result, logits))
|
||||
|
||||
def test_top_k_filtering_zero(self):
|
||||
logits = torch.tensor([[1.0, 2.0, 3.0]])
|
||||
result = BlockRefinementScheduler._top_k_filtering(logits, top_k=0)
|
||||
self.assertTrue(torch.equal(result, logits))
|
||||
|
||||
def test_top_k_filtering_large_k(self):
|
||||
logits = torch.tensor([[1.0, 2.0, 3.0]])
|
||||
result = BlockRefinementScheduler._top_k_filtering(logits, top_k=100)
|
||||
self.assertTrue(torch.equal(result, logits))
|
||||
|
||||
|
||||
class TestSampleFromLogits(unittest.TestCase):
|
||||
def test_greedy_sampling(self):
|
||||
logits = torch.tensor([[1.0, 5.0, 2.0]])
|
||||
tokens, probs = BlockRefinementScheduler._sample_from_logits(
|
||||
logits,
|
||||
temperature=0.0,
|
||||
top_k=None,
|
||||
top_p=None,
|
||||
generator=None,
|
||||
use_multinomial=False,
|
||||
)
|
||||
self.assertEqual(tokens.item(), 1)
|
||||
self.assertEqual(tokens.shape, (1,))
|
||||
self.assertEqual(probs.shape, (1,))
|
||||
|
||||
def test_multinomial_sampling(self):
|
||||
logits = torch.tensor([[0.0, 100.0, -100.0]])
|
||||
gen = torch.Generator().manual_seed(42)
|
||||
tokens, probs = BlockRefinementScheduler._sample_from_logits(
|
||||
logits,
|
||||
temperature=1.0,
|
||||
top_k=None,
|
||||
top_p=None,
|
||||
generator=gen,
|
||||
use_multinomial=True,
|
||||
)
|
||||
self.assertEqual(tokens.item(), 1)
|
||||
|
||||
def test_temperature_scaling(self):
|
||||
logits = torch.tensor([[1.0, 2.0, 3.0]])
|
||||
tokens, _ = BlockRefinementScheduler._sample_from_logits(
|
||||
logits,
|
||||
temperature=0.01,
|
||||
top_k=None,
|
||||
top_p=None,
|
||||
generator=None,
|
||||
use_multinomial=False,
|
||||
)
|
||||
self.assertEqual(tokens.item(), 2)
|
||||
|
||||
def test_negative_temperature_raises(self):
|
||||
logits = torch.tensor([[1.0, 2.0]])
|
||||
with self.assertRaises(ValueError):
|
||||
BlockRefinementScheduler._sample_from_logits(
|
||||
logits,
|
||||
temperature=-1.0,
|
||||
top_k=None,
|
||||
top_p=None,
|
||||
generator=None,
|
||||
use_multinomial=False,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user