Compare commits

..

1 Commits

Author SHA1 Message Date
DN6
76062a74e0 update 2026-03-23 17:16:44 +05:30
54 changed files with 359 additions and 5373 deletions

View File

@@ -446,10 +446,6 @@
title: AutoencoderKLHunyuanVideo
- local: api/models/autoencoder_kl_hunyuan_video15
title: AutoencoderKLHunyuanVideo15
- local: api/models/autoencoder_kl_kvae
title: AutoencoderKLKVAE
- local: api/models/autoencoder_kl_kvae_video
title: AutoencoderKLKVAEVideo
- local: api/models/autoencoderkl_audio_ltx_2
title: AutoencoderKLLTX2Audio
- local: api/models/autoencoderkl_ltx_2
@@ -670,10 +666,6 @@
- local: api/pipelines/z_image
title: Z-Image
title: Image
- sections:
- local: api/pipelines/llada2
title: LLaDA2
title: Text
- sections:
- local: api/pipelines/allegro
title: Allegro
@@ -722,8 +714,6 @@
- sections:
- local: api/schedulers/overview
title: Overview
- local: api/schedulers/block_refinement
title: BlockRefinementScheduler
- local: api/schedulers/cm_stochastic_iterative
title: CMStochasticIterativeScheduler
- local: api/schedulers/ddim_cogvideox

View File

@@ -1,32 +0,0 @@
<!-- Copyright 2025 The Kandinsky Team and The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. -->
# AutoencoderKLKVAE
The 2D variational autoencoder (VAE) model with KL loss.
The model can be loaded with the following code snippet.
```python
import torch
from diffusers import AutoencoderKLKVAE
vae = AutoencoderKLKVAE.from_pretrained("kandinskylab/KVAE-2D-1.0", subfolder="diffusers", torch_dtype=torch.bfloat16)
```
## AutoencoderKLKVAE
[[autodoc]] AutoencoderKLKVAE
- decode
- all

View File

@@ -1,33 +0,0 @@
<!-- Copyright 2025 The Kandinsky Team and The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. -->
# AutoencoderKLKVAEVideo
The 3D variational autoencoder (VAE) model with KL loss.
The model can be loaded with the following code snippet.
```python
import torch
from diffusers import AutoencoderKLKVAEVideo
vae = AutoencoderKLKVAEVideo.from_pretrained("kandinskylab/KVAE-3D-1.0", subfolder="diffusers", torch_dtype=torch.float16)
```
## AutoencoderKLKVAEVideo
[[autodoc]] AutoencoderKLKVAEVideo
- decode
- all

View File

@@ -1,90 +0,0 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# LLaDA2
[LLaDA2](https://huggingface.co/collections/inclusionAI/llada21) is a family of discrete diffusion language models
that generate text through block-wise iterative refinement. Instead of autoregressive token-by-token generation,
LLaDA2 starts with a fully masked sequence and progressively unmasks tokens by confidence over multiple refinement
steps.
## Usage
```py
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from diffusers import BlockRefinementScheduler, LLaDA2Pipeline
model_id = "inclusionAI/LLaDA2.1-mini"
model = AutoModelForCausalLM.from_pretrained(
model_id, trust_remote_code=True, dtype=torch.bfloat16, device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
scheduler = BlockRefinementScheduler()
pipe = LLaDA2Pipeline(model=model, scheduler=scheduler, tokenizer=tokenizer)
output = pipe(
prompt="Write a short poem about the ocean.",
gen_length=256,
block_length=32,
num_inference_steps=32,
threshold=0.7,
editing_threshold=0.5,
max_post_steps=16,
temperature=0.0,
)
print(output.texts[0])
```
## Callbacks
Callbacks run after each refinement step. Pass `callback_on_step_end_tensor_inputs` to select which tensors are
included in `callback_kwargs`. In the current implementation, `block_x` (the sequence window being refined) and
`transfer_index` (mask-filling commit mask) are provided; return `{"block_x": ...}` from the callback to replace the
window.
```py
def on_step_end(pipe, step, timestep, callback_kwargs):
block_x = callback_kwargs["block_x"]
# Inspect or modify `block_x` here.
return {"block_x": block_x}
out = pipe(
prompt="Write a short poem.",
callback_on_step_end=on_step_end,
callback_on_step_end_tensor_inputs=["block_x"],
)
```
## Recommended parameters
LLaDA2.1 models support two modes:
| Mode | `threshold` | `editing_threshold` | `max_post_steps` |
|------|-------------|---------------------|------------------|
| Quality | 0.7 | 0.5 | 16 |
| Speed | 0.5 | `None` | 16 |
Pass `editing_threshold=None`, `0.0`, or a negative value to turn off post-mask editing.
For LLaDA2.0 models, disable editing by passing `editing_threshold=None` or `0.0`.
For all models: `block_length=32`, `temperature=0.0`, `num_inference_steps=32`.
## LLaDA2Pipeline
[[autodoc]] LLaDA2Pipeline
- all
- __call__
## LLaDA2PipelineOutput
[[autodoc]] pipelines.LLaDA2PipelineOutput

View File

@@ -63,7 +63,6 @@ The table below lists all the pipelines currently available in 🤗 Diffusers an
| [Latent Diffusion](latent_diffusion) | text2image, super-resolution |
| [Latte](latte) | text2image |
| [LEDITS++](ledits_pp) | image editing |
| [LLaDA2](llada2) | text2text |
| [Lumina-T2X](lumina) | text2image |
| [Marigold](marigold) | depth-estimation, normals-estimation, intrinsic-decomposition |
| [MultiDiffusion](panorama) | text2image |

View File

@@ -1,25 +0,0 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# BlockRefinementScheduler
The `BlockRefinementScheduler` manages block-wise iterative refinement for discrete token diffusion. At each step it
commits the most confident tokens and optionally edits already-committed tokens when the model predicts a different
token with high confidence.
This scheduler is used by [`LLaDA2Pipeline`].
## BlockRefinementScheduler
[[autodoc]] BlockRefinementScheduler
## BlockRefinementSchedulerOutput
[[autodoc]] schedulers.scheduling_block_refinement.BlockRefinementSchedulerOutput

View File

@@ -1,50 +0,0 @@
# Discrete Token Diffusion (Experimental)
This folder contains **training and sampling examples** for *discrete diffusion over token IDs* (language-model style), built to follow the `diffusers` + `accelerate` training conventions.
## LLaDA2
[LLaDA2](https://huggingface.co/collections/inclusionAI/llada21) generates text through block-wise iterative refinement. Instead of autoregressive token-by-token generation, it starts with a fully masked sequence and progressively unmasks tokens by confidence over multiple refinement steps.
### Train
The training script uses confidence-aware loss and works with any causal LM from the Hub (e.g. Qwen, Llama, Mistral):
```bash
accelerate launch examples/discrete_diffusion/train_llada2.py \
--model_name_or_path Qwen/Qwen2.5-0.5B \
--dataset_name wikitext \
--dataset_config_name wikitext-2-raw-v1 \
--text_column text \
--output_dir llada2-output \
--max_train_steps 1000 \
--prompt_length 32 \
--block_length 32 \
--lambda_conf 2.0 \
--conf_temperature 0.5
```
If you don't want to download a dataset, you can use random-token data:
```bash
accelerate launch examples/discrete_diffusion/train_llada2.py \
--model_name_or_path Qwen/Qwen2.5-0.5B \
--output_dir llada2-output \
--use_dummy_data \
--num_dummy_samples 2048
```
### Sample
```bash
python examples/discrete_diffusion/sample_llada2.py \
--model_id inclusionAI/LLaDA2.1-mini \
--prompt "Write a short poem about the ocean." \
--gen_length 256 \
--num_inference_steps 32 \
--threshold 0.7 \
--editing_threshold 0.5 \
--max_post_steps 16 \
--use_chat_template \
--add_generation_prompt
```

View File

@@ -1,263 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Sample script for LLaDA2-style discrete diffusion text generation.
This script demonstrates how to use the LLaDA2Pipeline for text generation
using block-wise iterative refinement.
Example usage:
python sample_llada2.py --model_id inclusionAI/LLaDA2.0-mini --prompt "What is the capital of France?"
python sample_llada2.py --model_id inclusionAI/LLaDA2.0-flash-CAP --prompt "Explain quantum computing." --temperature 0.7
"""
import argparse
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from diffusers import BlockRefinementScheduler, LLaDA2Pipeline
from diffusers.hooks import apply_group_offloading
def main():
parser = argparse.ArgumentParser(
description="Generate text using LLaDA2Pipeline with block-wise discrete diffusion."
)
parser.add_argument(
"--model_id",
type=str,
default="inclusionAI/LLaDA2.0-mini",
help="HuggingFace model ID or path to local model.",
)
parser.add_argument(
"--prompt",
type=str,
default="Why does Camus think that Sisyphus is happy?",
help="Text prompt to generate from.",
)
parser.add_argument(
"--gen_length",
type=int,
default=2048,
help="Number of tokens to generate.",
)
parser.add_argument(
"--block_length",
type=int,
default=32,
help="Size of each generation block.",
)
parser.add_argument(
"--num_inference_steps",
type=int,
default=32,
help="Number of refinement steps per block.",
)
parser.add_argument(
"--temperature",
type=float,
default=0.0,
help="Sampling temperature (0.0 for greedy).",
)
parser.add_argument(
"--top_p",
type=float,
default=None,
help="Nucleus sampling probability threshold.",
)
parser.add_argument(
"--top_k",
type=int,
default=None,
help="Top-k sampling parameter.",
)
parser.add_argument(
"--threshold",
type=float,
default=0.95,
help="Confidence threshold for committing tokens.",
)
parser.add_argument(
"--editing_threshold",
type=float,
default=None,
help="Confidence threshold for editing already-committed tokens. Set to enable post-mask editing (e.g. 0.5).",
)
parser.add_argument(
"--max_post_steps",
type=int,
default=0,
help="Maximum post-mask editing iterations per block (e.g. 16). Only used when --editing_threshold is set.",
)
parser.add_argument(
"--sampling_method",
type=str,
default="multinomial",
choices=["auto", "greedy", "multinomial"],
help="Sampling method for block refinement.",
)
parser.add_argument(
"--eos_early_stop",
action="store_true",
help="Stop generation early when EOS token is generated.",
)
parser.add_argument(
"--use_chat_template",
action="store_true",
help="Use the tokenizer chat template for the prompt.",
)
parser.add_argument(
"--add_generation_prompt",
action="store_true",
help="Add the generation prompt when using the chat template.",
)
parser.add_argument(
"--device",
type=str,
default="cuda" if torch.cuda.is_available() else "cpu",
help="Device to run inference on.",
)
parser.add_argument(
"--dtype",
type=str,
default="bfloat16",
choices=["float32", "float16", "bfloat16"],
help="Model dtype.",
)
parser.add_argument(
"--seed",
type=int,
default=None,
help="Random seed for reproducibility.",
)
parser.add_argument(
"--offload",
type=str,
default=None,
choices=["group", "sequential"],
help="Memory offloading strategy: 'group' for group offloading (faster), 'sequential' for sequential CPU offload (slower but lower memory).",
)
parser.add_argument(
"--revision",
type=str,
default=None,
help="Model revision (branch, tag, or commit hash) to load from the Hub.",
)
args = parser.parse_args()
# Parse dtype
dtype_map = {
"float32": torch.float32,
"float16": torch.float16,
"bfloat16": torch.bfloat16,
}
torch_dtype = dtype_map[args.dtype]
print(f"Loading model: {args.model_id}")
tokenizer = AutoTokenizer.from_pretrained(args.model_id, trust_remote_code=True, revision=args.revision)
# Load model with appropriate memory settings based on offload strategy
if args.offload == "group":
# For group offloading, load to CPU first then apply hooks
print("Using group offloading for memory efficiency...")
model = AutoModelForCausalLM.from_pretrained(
args.model_id,
trust_remote_code=True,
dtype=torch_dtype,
low_cpu_mem_usage=True,
revision=args.revision,
)
# Apply group offloading with CUDA streams for better performance
onload_device = torch.device(args.device)
offload_device = torch.device("cpu")
apply_group_offloading(
model,
onload_device=onload_device,
offload_device=offload_device,
offload_type="leaf_level",
use_stream=True,
)
elif args.offload == "sequential":
# For sequential offloading, load to CPU first
print("Using sequential CPU offloading (slower but lower memory)...")
model = AutoModelForCausalLM.from_pretrained(
args.model_id,
trust_remote_code=True,
dtype=torch_dtype,
low_cpu_mem_usage=True,
revision=args.revision,
)
# Sequential offloading will be applied via pipeline
else:
# Default: use device_map="auto" for automatic memory management
model = AutoModelForCausalLM.from_pretrained(
args.model_id,
trust_remote_code=True,
dtype=torch_dtype,
device_map="auto",
low_cpu_mem_usage=True,
revision=args.revision,
)
model.eval()
# Create pipeline
scheduler = BlockRefinementScheduler()
pipe = LLaDA2Pipeline(model=model, scheduler=scheduler, tokenizer=tokenizer)
# Apply sequential CPU offload if requested
if args.offload == "sequential":
pipe.enable_sequential_cpu_offload()
# Set up generator for reproducibility
generator = None
if args.seed is not None:
generator = torch.Generator(device=args.device).manual_seed(args.seed)
print(f"\nPrompt: {args.prompt}")
print(
f"Generating {args.gen_length} tokens with block_length={args.block_length}, steps={args.num_inference_steps}"
)
print("-" * 50)
# Generate
output = pipe(
prompt=args.prompt,
use_chat_template=args.use_chat_template,
add_generation_prompt=args.add_generation_prompt,
gen_length=args.gen_length,
block_length=args.block_length,
num_inference_steps=args.num_inference_steps,
temperature=args.temperature,
top_p=args.top_p,
top_k=args.top_k,
threshold=args.threshold,
editing_threshold=args.editing_threshold,
max_post_steps=args.max_post_steps,
sampling_method=args.sampling_method,
eos_early_stop=args.eos_early_stop,
generator=generator,
)
print("\nGenerated text:")
print(output.texts[0])
print(f"\nGenerated {output.sequences.shape[1]} tokens")
if __name__ == "__main__":
main()

View File

@@ -1,321 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import math
import os
from dataclasses import asdict, dataclass
from typing import Dict, Optional
import torch
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed
from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import AutoModelForCausalLM, AutoTokenizer, DataCollatorForLanguageModeling, get_scheduler
from diffusers import BlockRefinementScheduler
from diffusers.training_utils import compute_confidence_aware_loss
logger = get_logger(__name__)
@dataclass
class TrainConfig:
model_name_or_path: str
dataset_name: str
dataset_config_name: Optional[str]
text_column: str
cache_dir: Optional[str]
use_dummy_data: bool
num_dummy_samples: int
output_dir: str
seed: int
max_train_steps: int
checkpointing_steps: int
logging_steps: int
per_device_train_batch_size: int
gradient_accumulation_steps: int
learning_rate: float
weight_decay: float
lr_scheduler: str
lr_warmup_steps: int
max_length: int
prompt_length: int
block_length: int
lambda_conf: float
conf_temperature: float
def parse_args() -> TrainConfig:
parser = argparse.ArgumentParser(description="Train block-refinement with a confidence-aware loss on a causal LM.")
parser.add_argument("--model_name_or_path", type=str, default="Qwen/Qwen2.5-0.5B")
parser.add_argument("--dataset_name", type=str, default="wikitext")
parser.add_argument("--dataset_config_name", type=str, default="wikitext-2-raw-v1")
parser.add_argument("--text_column", type=str, default="text")
parser.add_argument("--cache_dir", type=str, default=None)
parser.add_argument("--use_dummy_data", action="store_true", help="Use random-token data instead of downloading.")
parser.add_argument("--num_dummy_samples", type=int, default=2048)
parser.add_argument("--output_dir", type=str, default="block-refinement-output")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--max_train_steps", type=int, default=1000)
parser.add_argument("--checkpointing_steps", type=int, default=500)
parser.add_argument("--logging_steps", type=int, default=50)
parser.add_argument("--per_device_train_batch_size", type=int, default=1)
parser.add_argument("--gradient_accumulation_steps", type=int, default=8)
parser.add_argument("--learning_rate", type=float, default=2e-5)
parser.add_argument("--weight_decay", type=float, default=0.0)
parser.add_argument(
"--lr_scheduler", type=str, default="cosine", choices=["linear", "cosine", "cosine_with_restarts"]
)
parser.add_argument("--lr_warmup_steps", type=int, default=100)
parser.add_argument("--max_length", type=int, default=256)
parser.add_argument("--prompt_length", type=int, default=32)
parser.add_argument("--block_length", type=int, default=32)
parser.add_argument("--lambda_conf", type=float, default=2.0)
parser.add_argument("--conf_temperature", type=float, default=0.5)
args = parser.parse_args()
return TrainConfig(**vars(args))
def tokenize_fn(examples: Dict, tokenizer, text_column: str, max_length: int):
texts = examples[text_column]
texts = [t for t in texts if isinstance(t, str) and len(t.strip()) > 0]
return tokenizer(texts, truncation=True, padding=False, max_length=max_length)
class RandomTokenDataset(torch.utils.data.Dataset):
def __init__(self, *, num_samples: int, seq_len: int, vocab_size: int, pad_token_id: int):
self.num_samples = int(num_samples)
self.seq_len = int(seq_len)
self.vocab_size = int(vocab_size)
self.pad_token_id = int(pad_token_id)
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
del idx
input_ids = torch.randint(0, self.vocab_size, (self.seq_len,), dtype=torch.long)
attention_mask = torch.ones_like(input_ids)
return {"input_ids": input_ids, "attention_mask": attention_mask}
def main():
cfg = parse_args()
if cfg.prompt_length >= cfg.max_length:
raise ValueError("`prompt_length` must be < `max_length`.")
if cfg.block_length <= 0:
raise ValueError("`block_length` must be > 0.")
project_config = ProjectConfiguration(project_dir=cfg.output_dir, logging_dir=os.path.join(cfg.output_dir, "logs"))
accelerator = Accelerator(
gradient_accumulation_steps=cfg.gradient_accumulation_steps,
project_config=project_config,
)
if accelerator.is_main_process:
os.makedirs(cfg.output_dir, exist_ok=True)
accelerator.wait_for_everyone()
set_seed(cfg.seed)
logger.info("Training configuration: %s", asdict(cfg))
tokenizer = AutoTokenizer.from_pretrained(cfg.model_name_or_path, use_fast=True, cache_dir=cfg.cache_dir)
if tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token
if tokenizer.mask_token_id is None:
tokenizer.add_special_tokens({"mask_token": "[MASK]"})
load_dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
model = AutoModelForCausalLM.from_pretrained(cfg.model_name_or_path, cache_dir=cfg.cache_dir, dtype=load_dtype)
model.resize_token_embeddings(len(tokenizer))
if load_dtype == torch.float32:
model.to(dtype=torch.float32)
mask_token_id = int(tokenizer.mask_token_id)
if cfg.use_dummy_data:
dataset = RandomTokenDataset(
num_samples=cfg.num_dummy_samples,
seq_len=cfg.max_length,
vocab_size=len(tokenizer),
pad_token_id=int(tokenizer.pad_token_id),
)
train_dataloader = DataLoader(
dataset,
shuffle=True,
batch_size=cfg.per_device_train_batch_size,
drop_last=True,
)
else:
raw_datasets = load_dataset(cfg.dataset_name, cfg.dataset_config_name, cache_dir=cfg.cache_dir)
if "train" not in raw_datasets:
raise ValueError(f"Dataset {cfg.dataset_name} has no 'train' split.")
with accelerator.main_process_first():
tokenized = raw_datasets["train"].map(
lambda ex: tokenize_fn(ex, tokenizer, cfg.text_column, cfg.max_length),
batched=True,
remove_columns=raw_datasets["train"].column_names,
desc="Tokenizing",
)
collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False, return_tensors="pt")
train_dataloader = DataLoader(
tokenized, shuffle=True, collate_fn=collator, batch_size=cfg.per_device_train_batch_size, drop_last=True
)
optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.learning_rate, weight_decay=cfg.weight_decay)
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / cfg.gradient_accumulation_steps)
num_train_epochs = math.ceil(cfg.max_train_steps / num_update_steps_per_epoch)
lr_scheduler = get_scheduler(
name=cfg.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=cfg.lr_warmup_steps,
num_training_steps=cfg.max_train_steps,
)
model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
model, optimizer, train_dataloader, lr_scheduler
)
noise_scheduler = BlockRefinementScheduler(block_length=cfg.block_length)
global_step = 0
model.train()
for _epoch in range(num_train_epochs):
for batch in train_dataloader:
with accelerator.accumulate(model):
input_ids = batch["input_ids"]
attention_mask = batch.get("attention_mask", torch.ones_like(input_ids))
gen = torch.Generator(device=input_ids.device).manual_seed(cfg.seed + global_step)
noisy, noisy_rev, masked, masked_rev = noise_scheduler.add_noise(
input_ids,
attention_mask,
prompt_length=cfg.prompt_length,
block_length=cfg.block_length,
mask_token_id=mask_token_id,
generator=gen,
)
position_ids = (
torch.arange(input_ids.shape[1], device=input_ids.device).unsqueeze(0).expand_as(input_ids)
)
logits = model(input_ids=noisy, attention_mask=attention_mask, position_ids=position_ids).logits
logits_rev = model(
input_ids=noisy_rev, attention_mask=attention_mask, position_ids=position_ids
).logits
logits = logits.clone()
logits[..., mask_token_id] = torch.finfo(logits.dtype).min
logits_rev = logits_rev.clone()
logits_rev[..., mask_token_id] = torch.finfo(logits_rev.dtype).min
valid = attention_mask.to(dtype=torch.bool)
masked = masked & valid
masked_rev = masked_rev & valid
labels = input_ids.clone()
labels[~masked] = -100
labels_rev = input_ids.clone()
labels_rev[~masked_rev] = -100
weights = masked.to(dtype=logits.dtype)
weights_rev = masked_rev.to(dtype=logits.dtype)
loss, loss_sft, loss_conf = compute_confidence_aware_loss(
logits,
labels,
lambda_conf=cfg.lambda_conf,
temperature=cfg.conf_temperature,
per_token_weights=weights,
)
loss_rev, loss_sft_rev, loss_conf_rev = compute_confidence_aware_loss(
logits_rev,
labels_rev,
lambda_conf=cfg.lambda_conf,
temperature=cfg.conf_temperature,
per_token_weights=weights_rev,
)
total_loss = loss + loss_rev
accelerator.backward(total_loss)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
if accelerator.sync_gradients:
global_step += 1
if global_step % cfg.logging_steps == 0 and accelerator.is_main_process:
logger.info(
"step=%d loss=%.4f sft=%.4f conf=%.4f lr=%.6g",
global_step,
total_loss.item(),
(loss_sft + loss_sft_rev).item(),
(loss_conf + loss_conf_rev).item(),
lr_scheduler.get_last_lr()[0],
)
print(
f"step={global_step} loss={total_loss.item():.4f} "
f"sft={(loss_sft + loss_sft_rev).item():.4f} "
f"conf={(loss_conf + loss_conf_rev).item():.4f} "
f"lr={lr_scheduler.get_last_lr()[0]:.6g}"
)
if cfg.checkpointing_steps > 0 and global_step % cfg.checkpointing_steps == 0:
accelerator.wait_for_everyone()
if accelerator.is_main_process:
save_dir = os.path.join(cfg.output_dir, f"checkpoint-{global_step}")
os.makedirs(save_dir, exist_ok=True)
accelerator.unwrap_model(model).save_pretrained(save_dir, save_function=accelerator.save)
tokenizer.save_pretrained(save_dir)
if global_step >= cfg.max_train_steps:
break
if global_step >= cfg.max_train_steps:
break
accelerator.wait_for_everyone()
if accelerator.is_main_process:
final_dir = os.path.join(cfg.output_dir, "final")
os.makedirs(final_dir, exist_ok=True)
accelerator.unwrap_model(model).save_pretrained(final_dir, save_function=accelerator.save)
tokenizer.save_pretrained(final_dir)
logger.info("Done.")
if __name__ == "__main__":
main()

View File

@@ -193,8 +193,6 @@ else:
"AutoencoderKLHunyuanImageRefiner",
"AutoencoderKLHunyuanVideo",
"AutoencoderKLHunyuanVideo15",
"AutoencoderKLKVAE",
"AutoencoderKLKVAEVideo",
"AutoencoderKLLTX2Audio",
"AutoencoderKLLTX2Video",
"AutoencoderKLLTXVideo",
@@ -344,8 +342,6 @@ else:
_import_structure["schedulers"].extend(
[
"AmusedScheduler",
"BlockRefinementScheduler",
"BlockRefinementSchedulerOutput",
"CMStochasticIterativeScheduler",
"CogVideoXDDIMScheduler",
"CogVideoXDPMScheduler",
@@ -582,8 +578,6 @@ else:
"LDMTextToImagePipeline",
"LEditsPPPipelineStableDiffusion",
"LEditsPPPipelineStableDiffusionXL",
"LLaDA2Pipeline",
"LLaDA2PipelineOutput",
"LongCatImageEditPipeline",
"LongCatImagePipeline",
"LTX2ConditionPipeline",
@@ -981,8 +975,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AutoencoderKLHunyuanImageRefiner,
AutoencoderKLHunyuanVideo,
AutoencoderKLHunyuanVideo15,
AutoencoderKLKVAE,
AutoencoderKLKVAEVideo,
AutoencoderKLLTX2Audio,
AutoencoderKLLTX2Video,
AutoencoderKLLTXVideo,
@@ -1128,8 +1120,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .quantizers import DiffusersQuantizer
from .schedulers import (
AmusedScheduler,
BlockRefinementScheduler,
BlockRefinementSchedulerOutput,
CMStochasticIterativeScheduler,
CogVideoXDDIMScheduler,
CogVideoXDPMScheduler,
@@ -1345,8 +1335,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
LDMTextToImagePipeline,
LEditsPPPipelineStableDiffusion,
LEditsPPPipelineStableDiffusionXL,
LLaDA2Pipeline,
LLaDA2PipelineOutput,
LongCatImageEditPipeline,
LongCatImagePipeline,
LTX2ConditionPipeline,

View File

@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import functools
import inspect
from dataclasses import dataclass
from typing import Type
@@ -31,7 +32,7 @@ from ..models._modeling_parallel import (
gather_size_by_comm,
)
from ..utils import get_logger
from ..utils.torch_utils import lru_cache_unless_export, maybe_allow_in_graph, unwrap_module
from ..utils.torch_utils import maybe_allow_in_graph, unwrap_module
from .hooks import HookRegistry, ModelHook
@@ -326,7 +327,7 @@ class PartitionAnythingSharder:
return tensor
@lru_cache_unless_export(maxsize=64)
@functools.lru_cache(maxsize=64)
def _fill_gather_shapes(shape: tuple[int], gather_dims: tuple[int], dim: int, world_size: int) -> list[list[int]]:
gather_shapes = []
for i in range(world_size):

View File

@@ -2443,191 +2443,6 @@ def _convert_non_diffusers_flux2_lora_to_diffusers(state_dict):
return converted_state_dict
def _convert_kohya_flux2_lora_to_diffusers(state_dict):
def _convert_to_ai_toolkit(sds_sd, ait_sd, sds_key, ait_key):
if sds_key + ".lora_down.weight" not in sds_sd:
return
down_weight = sds_sd.pop(sds_key + ".lora_down.weight")
# scale weight by alpha and dim
rank = down_weight.shape[0]
default_alpha = torch.tensor(rank, dtype=down_weight.dtype, device=down_weight.device, requires_grad=False)
alpha = sds_sd.pop(sds_key + ".alpha", default_alpha).item()
scale = alpha / rank
scale_down = scale
scale_up = 1.0
while scale_down * 2 < scale_up:
scale_down *= 2
scale_up /= 2
ait_sd[ait_key + ".lora_A.weight"] = down_weight * scale_down
ait_sd[ait_key + ".lora_B.weight"] = sds_sd.pop(sds_key + ".lora_up.weight") * scale_up
def _convert_to_ai_toolkit_cat(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
if sds_key + ".lora_down.weight" not in sds_sd:
return
down_weight = sds_sd.pop(sds_key + ".lora_down.weight")
up_weight = sds_sd.pop(sds_key + ".lora_up.weight")
sd_lora_rank = down_weight.shape[0]
default_alpha = torch.tensor(
sd_lora_rank, dtype=down_weight.dtype, device=down_weight.device, requires_grad=False
)
alpha = sds_sd.pop(sds_key + ".alpha", default_alpha)
scale = alpha / sd_lora_rank
scale_down = scale
scale_up = 1.0
while scale_down * 2 < scale_up:
scale_down *= 2
scale_up /= 2
down_weight = down_weight * scale_down
up_weight = up_weight * scale_up
num_splits = len(ait_keys)
if dims is None:
dims = [up_weight.shape[0] // num_splits] * num_splits
else:
assert sum(dims) == up_weight.shape[0]
# check if upweight is sparse
is_sparse = False
if sd_lora_rank % num_splits == 0:
ait_rank = sd_lora_rank // num_splits
is_sparse = True
i = 0
for j in range(len(dims)):
for k in range(len(dims)):
if j == k:
continue
is_sparse = is_sparse and torch.all(
up_weight[i : i + dims[j], k * ait_rank : (k + 1) * ait_rank] == 0
)
i += dims[j]
if is_sparse:
logger.info(f"weight is sparse: {sds_key}")
ait_down_keys = [k + ".lora_A.weight" for k in ait_keys]
ait_up_keys = [k + ".lora_B.weight" for k in ait_keys]
if not is_sparse:
ait_sd.update(dict.fromkeys(ait_down_keys, down_weight))
ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))}) # noqa: C416
else:
ait_sd.update({k: v for k, v in zip(ait_down_keys, torch.chunk(down_weight, num_splits, dim=0))}) # noqa: C416
i = 0
for j in range(len(dims)):
ait_sd[ait_up_keys[j]] = up_weight[i : i + dims[j], j * ait_rank : (j + 1) * ait_rank].contiguous()
i += dims[j]
# Detect number of blocks from keys
num_double_layers = 0
num_single_layers = 0
for key in state_dict.keys():
if key.startswith("lora_unet_double_blocks_"):
block_idx = int(key.split("_")[4])
num_double_layers = max(num_double_layers, block_idx + 1)
elif key.startswith("lora_unet_single_blocks_"):
block_idx = int(key.split("_")[4])
num_single_layers = max(num_single_layers, block_idx + 1)
ait_sd = {}
for i in range(num_double_layers):
# Attention projections
_convert_to_ai_toolkit(
state_dict,
ait_sd,
f"lora_unet_double_blocks_{i}_img_attn_proj",
f"transformer.transformer_blocks.{i}.attn.to_out.0",
)
_convert_to_ai_toolkit_cat(
state_dict,
ait_sd,
f"lora_unet_double_blocks_{i}_img_attn_qkv",
[
f"transformer.transformer_blocks.{i}.attn.to_q",
f"transformer.transformer_blocks.{i}.attn.to_k",
f"transformer.transformer_blocks.{i}.attn.to_v",
],
)
_convert_to_ai_toolkit(
state_dict,
ait_sd,
f"lora_unet_double_blocks_{i}_txt_attn_proj",
f"transformer.transformer_blocks.{i}.attn.to_add_out",
)
_convert_to_ai_toolkit_cat(
state_dict,
ait_sd,
f"lora_unet_double_blocks_{i}_txt_attn_qkv",
[
f"transformer.transformer_blocks.{i}.attn.add_q_proj",
f"transformer.transformer_blocks.{i}.attn.add_k_proj",
f"transformer.transformer_blocks.{i}.attn.add_v_proj",
],
)
# MLP layers (Flux2 uses ff.linear_in/linear_out)
_convert_to_ai_toolkit(
state_dict,
ait_sd,
f"lora_unet_double_blocks_{i}_img_mlp_0",
f"transformer.transformer_blocks.{i}.ff.linear_in",
)
_convert_to_ai_toolkit(
state_dict,
ait_sd,
f"lora_unet_double_blocks_{i}_img_mlp_2",
f"transformer.transformer_blocks.{i}.ff.linear_out",
)
_convert_to_ai_toolkit(
state_dict,
ait_sd,
f"lora_unet_double_blocks_{i}_txt_mlp_0",
f"transformer.transformer_blocks.{i}.ff_context.linear_in",
)
_convert_to_ai_toolkit(
state_dict,
ait_sd,
f"lora_unet_double_blocks_{i}_txt_mlp_2",
f"transformer.transformer_blocks.{i}.ff_context.linear_out",
)
for i in range(num_single_layers):
# Single blocks: linear1 -> attn.to_qkv_mlp_proj (fused, no split needed)
_convert_to_ai_toolkit(
state_dict,
ait_sd,
f"lora_unet_single_blocks_{i}_linear1",
f"transformer.single_transformer_blocks.{i}.attn.to_qkv_mlp_proj",
)
# Single blocks: linear2 -> attn.to_out
_convert_to_ai_toolkit(
state_dict,
ait_sd,
f"lora_unet_single_blocks_{i}_linear2",
f"transformer.single_transformer_blocks.{i}.attn.to_out",
)
# Handle optional extra keys
extra_mappings = {
"lora_unet_img_in": "transformer.x_embedder",
"lora_unet_txt_in": "transformer.context_embedder",
"lora_unet_time_in_in_layer": "transformer.time_guidance_embed.timestep_embedder.linear_1",
"lora_unet_time_in_out_layer": "transformer.time_guidance_embed.timestep_embedder.linear_2",
"lora_unet_final_layer_linear": "transformer.proj_out",
}
for sds_key, ait_key in extra_mappings.items():
_convert_to_ai_toolkit(state_dict, ait_sd, sds_key, ait_key)
remaining_keys = list(state_dict.keys())
if remaining_keys:
logger.warning(f"Unsupported keys for Kohya Flux2 LoRA conversion: {remaining_keys}")
return ait_sd
def _convert_non_diffusers_z_image_lora_to_diffusers(state_dict):
"""
Convert non-diffusers ZImage LoRA state dict to diffusers format.

View File

@@ -43,7 +43,6 @@ from .lora_conversion_utils import (
_convert_bfl_flux_control_lora_to_diffusers,
_convert_fal_kontext_lora_to_diffusers,
_convert_hunyuan_video_lora_to_diffusers,
_convert_kohya_flux2_lora_to_diffusers,
_convert_kohya_flux_lora_to_diffusers,
_convert_musubi_wan_lora_to_diffusers,
_convert_non_diffusers_flux2_lora_to_diffusers,
@@ -5674,13 +5673,6 @@ class Flux2LoraLoaderMixin(LoraBaseMixin):
logger.warning(warn_msg)
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
is_kohya = any(".lora_down.weight" in k for k in state_dict)
if is_kohya:
state_dict = _convert_kohya_flux2_lora_to_diffusers(state_dict)
# Kohya already takes care of scaling the LoRA parameters with alpha.
out = (state_dict, metadata) if return_lora_metadata else state_dict
return out
is_peft_format = any(k.startswith("base_model.model.") for k in state_dict)
if is_peft_format:
state_dict = {k.replace("base_model.model.", "diffusion_model."): v for k, v in state_dict.items()}

View File

@@ -15,7 +15,6 @@
import inspect
import json
import os
from collections import defaultdict
from functools import partial
from pathlib import Path
from typing import Literal
@@ -45,13 +44,33 @@ from .unet_loader_utils import _maybe_expand_lora_scales
logger = logging.get_logger(__name__)
_SET_ADAPTER_SCALE_FN_MAPPING = defaultdict(
lambda: (lambda model_cls, weights: weights),
{
"UNet2DConditionModel": _maybe_expand_lora_scales,
"UNetMotionModel": _maybe_expand_lora_scales,
},
)
_SET_ADAPTER_SCALE_FN_MAPPING = {
"UNet2DConditionModel": _maybe_expand_lora_scales,
"UNetMotionModel": _maybe_expand_lora_scales,
"SD3Transformer2DModel": lambda model_cls, weights: weights,
"FluxTransformer2DModel": lambda model_cls, weights: weights,
"CogVideoXTransformer3DModel": lambda model_cls, weights: weights,
"ConsisIDTransformer3DModel": lambda model_cls, weights: weights,
"HeliosTransformer3DModel": lambda model_cls, weights: weights,
"MochiTransformer3DModel": lambda model_cls, weights: weights,
"HunyuanVideoTransformer3DModel": lambda model_cls, weights: weights,
"LTXVideoTransformer3DModel": lambda model_cls, weights: weights,
"SanaTransformer2DModel": lambda model_cls, weights: weights,
"AuraFlowTransformer2DModel": lambda model_cls, weights: weights,
"Lumina2Transformer2DModel": lambda model_cls, weights: weights,
"WanTransformer3DModel": lambda model_cls, weights: weights,
"CogView4Transformer2DModel": lambda model_cls, weights: weights,
"HiDreamImageTransformer2DModel": lambda model_cls, weights: weights,
"HunyuanVideoFramepackTransformer3DModel": lambda model_cls, weights: weights,
"WanVACETransformer3DModel": lambda model_cls, weights: weights,
"ChromaTransformer2DModel": lambda model_cls, weights: weights,
"ChronoEditTransformer3DModel": lambda model_cls, weights: weights,
"QwenImageTransformer2DModel": lambda model_cls, weights: weights,
"Flux2Transformer2DModel": lambda model_cls, weights: weights,
"ZImageTransformer2DModel": lambda model_cls, weights: weights,
"LTX2VideoTransformer3DModel": lambda model_cls, weights: weights,
"LTX2TextConnectors": lambda model_cls, weights: weights,
}
class PeftAdapterMixin:

View File

@@ -40,8 +40,6 @@ if is_torch_available():
_import_structure["autoencoders.autoencoder_kl_hunyuanimage"] = ["AutoencoderKLHunyuanImage"]
_import_structure["autoencoders.autoencoder_kl_hunyuanimage_refiner"] = ["AutoencoderKLHunyuanImageRefiner"]
_import_structure["autoencoders.autoencoder_kl_hunyuanvideo15"] = ["AutoencoderKLHunyuanVideo15"]
_import_structure["autoencoders.autoencoder_kl_kvae"] = ["AutoencoderKLKVAE"]
_import_structure["autoencoders.autoencoder_kl_kvae_video"] = ["AutoencoderKLKVAEVideo"]
_import_structure["autoencoders.autoencoder_kl_ltx"] = ["AutoencoderKLLTXVideo"]
_import_structure["autoencoders.autoencoder_kl_ltx2"] = ["AutoencoderKLLTX2Video"]
_import_structure["autoencoders.autoencoder_kl_ltx2_audio"] = ["AutoencoderKLLTX2Audio"]
@@ -163,8 +161,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AutoencoderKLHunyuanImageRefiner,
AutoencoderKLHunyuanVideo,
AutoencoderKLHunyuanVideo15,
AutoencoderKLKVAE,
AutoencoderKLKVAEVideo,
AutoencoderKLLTX2Audio,
AutoencoderKLLTX2Video,
AutoencoderKLLTXVideo,

View File

@@ -49,7 +49,7 @@ from ..utils import (
is_xformers_version,
)
from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS
from ..utils.torch_utils import lru_cache_unless_export, maybe_allow_in_graph
from ..utils.torch_utils import maybe_allow_in_graph
from ._modeling_parallel import gather_size_by_comm
@@ -587,7 +587,7 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
)
@lru_cache_unless_export(maxsize=128)
@functools.lru_cache(maxsize=128)
def _prepare_for_flash_attn_or_sage_varlen_without_mask(
batch_size: int,
seq_len_q: int,

View File

@@ -9,8 +9,6 @@ from .autoencoder_kl_hunyuan_video import AutoencoderKLHunyuanVideo
from .autoencoder_kl_hunyuanimage import AutoencoderKLHunyuanImage
from .autoencoder_kl_hunyuanimage_refiner import AutoencoderKLHunyuanImageRefiner
from .autoencoder_kl_hunyuanvideo15 import AutoencoderKLHunyuanVideo15
from .autoencoder_kl_kvae import AutoencoderKLKVAE
from .autoencoder_kl_kvae_video import AutoencoderKLKVAEVideo
from .autoencoder_kl_ltx import AutoencoderKLLTXVideo
from .autoencoder_kl_ltx2 import AutoencoderKLLTX2Video
from .autoencoder_kl_ltx2_audio import AutoencoderKLLTX2Audio

View File

@@ -87,14 +87,7 @@ class HunyuanImageRefinerRMS_norm(nn.Module):
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
def forward(self, x):
needs_fp32_normalize = x.dtype in (torch.float16, torch.bfloat16) or any(
t in str(x.dtype) for t in ("float4_", "float8_")
)
normalized = F.normalize(x.float() if needs_fp32_normalize else x, dim=(1 if self.channel_first else -1)).to(
x.dtype
)
return normalized * self.scale * self.gamma + self.bias
return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias
class HunyuanImageRefinerAttnBlock(nn.Module):

View File

@@ -87,14 +87,7 @@ class HunyuanVideo15RMS_norm(nn.Module):
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
def forward(self, x):
needs_fp32_normalize = x.dtype in (torch.float16, torch.bfloat16) or any(
t in str(x.dtype) for t in ("float4_", "float8_")
)
normalized = F.normalize(x.float() if needs_fp32_normalize else x, dim=(1 if self.channel_first else -1)).to(
x.dtype
)
return normalized * self.scale * self.gamma + self.bias
return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias
class HunyuanVideo15AttnBlock(nn.Module):

View File

@@ -1,802 +0,0 @@
# Copyright 2025 The Kandinsky Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils.accelerate_utils import apply_forward_hook
from ..activations import get_activation
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
class KVAEResnetBlock2D(nn.Module):
r"""
A Resnet block with optional guidance.
Parameters:
in_channels (`int`): The number of channels in the input.
out_channels (`int`, *optional*, default to `None`):
The number of output channels for the first conv2d layer. If None, same as `in_channels`.
conv_shortcut (`bool`, *optional*, default to `False`):
If `True` and `in_channels` not equal to `out_channels`, add a 3x3 nn.conv2d layer for skip-connection.
temb_channels (`int`, *optional*, default to `512`): The number of channels in timestep embedding.
zq_ch (`int`, *optional*, default to `None`): Guidance channels for normalization.
add_conv (`bool`, *optional*, default to `False`):
If `True` add conv2d layer for normalization.
normalization (`nn.Module`, *optional*, default to `None`): The normalization layer.
act_fn (`str`, *optional*, default to `"swish"`): The activation function to use.
"""
def __init__(
self,
*,
in_channels: int,
out_channels: Optional[int] = None,
conv_shortcut: bool = False,
temb_channels: int = 512,
zq_ch: Optional[int] = None,
add_conv: bool = False,
act_fn: str = "swish",
):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
self.nonlinearity = get_activation(act_fn)
if zq_ch is None:
self.norm1 = nn.GroupNorm(num_channels=in_channels, num_groups=32, eps=1e-6, affine=True)
else:
self.norm1 = KVAEDecoderSpatialNorm2D(in_channels, zq_channels=zq_ch, add_conv=add_conv)
self.conv1 = nn.Conv2d(
in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=(1, 1), padding_mode="replicate"
)
if temb_channels > 0:
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
if zq_ch is None:
self.norm2 = nn.GroupNorm(num_channels=out_channels, num_groups=32, eps=1e-6, affine=True)
else:
self.norm2 = KVAEDecoderSpatialNorm2D(out_channels, zq_channels=zq_ch, add_conv=add_conv)
self.conv2 = nn.Conv2d(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=3,
padding=(1, 1),
padding_mode="replicate",
)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
self.conv_shortcut = nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
padding=(1, 1),
padding_mode="replicate",
)
else:
self.nin_shortcut = nn.Conv2d(
in_channels,
out_channels,
kernel_size=1,
stride=1,
padding=0,
)
def forward(self, x: torch.Tensor, temb: torch.Tensor, zq: torch.Tensor = None) -> torch.Tensor:
h = x
if zq is None:
h = self.norm1(h)
else:
h = self.norm1(h, zq)
h = self.nonlinearity(h)
h = self.conv1(h)
if temb is not None:
h = h + self.temb_proj(self.nonlinearity(temb))[:, :, None, None, None]
if zq is None:
h = self.norm2(h)
else:
h = self.norm2(h, zq)
h = self.nonlinearity(h)
h = self.conv2(h)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
x = self.conv_shortcut(x)
else:
x = self.nin_shortcut(x)
return x + h
class KVAEPXSDownsample(nn.Module):
def __init__(self, in_channels: int, factor: int = 2):
r"""
A Downsampling module.
Args:
in_channels (`int`): The number of channels in the input.
factor (`int`, *optional*, default to `2`): The downsampling factor.
"""
super().__init__()
self.factor = factor
self.unshuffle = nn.PixelUnshuffle(self.factor)
self.spatial_conv = nn.Conv2d(
in_channels, in_channels, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), padding_mode="reflect"
)
self.linear = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: (bchw)
pxs_interm = self.unshuffle(x)
b, c, h, w = pxs_interm.shape
pxs_interm_view = pxs_interm.view(b, c // self.factor**2, self.factor**2, h, w)
pxs_out = torch.mean(pxs_interm_view, dim=2)
conv_out = self.spatial_conv(x)
# adding it all together
out = conv_out + pxs_out
return self.linear(out)
class KVAEPXSUpsample(nn.Module):
def __init__(self, in_channels: int, factor: int = 2):
r"""
An Upsampling module.
Args:
in_channels (`int`): The number of channels in the input.
factor (`int`, *optional*, default to `2`): The upsampling factor.
"""
super().__init__()
self.factor = factor
self.shuffle = nn.PixelShuffle(self.factor)
self.spatial_conv = nn.Conv2d(
in_channels, in_channels, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode="reflect"
)
self.linear = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
repeated = x.repeat_interleave(self.factor**2, dim=1)
pxs_interm = self.shuffle(repeated)
image_like_ups = F.interpolate(x, scale_factor=2, mode="nearest")
conv_out = self.spatial_conv(image_like_ups)
# adding it all together
out = conv_out + pxs_interm
return self.linear(out)
class KVAEDecoderSpatialNorm2D(nn.Module):
r"""
A 2D normalization module for decoder.
Args:
in_channels (`int`): The number of channels in the input.
zq_channels (`int`): The number of channels in the guidance.
add_conv (`bool`, *optional*, default to `false`):
If `True` add conv2d 3x3 layer for guidance in the beginning.
"""
def __init__(
self,
in_channels: int,
zq_channels: int,
add_conv: bool = False,
):
super().__init__()
self.norm_layer = nn.GroupNorm(num_channels=in_channels, num_groups=32, eps=1e-6, affine=True)
self.add_conv = add_conv
if add_conv:
self.conv = nn.Conv2d(
in_channels=zq_channels,
out_channels=zq_channels,
kernel_size=3,
padding=(1, 1),
padding_mode="replicate",
)
self.conv_y = nn.Conv2d(
in_channels=zq_channels,
out_channels=in_channels,
kernel_size=1,
)
self.conv_b = nn.Conv2d(
in_channels=zq_channels,
out_channels=in_channels,
kernel_size=1,
)
def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor:
f_first = f
f_first_size = f_first.shape[2:]
zq = F.interpolate(zq, size=f_first_size, mode="nearest")
if self.add_conv:
zq = self.conv(zq)
norm_f = self.norm_layer(f)
new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
return new_f
class KVAEEncoder2D(nn.Module):
r"""
A 2D encoder module.
Args:
ch (`int`): The base number of channels in multiresolution blocks.
ch_mult (`Tuple[int, ...]`, *optional*, default to `(1, 2, 4, 8)`):
The channel multipliers in multiresolution blocks.
num_res_blocks (`int`): The number of Resnet blocks.
in_channels (`int`): The number of channels in the input.
z_channels (`int`): The number of output channels.
double_z (`bool`, *optional*, defaults to `True`):
Whether to double the number of output channels for the last block.
act_fn (`str`, *optional*, default to `"swish"`): The activation function to use.
"""
def __init__(
self,
*,
ch: int,
ch_mult: Tuple[int, ...] = (1, 2, 4, 8),
num_res_blocks: int,
in_channels: int,
z_channels: int,
double_z: bool = True,
act_fn: str = "swish",
):
super().__init__()
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
if isinstance(num_res_blocks, int):
self.num_res_blocks = [num_res_blocks] * self.num_resolutions
else:
self.num_res_blocks = num_res_blocks
self.nonlinearity = get_activation(act_fn)
self.in_channels = in_channels
self.conv_in = nn.Conv2d(
in_channels=in_channels,
out_channels=self.ch,
kernel_size=3,
padding=(1, 1),
)
in_ch_mult = (1,) + tuple(ch_mult)
self.down = nn.ModuleList()
for i_level in range(self.num_resolutions):
block = nn.ModuleList()
attn = nn.ModuleList()
block_in = ch * in_ch_mult[i_level]
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks[i_level]):
block.append(
KVAEResnetBlock2D(
in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
)
)
block_in = block_out
down = nn.Module()
down.block = block
down.attn = attn
if i_level < self.num_resolutions - 1:
down.downsample = KVAEPXSDownsample(in_channels=block_in) # mb: bad out channels
self.down.append(down)
# middle
self.mid = nn.Module()
self.mid.block_1 = KVAEResnetBlock2D(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
)
self.mid.block_2 = KVAEResnetBlock2D(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
)
# end
self.norm_out = nn.GroupNorm(num_channels=block_in, num_groups=32, eps=1e-6, affine=True)
self.conv_out = nn.Conv2d(
in_channels=block_in,
out_channels=2 * z_channels if double_z else z_channels,
kernel_size=3,
padding=(1, 1),
)
self.gradient_checkpointing = False
def forward(self, x: torch.Tensor) -> torch.Tensor:
# timestep embedding
temb = None
# downsampling
h = self.conv_in(x)
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks[i_level]):
if torch.is_grad_enabled() and self.gradient_checkpointing:
h = self._gradient_checkpointing_func(self.down[i_level].block[i_block], h, temb)
else:
h = self.down[i_level].block[i_block](h, temb)
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
if i_level != self.num_resolutions - 1:
h = self.down[i_level].downsample(h)
# middle
if torch.is_grad_enabled() and self.gradient_checkpointing:
h = self._gradient_checkpointing_func(self.mid.block_1, h, temb)
h = self._gradient_checkpointing_func(self.mid.block_2, h, temb)
else:
h = self.mid.block_1(h, temb)
h = self.mid.block_2(h, temb)
# end
h = self.norm_out(h)
h = self.nonlinearity(h)
h = self.conv_out(h)
return h
class KVAEDecoder2D(nn.Module):
r"""
A 2D decoder module.
Args:
ch (`int`): The base number of channels in multiresolution blocks.
out_ch (`int`): The number of output channels.
ch_mult (`Tuple[int, ...]`, *optional*, default to `(1, 2, 4, 8)`):
The channel multipliers in multiresolution blocks.
num_res_blocks (`int`): The number of Resnet blocks.
in_channels (`int`): The number of channels in the input.
z_channels (`int`): The number of input channels.
give_pre_end (`bool`, *optional*, default to `false`):
If `True` exit the forward pass early and return the penultimate feature map.
zq_ch (`bool`, *optional*, default to `None`): The number of channels in the guidance.
add_conv (`bool`, *optional*, default to `false`): If `True` add conv2d layer for Resnet normalization layer.
act_fn (`str`, *optional*, default to `"swish"`): The activation function to use.
"""
def __init__(
self,
*,
ch: int,
out_ch: int,
ch_mult: Tuple[int, ...] = (1, 2, 4, 8),
num_res_blocks: int,
in_channels: int,
z_channels: int,
give_pre_end: bool = False,
zq_ch: Optional[int] = None,
add_conv: bool = False,
act_fn: str = "swish",
):
super().__init__()
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.in_channels = in_channels
self.give_pre_end = give_pre_end
self.nonlinearity = get_activation(act_fn)
if zq_ch is None:
zq_ch = z_channels
# compute in_ch_mult, block_in and curr_res at lowest res
block_in = ch * ch_mult[self.num_resolutions - 1]
self.conv_in = nn.Conv2d(
in_channels=z_channels, out_channels=block_in, kernel_size=3, padding=(1, 1), padding_mode="replicate"
)
# middle
self.mid = nn.Module()
self.mid.block_1 = KVAEResnetBlock2D(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
zq_ch=zq_ch,
add_conv=add_conv,
)
self.mid.block_2 = KVAEResnetBlock2D(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
zq_ch=zq_ch,
add_conv=add_conv,
)
# upsampling
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks + 1):
block.append(
KVAEResnetBlock2D(
in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
zq_ch=zq_ch,
add_conv=add_conv,
)
)
block_in = block_out
up = nn.Module()
up.block = block
up.attn = attn
if i_level != 0:
up.upsample = KVAEPXSUpsample(in_channels=block_in)
self.up.insert(0, up)
self.norm_out = KVAEDecoderSpatialNorm2D(block_in, zq_ch, add_conv=add_conv) # , gather=gather_norm)
self.conv_out = nn.Conv2d(
in_channels=block_in, out_channels=out_ch, kernel_size=3, padding=(1, 1), padding_mode="replicate"
)
self.gradient_checkpointing = False
def forward(self, z: torch.Tensor) -> torch.Tensor:
self.last_z_shape = z.shape
# timestep embedding
temb = None
# z to block_in
zq = z
h = self.conv_in(z)
# middle
if torch.is_grad_enabled() and self.gradient_checkpointing:
h = self._gradient_checkpointing_func(self.mid.block_1, h, temb, zq)
h = self._gradient_checkpointing_func(self.mid.block_2, h, temb, zq)
else:
h = self.mid.block_1(h, temb, zq)
h = self.mid.block_2(h, temb, zq)
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
if torch.is_grad_enabled() and self.gradient_checkpointing:
h = self._gradient_checkpointing_func(self.up[i_level].block[i_block], h, temb, zq)
else:
h = self.up[i_level].block[i_block](h, temb, zq)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h, zq)
if i_level != 0:
h = self.up[i_level].upsample(h)
# end
if self.give_pre_end:
return h
h = self.norm_out(h, zq)
h = self.nonlinearity(h)
h = self.conv_out(h)
return h
class AutoencoderKLKVAE(ModelMixin, AutoencoderMixin, ConfigMixin):
r"""
A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
This model inherits from [`ModelMixin`]. Check the superclass documentation for its generic methods implemented for
all models (such as downloading or saving).
Parameters:
in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
channels (int, *optional*, defaults to 128): The base number of channels in multiresolution blocks.
num_enc_blocks (int, *optional*, defaults to 2):
The number of Resnet blocks in encoder multiresolution layers.
num_dec_blocks (int, *optional*, defaults to 2):
The number of Resnet blocks in decoder multiresolution layers.
z_channels (int, *optional*, defaults to 16): Number of channels in the latent space.
double_z (`bool`, *optional*, defaults to `True`):
Whether to double the number of output channels of encoder.
ch_mult (`Tuple[int, ...]`, *optional*, default to `(1, 2, 4, 8)`):
The channel multipliers in multiresolution blocks.
sample_size (`int`, *optional*, defaults to `1024`): Sample input size.
"""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
in_channels: int = 3,
channels: int = 128,
num_enc_blocks: int = 2,
num_dec_blocks: int = 2,
z_channels: int = 16,
double_z: bool = True,
ch_mult: Tuple[int, ...] = (1, 2, 4, 8),
sample_size: int = 1024,
):
super().__init__()
# pass init params to Encoder
self.encoder = KVAEEncoder2D(
in_channels=in_channels,
ch=channels,
ch_mult=ch_mult,
num_res_blocks=num_enc_blocks,
z_channels=z_channels,
double_z=double_z,
)
# pass init params to Decoder
self.decoder = KVAEDecoder2D(
out_ch=in_channels,
ch=channels,
ch_mult=ch_mult,
num_res_blocks=num_dec_blocks,
in_channels=None,
z_channels=z_channels,
)
self.use_slicing = False
self.use_tiling = False
# only relevant if vae tiling is enabled
self.tile_sample_min_size = self.config.sample_size
sample_size = (
self.config.sample_size[0]
if isinstance(self.config.sample_size, (list, tuple))
else self.config.sample_size
)
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.ch_mult) - 1)))
self.tile_overlap_factor = 0.25
def _encode(self, x: torch.Tensor) -> torch.Tensor:
batch_size, num_channels, height, width = x.shape
if self.use_tiling and (width > self.tile_sample_min_size or height > self.tile_sample_min_size):
return self._tiled_encode(x)
enc = self.encoder(x)
return enc
@apply_forward_hook
def encode(
self, x: torch.Tensor, return_dict: bool = True
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
"""
Encode a batch of images into latents.
Args:
x (`torch.Tensor`): Input batch of images.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
Returns:
The latent representations of the encoded images. If `return_dict` is True, a
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
"""
if self.use_slicing and x.shape[0] > 1:
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
h = torch.cat(encoded_slices)
else:
h = self._encode(x)
posterior = DiagonalGaussianDistribution(h)
if not return_dict:
return (posterior,)
return AutoencoderKLOutput(latent_dist=posterior)
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):
return self.tiled_decode(z, return_dict=return_dict)
dec = self.decoder(z)
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
@apply_forward_hook
def decode(
self, z: torch.FloatTensor, return_dict: bool = True, generator=None
) -> Union[DecoderOutput, torch.FloatTensor]:
"""
Decode a batch of images.
Args:
z (`torch.Tensor`): Input batch of latent vectors.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
Returns:
[`~models.vae.DecoderOutput`] or `tuple`:
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
returned.
"""
if self.use_slicing and z.shape[0] > 1:
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
decoded = torch.cat(decoded_slices)
else:
decoded = self._decode(z).sample
if not return_dict:
return (decoded,)
return DecoderOutput(sample=decoded)
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[2], b.shape[2], blend_extent)
for y in range(blend_extent):
b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent)
return b
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[3], b.shape[3], blend_extent)
for x in range(blend_extent):
b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
return b
def _tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
r"""Encode a batch of images using a tiled encoder.
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
output, but they should be much less noticeable.
Args:
x (`torch.Tensor`): Input batch of images.
Returns:
`torch.Tensor`:
The latent representation of the encoded videos.
"""
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
row_limit = self.tile_latent_min_size - blend_extent
# Split the image into 512x512 tiles and encode them separately.
rows = []
for i in range(0, x.shape[2], overlap_size):
row = []
for j in range(0, x.shape[3], overlap_size):
tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
tile = self.encoder(tile)
row.append(tile)
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_extent)
result_row.append(tile[:, :, :row_limit, :row_limit])
result_rows.append(torch.cat(result_row, dim=3))
enc = torch.cat(result_rows, dim=2)
return enc
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
r"""
Decode a batch of images using a tiled decoder.
Args:
z (`torch.Tensor`): Input batch of latent vectors.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
Returns:
[`~models.vae.DecoderOutput`] or `tuple`:
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
returned.
"""
overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
row_limit = self.tile_sample_min_size - blend_extent
# Split z into overlapping 64x64 tiles and decode them separately.
# The tiles have an overlap to avoid seams between tiles.
rows = []
for i in range(0, z.shape[2], overlap_size):
row = []
for j in range(0, z.shape[3], overlap_size):
tile = z[:, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size]
decoded = self.decoder(tile)
row.append(decoded)
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_extent)
result_row.append(tile[:, :, :row_limit, :row_limit])
result_rows.append(torch.cat(result_row, dim=3))
dec = torch.cat(result_rows, dim=2)
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
def forward(
self,
sample: torch.Tensor,
sample_posterior: bool = False,
return_dict: bool = True,
generator: Optional[torch.Generator] = None,
) -> Union[DecoderOutput, torch.Tensor]:
r"""
Args:
sample (`torch.Tensor`): Input sample.
sample_posterior (`bool`, *optional*, defaults to `False`):
Whether to sample from the posterior.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
"""
x = sample
posterior = self.encode(x).latent_dist
if sample_posterior:
z = posterior.sample(generator=generator)
else:
z = posterior.mode()
dec = self.decode(z).sample
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)

View File

@@ -1,954 +0,0 @@
# Copyright 2025 The Kandinsky Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Dict, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders.single_file_model import FromOriginalModelMixin
from ...utils import logging
from ...utils.accelerate_utils import apply_forward_hook
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def nonlinearity(x: torch.Tensor) -> torch.Tensor:
return F.silu(x)
# =============================================================================
# Base layers
# =============================================================================
class KVAESafeConv3d(nn.Conv3d):
r"""
A 3D convolution layer that splits the input tensor into smaller parts to avoid OOM.
"""
def forward(self, input: torch.Tensor, write_to: torch.Tensor = None) -> torch.Tensor:
memory_count = input.numel() * input.element_size() / (10**9)
if memory_count > 3:
kernel_size = self.kernel_size[0]
part_num = math.ceil(memory_count / 2)
input_chunks = torch.chunk(input, part_num, dim=2)
if write_to is None:
output = []
for i, chunk in enumerate(input_chunks):
if i == 0 or kernel_size == 1:
z = torch.clone(chunk)
else:
z = torch.cat([z[:, :, -kernel_size + 1 :], chunk], dim=2)
output.append(super().forward(z))
return torch.cat(output, dim=2)
else:
time_offset = 0
for i, chunk in enumerate(input_chunks):
if i == 0 or kernel_size == 1:
z = torch.clone(chunk)
else:
z = torch.cat([z[:, :, -kernel_size + 1 :], chunk], dim=2)
z_time = z.size(2) - (kernel_size - 1)
write_to[:, :, time_offset : time_offset + z_time] = super().forward(z)
time_offset += z_time
return write_to
else:
if write_to is None:
return super().forward(input)
else:
write_to[...] = super().forward(input)
return write_to
class KVAECausalConv3d(nn.Module):
r"""
A 3D causal convolution layer.
"""
def __init__(
self,
chan_in: int,
chan_out: int,
kernel_size: Union[int, Tuple[int, int, int]],
stride: Tuple[int, int, int] = (1, 1, 1),
dilation: Tuple[int, int, int] = (1, 1, 1),
**kwargs,
):
super().__init__()
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size, kernel_size)
time_kernel_size, height_kernel_size, width_kernel_size = kernel_size
self.height_pad = height_kernel_size // 2
self.width_pad = width_kernel_size // 2
self.time_pad = time_kernel_size - 1
self.time_kernel_size = time_kernel_size
self.stride = stride
self.conv = KVAESafeConv3d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs)
def forward(self, input: torch.Tensor) -> torch.Tensor:
padding_3d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad, self.time_pad, 0)
input_padded = F.pad(input, padding_3d, mode="replicate")
return self.conv(input_padded)
class KVAECachedCausalConv3d(nn.Module):
r"""
A 3D causal convolution layer with caching for temporal processing.
"""
def __init__(
self,
chan_in: int,
chan_out: int,
kernel_size: Union[int, Tuple[int, int, int]],
stride: Tuple[int, int, int] = (1, 1, 1),
dilation: Tuple[int, int, int] = (1, 1, 1),
**kwargs,
):
super().__init__()
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size, kernel_size)
time_kernel_size, height_kernel_size, width_kernel_size = kernel_size
self.height_pad = height_kernel_size // 2
self.width_pad = width_kernel_size // 2
self.time_pad = time_kernel_size - 1
self.time_kernel_size = time_kernel_size
self.stride = stride
self.conv = KVAESafeConv3d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs)
def forward(self, input: torch.Tensor, cache: Dict) -> torch.Tensor:
t_stride = self.stride[0]
padding_3d = (self.height_pad, self.height_pad, self.width_pad, self.width_pad, 0, 0)
input_parallel = F.pad(input, padding_3d, mode="replicate")
if cache["padding"] is None:
first_frame = input_parallel[:, :, :1]
time_pad_shape = list(first_frame.shape)
time_pad_shape[2] = self.time_pad
padding = first_frame.expand(time_pad_shape)
else:
padding = cache["padding"]
out_size = list(input.shape)
out_size[1] = self.conv.out_channels
if t_stride == 2:
out_size[2] = (input.size(2) + 1) // 2
output = torch.empty(tuple(out_size), dtype=input.dtype, device=input.device)
offset_out = math.ceil(padding.size(2) / t_stride)
offset_in = offset_out * t_stride - padding.size(2)
if offset_out > 0:
padding_poisoned = torch.cat(
[padding, input_parallel[:, :, : offset_in + self.time_kernel_size - t_stride]], dim=2
)
output[:, :, :offset_out] = self.conv(padding_poisoned)
if offset_out < output.size(2):
output[:, :, offset_out:] = self.conv(input_parallel[:, :, offset_in:])
pad_offset = (
offset_in
+ t_stride * math.trunc((input_parallel.size(2) - offset_in - self.time_kernel_size) / t_stride)
+ t_stride
)
cache["padding"] = torch.clone(input_parallel[:, :, pad_offset:])
return output
class KVAECachedGroupNorm(nn.Module):
r"""
GroupNorm with caching support for temporal processing.
"""
def __init__(self, in_channels: int):
super().__init__()
self.norm_layer = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
def forward(self, x: torch.Tensor, cache: Dict = None) -> torch.Tensor:
out = self.norm_layer(x)
if cache is not None and cache.get("mean") is None and cache.get("var") is None:
cache["mean"] = 1
cache["var"] = 1
return out
# =============================================================================
# Cached layers
# =============================================================================
class KVAECachedSpatialNorm3D(nn.Module):
r"""
Spatially conditioned normalization for decoder with caching.
"""
def __init__(
self,
f_channels: int,
zq_channels: int,
add_conv: bool = False,
):
super().__init__()
self.norm_layer = KVAECachedGroupNorm(f_channels)
self.add_conv = add_conv
if add_conv:
self.conv = KVAECachedCausalConv3d(chan_in=zq_channels, chan_out=zq_channels, kernel_size=3)
self.conv_y = KVAESafeConv3d(zq_channels, f_channels, kernel_size=1)
self.conv_b = KVAESafeConv3d(zq_channels, f_channels, kernel_size=1)
def forward(self, f: torch.Tensor, zq: torch.Tensor, cache: Dict) -> torch.Tensor:
if cache["norm"].get("mean") is None and cache["norm"].get("var") is None:
f_first, f_rest = f[:, :, :1], f[:, :, 1:]
f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:]
zq_first, zq_rest = zq[:, :, :1], zq[:, :, 1:]
zq_first = F.interpolate(zq_first, size=f_first_size, mode="nearest")
if zq.size(2) > 1:
zq_rest_splits = torch.split(zq_rest, 32, dim=1)
interpolated_splits = [
F.interpolate(split, size=f_rest_size, mode="nearest") for split in zq_rest_splits
]
zq_rest = torch.cat(interpolated_splits, dim=1)
zq = torch.cat([zq_first, zq_rest], dim=2)
else:
zq = zq_first
else:
f_size = f.shape[-3:]
zq_splits = torch.split(zq, 32, dim=1)
interpolated_splits = [F.interpolate(split, size=f_size, mode="nearest") for split in zq_splits]
zq = torch.cat(interpolated_splits, dim=1)
if self.add_conv:
zq = self.conv(zq, cache["add_conv"])
norm_f = self.norm_layer(f, cache["norm"])
norm_f = norm_f * self.conv_y(zq)
norm_f = norm_f + self.conv_b(zq)
return norm_f
class KVAECachedResnetBlock3D(nn.Module):
r"""
A 3D ResNet block with caching.
"""
def __init__(
self,
in_channels: int,
out_channels: Optional[int] = None,
conv_shortcut: bool = False,
dropout: float = 0.0,
temb_channels: int = 0,
zq_ch: Optional[int] = None,
add_conv: bool = False,
gather_norm: bool = False,
):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
if zq_ch is None:
self.norm1 = KVAECachedGroupNorm(in_channels)
else:
self.norm1 = KVAECachedSpatialNorm3D(in_channels, zq_ch, add_conv=add_conv)
self.conv1 = KVAECachedCausalConv3d(chan_in=in_channels, chan_out=out_channels, kernel_size=3)
if temb_channels > 0:
self.temb_proj = nn.Linear(temb_channels, out_channels)
if zq_ch is None:
self.norm2 = KVAECachedGroupNorm(out_channels)
else:
self.norm2 = KVAECachedSpatialNorm3D(out_channels, zq_ch, add_conv=add_conv)
self.conv2 = KVAECachedCausalConv3d(chan_in=out_channels, chan_out=out_channels, kernel_size=3)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
self.conv_shortcut = KVAECachedCausalConv3d(chan_in=in_channels, chan_out=out_channels, kernel_size=3)
else:
self.nin_shortcut = KVAESafeConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x: torch.Tensor, temb: torch.Tensor, layer_cache: Dict, zq: torch.Tensor = None) -> torch.Tensor:
h = x
if zq is None:
# Encoder path - norm takes cache
h = self.norm1(h, cache=layer_cache["norm1"])
else:
# Decoder path - spatial norm takes zq and cache
h = self.norm1(h, zq, cache=layer_cache["norm1"])
h = F.silu(h)
h = self.conv1(h, cache=layer_cache["conv1"])
if temb is not None:
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None, None]
if zq is None:
h = self.norm2(h, cache=layer_cache["norm2"])
else:
h = self.norm2(h, zq, cache=layer_cache["norm2"])
h = F.silu(h)
h = self.conv2(h, cache=layer_cache["conv2"])
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
x = self.conv_shortcut(x, cache=layer_cache["conv_shortcut"])
else:
x = self.nin_shortcut(x)
return x + h
class KVAECachedPXSDownsample(nn.Module):
r"""
A 3D downsampling layer using PixelUnshuffle with caching.
"""
def __init__(self, in_channels: int, compress_time: bool, factor: int = 2):
super().__init__()
self.temporal_compress = compress_time
self.factor = factor
self.unshuffle = nn.PixelUnshuffle(self.factor)
self.s_pool = nn.AvgPool3d((1, 2, 2), (1, 2, 2))
self.spatial_conv = KVAESafeConv3d(
in_channels,
in_channels,
kernel_size=(1, 3, 3),
stride=(1, 2, 2),
padding=(0, 1, 1),
padding_mode="reflect",
)
if self.temporal_compress:
self.temporal_conv = KVAECachedCausalConv3d(
in_channels, in_channels, kernel_size=(3, 1, 1), stride=(2, 1, 1), dilation=(1, 1, 1)
)
self.linear = nn.Conv3d(in_channels, in_channels, kernel_size=1, stride=1)
def spatial_downsample(self, input: torch.Tensor) -> torch.Tensor:
b, c, t, h, w = input.shape
pxs_input = input.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
# pxs_input = rearrange(input, 'b c t h w -> (b t) c h w')
pxs_interm = self.unshuffle(pxs_input)
b_it, c_it, h_it, w_it = pxs_interm.shape
pxs_interm_view = pxs_interm.view(b_it, c_it // self.factor**2, self.factor**2, h_it, w_it)
pxs_out = torch.mean(pxs_interm_view, dim=2)
pxs_out = pxs_out.view(b, t, -1, h_it, w_it).permute(0, 2, 1, 3, 4)
# pxs_out = rearrange(pxs_out, '(b t) c h w -> b c t h w', t=input.size(2))
conv_out = self.spatial_conv(input)
return conv_out + pxs_out
def temporal_downsample(self, input: torch.Tensor, cache: list) -> torch.Tensor:
b, c, t, h, w = input.shape
permuted = input.permute(0, 3, 4, 1, 2).reshape(b * h * w, c, t)
if cache[0]["padding"] is None:
first, rest = permuted[..., :1], permuted[..., 1:]
if rest.size(-1) > 0:
rest_interp = F.avg_pool1d(rest, kernel_size=2, stride=2)
full_interp = torch.cat([first, rest_interp], dim=-1)
else:
full_interp = first
else:
rest = permuted
if rest.size(-1) > 0:
full_interp = F.avg_pool1d(rest, kernel_size=2, stride=2)
t_new = full_interp.size(-1)
full_interp = full_interp.view(b, h, w, c, t_new).permute(0, 3, 4, 1, 2)
conv_out = self.temporal_conv(input, cache[0])
return conv_out + full_interp
def forward(self, x: torch.Tensor, cache: list) -> torch.Tensor:
out = self.spatial_downsample(x)
if self.temporal_compress:
out = self.temporal_downsample(out, cache=cache)
return self.linear(out)
class KVAECachedPXSUpsample(nn.Module):
r"""
A 3D upsampling layer using PixelShuffle with caching.
"""
def __init__(self, in_channels: int, compress_time: bool, factor: int = 2):
super().__init__()
self.temporal_compress = compress_time
self.factor = factor
self.shuffle = nn.PixelShuffle(self.factor)
self.spatial_conv = KVAESafeConv3d(
in_channels,
in_channels,
kernel_size=(1, 3, 3),
stride=(1, 1, 1),
padding=(0, 1, 1),
padding_mode="reflect",
)
if self.temporal_compress:
self.temporal_conv = KVAECachedCausalConv3d(
in_channels, in_channels, kernel_size=(3, 1, 1), stride=(1, 1, 1), dilation=(1, 1, 1)
)
self.linear = KVAESafeConv3d(in_channels, in_channels, kernel_size=1, stride=1)
def spatial_upsample(self, input: torch.Tensor) -> torch.Tensor:
b, c, t, h, w = input.shape
input_view = input.permute(0, 2, 1, 3, 4).reshape(b, t * c, h, w)
input_interp = F.interpolate(input_view, scale_factor=2, mode="nearest")
input_interp = input_interp.view(b, t, c, 2 * h, 2 * w).permute(0, 2, 1, 3, 4)
out = self.spatial_conv(input_interp)
return input_interp + out
def temporal_upsample(self, input: torch.Tensor, cache: Dict) -> torch.Tensor:
time_factor = 1.0 + 1.0 * (input.size(2) > 1)
if isinstance(time_factor, torch.Tensor):
time_factor = time_factor.item()
repeated = input.repeat_interleave(int(time_factor), dim=2)
if cache["padding"] is None:
tail = repeated[..., int(time_factor - 1) :, :, :]
else:
tail = repeated
conv_out = self.temporal_conv(tail, cache)
return conv_out + tail
def forward(self, x: torch.Tensor, cache: Dict) -> torch.Tensor:
if self.temporal_compress:
x = self.temporal_upsample(x, cache)
s_out = self.spatial_upsample(x)
to = torch.empty_like(s_out)
lin_out = self.linear(s_out, write_to=to)
return lin_out
# =============================================================================
# Cached Encoder/Decoder
# =============================================================================
class KVAECachedEncoder3D(nn.Module):
r"""
Cached 3D Encoder for KVAE.
"""
def __init__(
self,
ch: int = 128,
ch_mult: Tuple[int, ...] = (1, 2, 4, 8),
num_res_blocks: int = 2,
dropout: float = 0.0,
in_channels: int = 3,
z_channels: int = 16,
double_z: bool = True,
temporal_compress_times: int = 4,
):
super().__init__()
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.in_channels = in_channels
self.temporal_compress_level = int(np.log2(temporal_compress_times))
self.conv_in = KVAECachedCausalConv3d(chan_in=in_channels, chan_out=self.ch, kernel_size=3)
in_ch_mult = (1,) + tuple(ch_mult)
self.down = nn.ModuleList()
block_in = ch
for i_level in range(self.num_resolutions):
block = nn.ModuleList()
attn = nn.ModuleList()
block_in = ch * in_ch_mult[i_level]
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks):
block.append(
KVAECachedResnetBlock3D(
in_channels=block_in,
out_channels=block_out,
dropout=dropout,
temb_channels=self.temb_ch,
)
)
block_in = block_out
down = nn.Module()
down.block = block
down.attn = attn
if i_level != self.num_resolutions - 1:
if i_level < self.temporal_compress_level:
down.downsample = KVAECachedPXSDownsample(block_in, compress_time=True)
else:
down.downsample = KVAECachedPXSDownsample(block_in, compress_time=False)
self.down.append(down)
self.mid = nn.Module()
self.mid.block_1 = KVAECachedResnetBlock3D(
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
)
self.mid.block_2 = KVAECachedResnetBlock3D(
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
)
self.norm_out = KVAECachedGroupNorm(block_in)
self.conv_out = KVAECachedCausalConv3d(
chan_in=block_in, chan_out=2 * z_channels if double_z else z_channels, kernel_size=3
)
self.gradient_checkpointing = False
def forward(self, x: torch.Tensor, cache_dict: Dict) -> torch.Tensor:
temb = None
h = self.conv_in(x, cache=cache_dict["conv_in"])
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
h = self._gradient_checkpointing_func(
self.down[i_level].block[i_block], h, temb, cache_dict[i_level][i_block]
)
else:
h = self.down[i_level].block[i_block](h, temb, layer_cache=cache_dict[i_level][i_block])
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
if i_level != self.num_resolutions - 1:
h = self.down[i_level].downsample(h, cache=cache_dict[i_level]["down"])
if torch.is_grad_enabled() and self.gradient_checkpointing:
h = self._gradient_checkpointing_func(self.mid.block_1, h, temb, cache_dict["mid_1"])
h = self._gradient_checkpointing_func(self.mid.block_2, h, temb, cache_dict["mid_2"])
else:
h = self.mid.block_1(h, temb, layer_cache=cache_dict["mid_1"])
h = self.mid.block_2(h, temb, layer_cache=cache_dict["mid_2"])
h = self.norm_out(h, cache=cache_dict["norm_out"])
h = nonlinearity(h)
h = self.conv_out(h, cache=cache_dict["conv_out"])
return h
class KVAECachedDecoder3D(nn.Module):
r"""
Cached 3D Decoder for KVAE.
"""
def __init__(
self,
ch: int = 128,
out_ch: int = 3,
ch_mult: Tuple[int, ...] = (1, 2, 4, 8),
num_res_blocks: int = 2,
dropout: float = 0.0,
z_channels: int = 16,
zq_ch: Optional[int] = None,
add_conv: bool = False,
temporal_compress_times: int = 4,
):
super().__init__()
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.temporal_compress_level = int(np.log2(temporal_compress_times))
if zq_ch is None:
zq_ch = z_channels
block_in = ch * ch_mult[self.num_resolutions - 1]
self.conv_in = KVAECachedCausalConv3d(chan_in=z_channels, chan_out=block_in, kernel_size=3)
self.mid = nn.Module()
self.mid.block_1 = KVAECachedResnetBlock3D(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
zq_ch=zq_ch,
add_conv=add_conv,
)
self.mid.block_2 = KVAECachedResnetBlock3D(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
zq_ch=zq_ch,
add_conv=add_conv,
)
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks + 1):
block.append(
KVAECachedResnetBlock3D(
in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout,
zq_ch=zq_ch,
add_conv=add_conv,
)
)
block_in = block_out
up = nn.Module()
up.block = block
up.attn = attn
if i_level != 0:
if i_level < self.num_resolutions - self.temporal_compress_level:
up.upsample = KVAECachedPXSUpsample(block_in, compress_time=False)
else:
up.upsample = KVAECachedPXSUpsample(block_in, compress_time=True)
self.up.insert(0, up)
self.norm_out = KVAECachedSpatialNorm3D(block_in, zq_ch, add_conv=add_conv)
self.conv_out = KVAECachedCausalConv3d(chan_in=block_in, chan_out=out_ch, kernel_size=3)
self.gradient_checkpointing = False
def forward(self, z: torch.Tensor, cache_dict: Dict) -> torch.Tensor:
temb = None
zq = z
h = self.conv_in(z, cache_dict["conv_in"])
if torch.is_grad_enabled() and self.gradient_checkpointing:
h = self._gradient_checkpointing_func(self.mid.block_1, h, temb, cache_dict["mid_1"], zq)
h = self._gradient_checkpointing_func(self.mid.block_2, h, temb, cache_dict["mid_2"], zq)
else:
h = self.mid.block_1(h, temb, layer_cache=cache_dict["mid_1"], zq=zq)
h = self.mid.block_2(h, temb, layer_cache=cache_dict["mid_2"], zq=zq)
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
if torch.is_grad_enabled() and self.gradient_checkpointing:
h = self._gradient_checkpointing_func(
self.up[i_level].block[i_block], h, temb, cache_dict[i_level][i_block], zq
)
else:
h = self.up[i_level].block[i_block](h, temb, layer_cache=cache_dict[i_level][i_block], zq=zq)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h, zq)
if i_level != 0:
h = self.up[i_level].upsample(h, cache_dict[i_level]["up"])
h = self.norm_out(h, zq, cache_dict["norm_out"])
h = nonlinearity(h)
h = self.conv_out(h, cache_dict["conv_out"])
return h
# =============================================================================
# Main AutoencoderKL class
# =============================================================================
class AutoencoderKLKVAEVideo(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin):
r"""
A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. Used in
[KVAE](https://github.com/kandinskylab/kvae-1).
This model inherits from [`ModelMixin`]. Check the superclass documentation for its generic methods implemented for
all models (such as downloading or saving).
Parameters:
ch (`int`, *optional*, defaults to 128): Base channel count.
ch_mult (`Tuple[int]`, *optional*, defaults to `(1, 2, 4, 8)`): Channel multipliers per level.
num_res_blocks (`int`, *optional*, defaults to 2): Number of residual blocks per level.
in_channels (`int`, *optional*, defaults to 3): Number of input channels.
out_ch (`int`, *optional*, defaults to 3): Number of output channels.
z_channels (`int`, *optional*, defaults to 16): Number of latent channels.
temporal_compress_times (`int`, *optional*, defaults to 4): Temporal compression factor.
"""
_supports_gradient_checkpointing = True
_no_split_modules = ["KVAECachedResnetBlock3D"]
@register_to_config
def __init__(
self,
ch: int = 128,
ch_mult: Tuple[int, ...] = (1, 2, 4, 8),
num_res_blocks: int = 2,
in_channels: int = 3,
out_ch: int = 3,
z_channels: int = 16,
temporal_compress_times: int = 4,
):
super().__init__()
self.encoder = KVAECachedEncoder3D(
ch=ch,
ch_mult=ch_mult,
num_res_blocks=num_res_blocks,
in_channels=in_channels,
z_channels=z_channels,
double_z=True,
temporal_compress_times=temporal_compress_times,
)
self.decoder = KVAECachedDecoder3D(
ch=ch,
ch_mult=ch_mult,
num_res_blocks=num_res_blocks,
out_ch=out_ch,
z_channels=z_channels,
temporal_compress_times=temporal_compress_times,
)
self.use_slicing = False
self.use_tiling = False
def _make_encoder_cache(self) -> Dict:
"""Create empty cache for cached encoder."""
def make_dict(name, p=None):
if name == "conv":
return {"padding": None}
layer, module = name.split("_")
if layer == "norm":
if module == "enc":
return {"mean": None, "var": None}
else:
return {"norm": make_dict("norm_enc"), "add_conv": make_dict("conv")}
elif layer == "resblock":
return {
"norm1": make_dict(f"norm_{module}"),
"norm2": make_dict(f"norm_{module}"),
"conv1": make_dict("conv"),
"conv2": make_dict("conv"),
"conv_shortcut": make_dict("conv"),
}
elif layer.isdigit():
out_dict = {"down": [make_dict("conv"), make_dict("conv")], "up": make_dict("conv")}
for i in range(p):
out_dict[i] = make_dict(f"resblock_{module}")
return out_dict
cache = {
"conv_in": make_dict("conv"),
"mid_1": make_dict("resblock_enc"),
"mid_2": make_dict("resblock_enc"),
"norm_out": make_dict("norm_enc"),
"conv_out": make_dict("conv"),
}
# Encoder uses num_res_blocks per level
for i in range(len(self.config.ch_mult)):
cache[i] = make_dict(f"{i}_enc", p=self.config.num_res_blocks)
return cache
def _make_decoder_cache(self) -> Dict:
"""Create empty cache for decoder."""
def make_dict(name, p=None):
if name == "conv":
return {"padding": None}
layer, module = name.split("_")
if layer == "norm":
if module == "enc":
return {"mean": None, "var": None}
else:
return {"norm": make_dict("norm_enc"), "add_conv": make_dict("conv")}
elif layer == "resblock":
return {
"norm1": make_dict(f"norm_{module}"),
"norm2": make_dict(f"norm_{module}"),
"conv1": make_dict("conv"),
"conv2": make_dict("conv"),
"conv_shortcut": make_dict("conv"),
}
elif layer.isdigit():
out_dict = {"down": [make_dict("conv"), make_dict("conv")], "up": make_dict("conv")}
for i in range(p):
out_dict[i] = make_dict(f"resblock_{module}")
return out_dict
cache = {
"conv_in": make_dict("conv"),
"mid_1": make_dict("resblock_dec"),
"mid_2": make_dict("resblock_dec"),
"norm_out": make_dict("norm_dec"),
"conv_out": make_dict("conv"),
}
for i in range(len(self.config.ch_mult)):
cache[i] = make_dict(f"{i}_dec", p=self.config.num_res_blocks + 1)
return cache
def enable_slicing(self) -> None:
r"""Enable sliced VAE decoding."""
self.use_slicing = True
def disable_slicing(self) -> None:
r"""Disable sliced VAE decoding."""
self.use_slicing = False
def _encode(self, x: torch.Tensor, seg_len: int = 16) -> torch.Tensor:
# Cached encoder processes by segments
cache = self._make_encoder_cache()
split_list = [seg_len + 1]
n_frames = x.size(2) - (seg_len + 1)
while n_frames > 0:
split_list.append(seg_len)
n_frames -= seg_len
split_list[-1] += n_frames
latent = []
for chunk in torch.split(x, split_list, dim=2):
l = self.encoder(chunk, cache)
sample, _ = torch.chunk(l, 2, dim=1)
latent.append(sample)
return torch.cat(latent, dim=2)
@apply_forward_hook
def encode(
self, x: torch.Tensor, return_dict: bool = True
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
"""
Encode a batch of videos into latents.
Args:
x (`torch.Tensor`): Input batch of videos with shape (B, C, T, H, W).
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
Returns:
The latent representations of the encoded videos.
"""
if self.use_slicing and x.shape[0] > 1:
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
h = torch.cat(encoded_slices)
else:
h = self._encode(x)
# For cached encoder, we already did the split in _encode
h_double = torch.cat([h, torch.zeros_like(h)], dim=1)
posterior = DiagonalGaussianDistribution(h_double)
if not return_dict:
return (posterior,)
return AutoencoderKLOutput(latent_dist=posterior)
def _decode(self, z: torch.Tensor, seg_len: int = 16) -> torch.Tensor:
cache = self._make_decoder_cache()
temporal_compress = self.config.temporal_compress_times
split_list = [seg_len + 1]
n_frames = temporal_compress * (z.size(2) - 1) - seg_len
while n_frames > 0:
split_list.append(seg_len)
n_frames -= seg_len
split_list[-1] += n_frames
split_list = [math.ceil(size / temporal_compress) for size in split_list]
recs = []
for chunk in torch.split(z, split_list, dim=2):
out = self.decoder(chunk, cache)
recs.append(out)
return torch.cat(recs, dim=2)
@apply_forward_hook
def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
"""
Decode a batch of videos.
Args:
z (`torch.Tensor`): Input batch of latent vectors with shape (B, C, T, H, W).
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
Returns:
[`~models.vae.DecoderOutput`] or `tuple`: Decoded video.
"""
if self.use_slicing and z.shape[0] > 1:
decoded_slices = [self._decode(z_slice) for z_slice in z.split(1)]
decoded = torch.cat(decoded_slices)
else:
decoded = self._decode(z)
if not return_dict:
return (decoded,)
return DecoderOutput(sample=decoded)
def forward(
self,
sample: torch.Tensor,
sample_posterior: bool = False,
return_dict: bool = True,
generator: Optional[torch.Generator] = None,
) -> Union[DecoderOutput, torch.Tensor]:
x = sample
posterior = self.encode(x).latent_dist
if sample_posterior:
z = posterior.sample(generator=generator)
else:
z = posterior.mode()
dec = self.decode(z).sample
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)

View File

@@ -105,14 +105,7 @@ class QwenImageRMS_norm(nn.Module):
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
def forward(self, x):
needs_fp32_normalize = x.dtype in (torch.float16, torch.bfloat16) or any(
t in str(x.dtype) for t in ("float4_", "float8_")
)
normalized = F.normalize(x.float() if needs_fp32_normalize else x, dim=(1 if self.channel_first else -1)).to(
x.dtype
)
return normalized * self.scale * self.gamma + self.bias
return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias
class QwenImageUpsample(nn.Upsample):

View File

@@ -196,14 +196,7 @@ class WanRMS_norm(nn.Module):
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
def forward(self, x):
needs_fp32_normalize = x.dtype in (torch.float16, torch.bfloat16) or any(
t in str(x.dtype) for t in ("float4_", "float8_")
)
normalized = F.normalize(x.float() if needs_fp32_normalize else x, dim=(1 if self.channel_first else -1)).to(
x.dtype
)
return normalized * self.scale * self.gamma + self.bias
return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias
class WanUpsample(nn.Upsample):

View File

@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import math
from math import prod
from typing import Any
@@ -24,7 +25,7 @@ import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import apply_lora_scale, deprecate, logging
from ...utils.torch_utils import lru_cache_unless_export, maybe_allow_in_graph
from ...utils.torch_utils import maybe_allow_in_graph
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
from ..attention import AttentionMixin, FeedForward
from ..attention_dispatch import dispatch_attention_fn
@@ -306,7 +307,7 @@ class QwenEmbedRope(nn.Module):
return vid_freqs, txt_freqs
@lru_cache_unless_export(maxsize=128)
@functools.lru_cache(maxsize=128)
def _compute_video_freqs(
self, frame: int, height: int, width: int, idx: int = 0, device: torch.device = None
) -> torch.Tensor:
@@ -427,7 +428,7 @@ class QwenEmbedLayer3DRope(nn.Module):
return vid_freqs, txt_freqs
@lru_cache_unless_export(maxsize=None)
@functools.lru_cache(maxsize=None)
def _compute_video_freqs(self, frame, height, width, idx=0, device: torch.device = None):
seq_lens = frame * height * width
pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs
@@ -449,7 +450,7 @@ class QwenEmbedLayer3DRope(nn.Module):
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
return freqs.clone().contiguous()
@lru_cache_unless_export(maxsize=None)
@functools.lru_cache(maxsize=None)
def _compute_condition_freqs(self, frame, height, width, device: torch.device = None):
seq_lens = frame * height * width
pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs
@@ -933,7 +934,6 @@ class QwenImageTransformer2DModel(
batch_size, image_seq_len = hidden_states.shape[:2]
image_mask = torch.ones((batch_size, image_seq_len), dtype=torch.bool, device=hidden_states.device)
joint_attention_mask = torch.cat([encoder_hidden_states_mask, image_mask], dim=1)
joint_attention_mask = joint_attention_mask[:, None, None, :]
block_attention_kwargs["attention_mask"] = joint_attention_mask
for index_block, block in enumerate(self.transformer_blocks):

View File

@@ -788,12 +788,9 @@ class ZImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOr
freqs_cis = pad_sequence(freqs_cis, batch_first=True, padding_value=0.0)[:, : feats.shape[1]]
# Attention mask
if all(seq == max_seqlen for seq in item_seqlens):
attn_mask = None
else:
attn_mask = torch.zeros((bsz, max_seqlen), dtype=torch.bool, device=device)
for i, seq_len in enumerate(item_seqlens):
attn_mask[i, :seq_len] = 1
attn_mask = torch.zeros((bsz, max_seqlen), dtype=torch.bool, device=device)
for i, seq_len in enumerate(item_seqlens):
attn_mask[i, :seq_len] = 1
# Noise mask
noise_mask_tensor = None
@@ -874,12 +871,9 @@ class ZImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOr
unified_freqs = pad_sequence(unified_freqs, batch_first=True, padding_value=0.0)
# Attention mask
if all(seq == max_seqlen for seq in unified_seqlens):
attn_mask = None
else:
attn_mask = torch.zeros((bsz, max_seqlen), dtype=torch.bool, device=device)
for i, seq_len in enumerate(unified_seqlens):
attn_mask[i, :seq_len] = 1
attn_mask = torch.zeros((bsz, max_seqlen), dtype=torch.bool, device=device)
for i, seq_len in enumerate(unified_seqlens):
attn_mask[i, :seq_len] = 1
# Noise mask
noise_mask_tensor = None

View File

@@ -285,7 +285,6 @@ else:
]
)
_import_structure["latte"] = ["LattePipeline"]
_import_structure["llada2"] = ["LLaDA2Pipeline", "LLaDA2PipelineOutput"]
_import_structure["ltx"] = [
"LTXPipeline",
"LTXImageToVideoPipeline",
@@ -729,7 +728,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
LEditsPPPipelineStableDiffusion,
LEditsPPPipelineStableDiffusionXL,
)
from .llada2 import LLaDA2Pipeline, LLaDA2PipelineOutput
from .longcat_image import LongCatImageEditPipeline, LongCatImagePipeline
from .ltx import (
LTXConditionPipeline,

View File

@@ -16,29 +16,22 @@ from typing import Callable
import numpy as np
import torch
import torchvision
import torchvision.transforms
import torchvision.transforms.functional
from transformers import AutoTokenizer, Qwen2_5_VLForConditionalGeneration
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...image_processor import PipelineImageInput
from ...models import AutoencoderKLWan, CosmosTransformer3DModel
from ...schedulers import UniPCMultistepScheduler
from ...utils import (
is_cosmos_guardrail_available,
is_torch_xla_available,
is_torchvision_available,
logging,
replace_example_docstring,
)
from ...utils import is_cosmos_guardrail_available, is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
from ..pipeline_utils import DiffusionPipeline
from .pipeline_output import CosmosPipelineOutput
if is_torchvision_available():
import torchvision.transforms.functional
if is_cosmos_guardrail_available():
from cosmos_guardrail import CosmosSafetyChecker
else:

View File

@@ -1,47 +0,0 @@
from typing import TYPE_CHECKING
from ...utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
)
_dummy_objects = {}
_import_structure = {}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_llada2"] = ["LLaDA2Pipeline", "LLaDA2PipelineOutput"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
else:
from .pipeline_llada2 import LLaDA2Pipeline, LLaDA2PipelineOutput
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)

View File

@@ -1,491 +0,0 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Callable
import torch
from tqdm.auto import tqdm
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...schedulers import BlockRefinementScheduler
from ...utils import BaseOutput, logging, replace_example_docstring
from ..pipeline_utils import DiffusionPipeline
logger = logging.get_logger(__name__)
EXAMPLE_DOC_STRING = """
Examples:
```python
>>> import torch
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
>>> from diffusers import BlockRefinementScheduler, LLaDA2Pipeline
>>> model_id = "inclusionAI/LLaDA2.1-mini"
>>> model = AutoModelForCausalLM.from_pretrained(
... model_id, trust_remote_code=True, dtype=torch.bfloat16, device_map="auto"
... )
>>> tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
>>> scheduler = BlockRefinementScheduler()
>>> pipe = LLaDA2Pipeline(model=model, scheduler=scheduler, tokenizer=tokenizer)
>>> output = pipe(prompt="What is the meaning of life?", gen_length=256)
>>> print(output.texts[0])
```
"""
@dataclass
class LLaDA2PipelineOutput(BaseOutput):
sequences: torch.LongTensor
texts: list[str] | None = None
class LLaDA2Pipeline(DiffusionPipeline):
r"""
Pipeline for LLaDA2-style discrete diffusion text generation via block-wise iterative refinement.
This pipeline maintains a template sequence filled with a `mask_token_id` and refines it in blocks. In each
refinement step, it samples candidate tokens for the active block and commits a subset based on confidence.
The model is expected to accept an attention mask and `position_ids`, and to return logits of shape `[batch, seq,
vocab_size]`.
"""
model: Any
scheduler: BlockRefinementScheduler
tokenizer: Any
_callback_tensor_inputs = ["block_x", "x0", "x0_p", "transfer_index", "confidence", "active_block"]
def __init__(
self,
model: Any,
scheduler: BlockRefinementScheduler,
tokenizer: Any | None = None,
):
super().__init__()
self.register_modules(model=model, scheduler=scheduler, tokenizer=tokenizer)
self.eos_token_id = getattr(self.tokenizer, "eos_token_id", None) if self.tokenizer is not None else None
self.mask_token_id = getattr(self.tokenizer, "mask_token_id", None) if self.tokenizer is not None else None
@property
def num_timesteps(self):
return self._num_timesteps
# --- Prompt encoding ---
def _prepare_input_ids(
self,
*,
prompt: str | list[str] | None,
messages: list[dict[str, str]] | None,
input_ids: torch.LongTensor | None,
use_chat_template: bool,
add_generation_prompt: bool,
chat_template_kwargs: dict[str, Any] | None,
) -> torch.LongTensor:
"""Convert prompt/messages/input_ids to a [batch, seq] LongTensor."""
if input_ids is not None:
if input_ids.ndim == 1:
input_ids = input_ids.unsqueeze(0)
if input_ids.ndim != 2:
raise ValueError(f"`input_ids` must be 2D, got shape {tuple(input_ids.shape)}.")
if input_ids.dtype != torch.long:
raise ValueError(f"`input_ids` must be int64 token IDs, got dtype={input_ids.dtype}.")
return input_ids
if self.tokenizer is None:
raise ValueError("Tokenizer is required when `input_ids` is not provided.")
if messages is not None and prompt is not None:
raise ValueError("Provide either `prompt` or `messages`, not both.")
if messages is None and prompt is None:
raise ValueError("Provide one of `prompt`, `messages`, or `input_ids`.")
chat_template_kwargs = chat_template_kwargs or {}
if messages is not None:
encoded = self.tokenizer.apply_chat_template(
messages,
add_generation_prompt=add_generation_prompt,
tokenize=True,
return_tensors="pt",
return_dict=True,
**chat_template_kwargs,
)
return encoded["input_ids"]
if use_chat_template and getattr(self.tokenizer, "chat_template", None):
if isinstance(prompt, list):
raise ValueError("`prompt` must be a string when `use_chat_template=True`.")
encoded = self.tokenizer.apply_chat_template(
[{"role": "user", "content": prompt}],
add_generation_prompt=add_generation_prompt,
tokenize=True,
return_tensors="pt",
return_dict=True,
**chat_template_kwargs,
)
return encoded["input_ids"]
encoded = self.tokenizer(prompt, return_tensors="pt", padding=isinstance(prompt, list))
return encoded["input_ids"]
def check_inputs(
self,
prompt: str | list[str] | None,
messages: list[dict[str, str]] | None,
input_ids: torch.LongTensor | None,
gen_length: int,
block_length: int,
num_inference_steps: int,
minimal_topk: int,
threshold: float,
sampling_method: str,
output_type: str,
callback_on_step_end: Callable | PipelineCallback | MultiPipelineCallbacks | None,
callback_on_step_end_tensor_inputs: list[str] | None,
):
# Input source validation
if prompt is None and messages is None and input_ids is None:
raise ValueError("Provide one of `prompt`, `messages`, or `input_ids`.")
if prompt is not None and messages is not None:
raise ValueError("Provide either `prompt` or `messages`, not both.")
if input_ids is not None:
if input_ids.ndim not in (1, 2):
raise ValueError(f"`input_ids` must be 1D or 2D, got shape {tuple(input_ids.shape)}.")
if input_ids.dtype != torch.long:
raise ValueError(f"`input_ids` must be int64 token IDs, got dtype={input_ids.dtype}.")
if prompt is not None and input_ids is None and self.tokenizer is None:
raise ValueError("Tokenizer is required when `input_ids` is not provided.")
if messages is not None and input_ids is None and self.tokenizer is None:
raise ValueError("Tokenizer is required when `input_ids` is not provided.")
# Generation parameter validation
if gen_length <= 0:
raise ValueError(f"`gen_length` must be > 0, got {gen_length}.")
if block_length <= 0:
raise ValueError(f"`block_length` must be > 0, got {block_length}.")
if num_inference_steps <= 0:
raise ValueError(f"`num_inference_steps` must be > 0, got {num_inference_steps}.")
if minimal_topk <= 0:
raise ValueError(f"`minimal_topk` must be > 0, got {minimal_topk}.")
if not (0.0 <= threshold <= 1.0) and not (threshold > 1.0):
raise ValueError(f"`threshold` must be in [0, 1] (or > 1 to force top-k commits), got {threshold}.")
if sampling_method not in {"auto", "greedy", "multinomial"}:
raise ValueError(
f"`sampling_method` must be one of {{'auto','greedy','multinomial'}}, got {sampling_method!r}."
)
if output_type not in {"seq", "text"}:
raise ValueError(f"`output_type` must be 'seq' or 'text', got {output_type!r}.")
# Callback validation
if callback_on_step_end is not None and isinstance(
callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)
):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found "
f"{[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: str | list[str] | None = None,
messages: list[dict[str, str]] | None = None,
input_ids: torch.LongTensor | None = None,
use_chat_template: bool = True,
add_generation_prompt: bool = True,
gen_length: int = 2048,
block_length: int = 32,
num_inference_steps: int = 32,
temperature: float = 0.0,
top_p: float | None = None,
top_k: int | None = None,
sampling_method: str = "multinomial",
threshold: float = 0.7,
editing_threshold: float | None = 0.5,
max_post_steps: int = 16,
minimal_topk: int = 1,
eos_early_stop: bool = True,
eos_token_id: int | None = None,
mask_token_id: int | None = None,
generator: torch.Generator | None = None,
output_type: str = "text",
return_dict: bool = True,
callback_on_step_end: Callable[[int, int, dict], None]
| PipelineCallback
| MultiPipelineCallbacks
| None = None,
callback_on_step_end_tensor_inputs: list[str] | None = None,
) -> LLaDA2PipelineOutput | tuple[torch.LongTensor, list[str] | None]:
"""
Generate text with block-wise refinement.
Args:
prompt (`str` or `List[str]`, *optional*):
Prompt text. When `use_chat_template` is `True` (default) and a tokenizer with a chat template is
available, the prompt is wrapped in a chat message before tokenization.
messages (`List[Dict[str, str]]`, *optional*):
Chat messages to encode (e.g. `[{"role": "user", "content": "Hello"}]`). Takes precedence over `prompt`
when provided. Requires a tokenizer with `apply_chat_template`.
input_ids (`torch.LongTensor`, *optional*):
Pre-tokenized input IDs. Takes precedence over `prompt` and `messages`.
use_chat_template (`bool`, defaults to `True`):
Whether to wrap the prompt in a chat template.
add_generation_prompt (`bool`, defaults to `True`):
Whether to add the generation prompt when using chat templates.
gen_length (`int`):
Number of tokens to generate.
block_length (`int`):
Block size for refinement.
num_inference_steps (`int`):
Number of refinement steps per block.
temperature (`float`):
Sampling temperature.
top_p (`float`, *optional*):
Nucleus sampling cutoff.
top_k (`int`, *optional*):
Top-k sampling cutoff.
sampling_method (`str`):
Sampling method (`auto`, `greedy`, `multinomial`).
threshold (`float`):
Confidence threshold for committing tokens.
editing_threshold (`float`, *optional*):
Confidence threshold for editing already-committed (non-mask) tokens. When positive, after all mask
tokens in a block are resolved, the pipeline continues refining: if the model predicts a different
token with confidence above this threshold, the existing token is replaced. Set to `None`, `0.0`, or a
negative value to disable editing. Defaults to `0.5`.
max_post_steps (`int`):
Maximum number of additional refinement iterations after all mask tokens in a block are resolved. Only
used when `editing_threshold` is enabled. Defaults to `16`.
minimal_topk (`int`):
Minimum number of tokens to commit per step.
eos_early_stop (`bool`):
Whether to stop after committing EOS in a block.
eos_token_id (`int`, *optional*):
EOS token ID to use for early stopping.
mask_token_id (`int`, *optional*):
Mask token ID to use for the template.
generator (`torch.Generator`, *optional*):
RNG for sampling.
output_type (`str`, defaults to `"text"`):
Output format. `"text"` decodes sequences into strings (requires a tokenizer). `"seq"` returns raw
token ID sequences only.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`LLaDA2PipelineOutput`] instead of a tuple.
callback_on_step_end (`Callable` or `PipelineCallback`, *optional*):
Callback executed after each refinement step with signature `callback_on_step_end(self, step: int,
timestep: int, callback_kwargs: Dict)`.
callback_on_step_end_tensor_inputs (`List[str]`, *optional*):
Tensor keys to pass to the callback. Allowed keys: `block_x`, `x0`, `x0_p`, `transfer_index`,
`confidence`, `active_block`.
Examples:
"""
# 1. Check inputs early
if callback_on_step_end is not None and isinstance(
callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)
):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
if callback_on_step_end_tensor_inputs is None:
callback_on_step_end_tensor_inputs = ["block_x"]
self.check_inputs(
prompt=prompt,
messages=messages,
input_ids=input_ids,
gen_length=gen_length,
block_length=block_length,
num_inference_steps=num_inference_steps,
minimal_topk=minimal_topk,
threshold=threshold,
sampling_method=sampling_method,
output_type=output_type,
callback_on_step_end=callback_on_step_end,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
)
# 2. Prepare input IDs from prompt/messages/input_ids
prompt_ids = self._prepare_input_ids(
prompt=prompt,
messages=messages,
input_ids=input_ids,
use_chat_template=use_chat_template,
add_generation_prompt=add_generation_prompt,
chat_template_kwargs=None,
)
device = self._execution_device
if prompt_ids.ndim == 1:
prompt_ids = prompt_ids.unsqueeze(0)
prompt_ids = prompt_ids.to(device=device)
batch_size, prompt_length = prompt_ids.shape
if eos_token_id is None:
eos_token_id = self.eos_token_id
if mask_token_id is None:
mask_token_id = self.mask_token_id
if mask_token_id is None:
raise ValueError("`mask_token_id` must be provided (or available on the tokenizer).")
num_inference_steps = min(num_inference_steps, gen_length // minimal_topk)
self.scheduler.set_timesteps(num_inference_steps, device=device)
# 3. Build attention mask and position IDs
num_blocks = (prompt_length + gen_length + block_length - 1) // block_length
total_length = num_blocks * block_length
# 2D attention mask (no padding) — the model handles backend-specific conversion internally.
attn_mask = torch.ones((batch_size, total_length), device=device, dtype=torch.long)
position_ids = torch.arange(total_length, device=device, dtype=torch.long).unsqueeze(0).expand(batch_size, -1)
# 4. Prepare latents (fully masked sequence)
x = torch.full((batch_size, total_length), mask_token_id, device=device, dtype=torch.long)
if prompt_length > 0:
x[:, :prompt_length] = prompt_ids
prefill_blocks = prompt_length // block_length
self._num_timesteps = num_inference_steps * max(num_blocks - prefill_blocks, 0)
finished = torch.zeros((batch_size,), device=device, dtype=torch.bool)
editing_enabled = editing_threshold is not None and editing_threshold > 0.0
global_step = 0
# 5. Block-wise refinement loop
block_progress_bar_config = getattr(self, "_progress_bar_config", {}).copy()
block_progress_bar_config["position"] = 0
block_progress_bar_config["desc"] = "Blocks"
for num_block in tqdm(range(prefill_blocks, num_blocks), **block_progress_bar_config):
current_window_end = (num_block + 1) * block_length
block_x = x[:, :current_window_end]
block_attn_mask = attn_mask[:, :current_window_end]
block_position_ids = position_ids[:, :current_window_end]
# Identify which positions in the block are prompt (non-editable).
block_start_pos = num_block * block_length
prompt_mask_in_block = torch.zeros(block_length, device=device, dtype=torch.bool)
if block_start_pos < prompt_length:
prompt_end_in_block = min(prompt_length - block_start_pos, block_length)
prompt_mask_in_block[:prompt_end_in_block] = True
post_steps = 0
step_idx = 0
should_continue = True
self.set_progress_bar_config(position=1, leave=False, desc=f"Block {num_block} Inference Steps")
progress_bar = self.progress_bar(total=num_inference_steps)
while should_continue:
block_tokens = block_x[:, -block_length:]
masks_remaining = (block_tokens == mask_token_id).any()
if not masks_remaining:
post_steps += 1
logits = self.model(block_x, attention_mask=block_attn_mask, position_ids=block_position_ids).logits
block_logits = logits[:, -block_length:, :]
scheduler_output = self.scheduler.step(
model_output=block_logits,
timestep=step_idx,
sample=block_tokens,
mask_token_id=mask_token_id,
temperature=temperature,
top_p=top_p,
top_k=top_k,
sampling_method=sampling_method,
threshold=threshold,
editing_threshold=editing_threshold,
minimal_topk=minimal_topk,
prompt_mask=prompt_mask_in_block,
generator=generator,
return_dict=True,
)
transfer_index = scheduler_output.transfer_index
editing_transfer_index = scheduler_output.editing_transfer_index
final_transfer = transfer_index | editing_transfer_index
if final_transfer.any():
block_x[:, -block_length:] = scheduler_output.prev_sample
if eos_early_stop and eos_token_id is not None:
finished = self.scheduler.check_eos_finished(
cur_x=block_x,
sampled_tokens=scheduler_output.sampled_tokens,
final_transfer=final_transfer,
finished=finished,
eos_token_id=eos_token_id,
mask_token_id=mask_token_id,
prompt_length=prompt_length,
)
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, global_step, step_idx, callback_kwargs)
block_x = callback_outputs.pop("block_x", block_x)
global_step += 1
if masks_remaining:
step_idx += 1
progress_bar.update(1)
should_continue = self.scheduler.check_block_should_continue(
step_idx=step_idx,
masks_remaining=masks_remaining,
editing_enabled=editing_enabled,
editing_transfer_index=editing_transfer_index,
post_steps=post_steps,
max_post_steps=max_post_steps,
finished=finished,
)
progress_bar.close()
x[:, :current_window_end] = block_x
if eos_early_stop and finished.all():
break
# 6. Post-process output
generated = x[:, : prompt_length + gen_length]
sequences = generated[:, prompt_length:]
if eos_token_id is not None and batch_size == 1:
eos_positions = (sequences[0] == eos_token_id).nonzero(as_tuple=True)[0]
if len(eos_positions) > 0:
sequences = sequences[:, : int(eos_positions[0].item()) + 1]
texts = None
if output_type == "text" and self.tokenizer is not None:
texts = self.tokenizer.batch_decode(sequences, skip_special_tokens=True)
if not return_dict:
return sequences.to(device=device), texts
return LLaDA2PipelineOutput(sequences=sequences.to(device=device), texts=texts)
__all__ = ["LLaDA2Pipeline", "LLaDA2PipelineOutput"]

View File

@@ -40,7 +40,6 @@ except OptionalDependencyNotAvailable:
else:
_import_structure["deprecated"] = ["KarrasVeScheduler", "ScoreSdeVpScheduler"]
_import_structure["scheduling_amused"] = ["AmusedScheduler"]
_import_structure["scheduling_block_refinement"] = ["BlockRefinementScheduler", "BlockRefinementSchedulerOutput"]
_import_structure["scheduling_consistency_decoder"] = ["ConsistencyDecoderScheduler"]
_import_structure["scheduling_consistency_models"] = ["CMStochasticIterativeScheduler"]
_import_structure["scheduling_ddim"] = ["DDIMScheduler"]
@@ -146,7 +145,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
else:
from .deprecated import KarrasVeScheduler, ScoreSdeVpScheduler
from .scheduling_amused import AmusedScheduler
from .scheduling_block_refinement import BlockRefinementScheduler, BlockRefinementSchedulerOutput
from .scheduling_consistency_decoder import ConsistencyDecoderScheduler
from .scheduling_consistency_models import CMStochasticIterativeScheduler
from .scheduling_ddim import DDIMScheduler

View File

@@ -1,460 +0,0 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from dataclasses import dataclass
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput
from .scheduling_utils import SchedulerMixin
@dataclass
class BlockRefinementSchedulerOutput(BaseOutput):
"""
Output class for block refinement scheduling.
Args:
prev_sample (`torch.LongTensor` of shape `(batch_size, block_length)`):
Updated block tokens after the current refinement step.
transfer_index (`torch.BoolTensor` of shape `(batch_size, block_length)`):
Boolean mask indicating which tokens were committed (mask-filling).
editing_transfer_index (`torch.BoolTensor` of shape `(batch_size, block_length)`):
Boolean mask indicating which tokens were edited (non-mask replacement).
sampled_tokens (`torch.LongTensor` of shape `(batch_size, block_length)`):
Sampled token IDs from the model logits.
sampled_probs (`torch.Tensor` of shape `(batch_size, block_length)`):
Probabilities of the sampled tokens.
"""
prev_sample: torch.LongTensor
transfer_index: torch.BoolTensor
editing_transfer_index: torch.BoolTensor
sampled_tokens: torch.LongTensor
sampled_probs: torch.Tensor
class BlockRefinementScheduler(SchedulerMixin, ConfigMixin):
"""
Scheduler for block-wise iterative refinement (commit-by-confidence).
At each step, the scheduler samples candidate tokens from model logits and commits those with the highest
confidence. The number of tokens to commit per step is determined by evenly distributing the block length across
the number of refinement steps.
Optionally supports editing: after all mask tokens are resolved, tokens can be replaced if the model predicts a
different token with confidence above a positive `editing_threshold` (`None`, `0.0`, or negative disables editing).
"""
order = 1
@register_to_config
def __init__(
self,
block_length: int = 32,
num_inference_steps: int = 32,
threshold: float = 0.95,
editing_threshold: float | None = None,
minimal_topk: int = 1,
):
self.num_inference_steps = num_inference_steps
self.timesteps = torch.arange(self.num_inference_steps - 1, -1, -1, dtype=torch.long)
self._transfer_schedule: torch.LongTensor | None = None
def set_timesteps(self, num_inference_steps: int, device: str | torch.device | None = None) -> None:
if num_inference_steps <= 0:
raise ValueError(f"`num_inference_steps` must be > 0, got {num_inference_steps}.")
self.num_inference_steps = num_inference_steps
self.timesteps = torch.arange(self.num_inference_steps - 1, -1, -1, device=device, dtype=torch.long)
self._transfer_schedule = self.get_num_transfer_tokens(self.config.block_length, self.num_inference_steps).to(
device=device if device is not None else "cpu"
)
def get_num_transfer_tokens(self, block_length: int, num_inference_steps: int) -> torch.LongTensor:
"""Evenly distribute `block_length` token commits across `num_inference_steps` steps."""
if num_inference_steps <= 0:
return torch.zeros((0,), dtype=torch.long)
base = block_length // num_inference_steps
remainder = block_length % num_inference_steps
out = torch.full((num_inference_steps,), base, dtype=torch.long)
out[:remainder] += 1
return out
# --- SAR sampling utilities ---
@staticmethod
def _top_p_filtering(logits: torch.Tensor, top_p: float | None) -> torch.Tensor:
"""Nucleus (top-p) logit filtering."""
if top_p is None or top_p >= 1.0:
return logits
if not (0.0 < top_p <= 1.0):
raise ValueError(f"`top_p` must be in (0, 1], got {top_p}.")
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
sorted_probs = torch.softmax(sorted_logits, dim=-1)
cumulative_probs = sorted_probs.cumsum(dim=-1)
sorted_indices_to_remove = cumulative_probs > float(top_p)
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
sorted_logits = sorted_logits.masked_fill(sorted_indices_to_remove, torch.finfo(sorted_logits.dtype).min)
filtered = logits.scatter(-1, sorted_indices, sorted_logits)
return filtered
@staticmethod
def _top_k_filtering(logits: torch.Tensor, top_k: int | None) -> torch.Tensor:
"""Top-k logit filtering."""
if top_k is None or top_k <= 0:
return logits
if top_k >= logits.shape[-1]:
return logits
values, _ = torch.topk(logits, k=top_k, dim=-1)
min_keep = values[..., -1, None]
return logits.masked_fill(logits < min_keep, torch.finfo(logits.dtype).min)
@staticmethod
def _sample_from_logits(
logits: torch.Tensor,
*,
temperature: float,
top_k: int | None,
top_p: float | None,
generator: torch.Generator | None,
use_multinomial: bool,
) -> tuple[torch.LongTensor, torch.Tensor]:
"""Sample tokens from logits with temperature scaling, top-k, and top-p."""
if temperature < 0:
raise ValueError(f"`temperature` must be >= 0, got {temperature}.")
vocab_size = logits.shape[-1]
flat_logits = logits.reshape(-1, vocab_size)
if temperature == 0.0 or not use_multinomial:
probs = torch.softmax(flat_logits.float(), dim=-1)
token = flat_logits.argmax(dim=-1, keepdim=True)
token_prob = torch.gather(probs, -1, token)
return token.view(*logits.shape[:-1]), token_prob.view(*logits.shape[:-1])
scaled = flat_logits
if temperature != 1.0:
scaled = flat_logits / temperature
filtered = BlockRefinementScheduler._top_k_filtering(scaled, top_k=top_k)
filtered = BlockRefinementScheduler._top_p_filtering(filtered, top_p=top_p)
probs = torch.softmax(filtered.float(), dim=-1)
token = torch.multinomial(probs, num_samples=1, generator=generator)
token_prob = torch.gather(probs, -1, token)
return token.view(*logits.shape[:-1]), token_prob.view(*logits.shape[:-1])
def step(
self,
model_output: torch.Tensor,
timestep: int | torch.Tensor,
sample: torch.LongTensor,
*,
mask_token_id: int,
temperature: float = 0.0,
top_p: float | None = None,
top_k: int | None = None,
sampling_method: str = "auto",
threshold: float | None = None,
editing_threshold: float | None = None,
minimal_topk: int | None = None,
prompt_mask: torch.BoolTensor | None = None,
generator: torch.Generator | None = None,
return_dict: bool = True,
) -> (
BlockRefinementSchedulerOutput
| tuple[torch.LongTensor, torch.BoolTensor, torch.BoolTensor, torch.LongTensor, torch.Tensor]
):
"""
Perform a single refinement step: sample from logits, commit confident tokens, and optionally edit existing
ones.
Args:
model_output (`torch.Tensor` of shape `(batch_size, block_length, vocab_size)`):
Raw logits from the model for the current block.
timestep (`int` or `torch.Tensor`):
Current step index within the block's refinement schedule.
sample (`torch.LongTensor` of shape `(batch_size, block_length)`):
Current block token IDs (contains mask tokens for uncommitted positions).
mask_token_id (`int`):
Token ID used for masked positions.
temperature (`float`):
Sampling temperature.
top_p (`float`, *optional*):
Nucleus sampling cutoff.
top_k (`int`, *optional*):
Top-k sampling cutoff.
sampling_method (`str`):
Sampling method (`auto`, `greedy`, `multinomial`).
threshold (`float`, *optional*):
Confidence threshold for committing tokens. Defaults to config value.
editing_threshold (`float`, *optional*):
Confidence threshold for editing non-mask tokens; must be positive to enable editing. Defaults to
config value.
minimal_topk (`int`, *optional*):
Minimum tokens to commit per step. Defaults to config value.
prompt_mask (`torch.BoolTensor`, *optional*):
Boolean mask of shape `(block_length,)` where `True` marks prompt (non-editable) positions.
generator (`torch.Generator`, *optional*):
RNG for sampling.
return_dict (`bool`):
Whether to return a `BlockRefinementSchedulerOutput` or a tuple.
"""
if threshold is None:
threshold = float(self.config.threshold)
if editing_threshold is None:
editing_threshold = self.config.editing_threshold
if minimal_topk is None:
minimal_topk = self.config.minimal_topk
# Sample from logits
use_multinomial = sampling_method == "multinomial" or (sampling_method == "auto" and temperature != 0.0)
sampled_tokens, sampled_probs = self._sample_from_logits(
model_output,
temperature=temperature,
top_k=top_k,
top_p=top_p,
generator=generator,
use_multinomial=use_multinomial,
)
batch_size, block_length = sample.shape
active_block = sample == mask_token_id
masks_remaining = active_block.any()
if isinstance(timestep, torch.Tensor):
step_index = int(timestep.item())
else:
step_index = int(timestep)
# --- Mask-filling transfer ---
transfer_index = torch.zeros_like(sampled_tokens, dtype=torch.bool)
if masks_remaining and self._transfer_schedule is not None:
clamped_step = min(step_index, len(self._transfer_schedule) - 1)
num_to_transfer = int(self._transfer_schedule[clamped_step].item())
confidence = torch.where(
active_block,
sampled_probs.to(dtype=torch.float32),
torch.full_like(sampled_probs, -torch.inf, dtype=torch.float32),
)
for b in range(batch_size):
high_conf = confidence[b] > threshold
if high_conf.sum().item() >= num_to_transfer:
transfer_index[b] = high_conf
else:
k = min(num_to_transfer, int(active_block[b].sum().item()))
if k > 0:
_, idx = torch.topk(confidence[b], k=k)
transfer_index[b, idx] = True
# --- Editing transfer (non-mask, non-prompt positions) ---
editing_enabled = editing_threshold is not None and editing_threshold > 0.0
editing_transfer_index = torch.zeros_like(sampled_tokens, dtype=torch.bool)
if editing_enabled:
if prompt_mask is None:
prompt_mask = torch.zeros(block_length, device=sample.device, dtype=torch.bool)
editable = (~active_block) & (~prompt_mask.unsqueeze(0))
editing_conf = torch.where(
editable,
sampled_probs.to(dtype=torch.float32),
torch.full_like(sampled_probs, -torch.inf, dtype=torch.float32),
)
high_conf_edit = editing_conf > float(editing_threshold)
token_changed = sampled_tokens != sample
editing_transfer_index = high_conf_edit & token_changed & editable
# Apply transfers
final_transfer = transfer_index | editing_transfer_index
prev_sample = sample.clone()
if final_transfer.any():
prev_sample[final_transfer] = sampled_tokens[final_transfer]
if not return_dict:
return prev_sample, transfer_index, editing_transfer_index, sampled_tokens, sampled_probs
return BlockRefinementSchedulerOutput(
prev_sample=prev_sample,
transfer_index=transfer_index,
editing_transfer_index=editing_transfer_index,
sampled_tokens=sampled_tokens,
sampled_probs=sampled_probs,
)
@staticmethod
def check_eos_finished(
cur_x: torch.LongTensor,
sampled_tokens: torch.LongTensor,
final_transfer: torch.BoolTensor,
finished: torch.BoolTensor,
eos_token_id: int,
mask_token_id: int,
prompt_length: int,
) -> torch.BoolTensor:
"""
Update per-batch finished flags when EOS tokens are committed.
Args:
cur_x (`torch.LongTensor` of shape `(batch_size, seq_len)`):
Current full sequence including all blocks up to the current window.
sampled_tokens (`torch.LongTensor` of shape `(batch_size, block_length)`):
Tokens sampled by the scheduler in this step.
final_transfer (`torch.BoolTensor` of shape `(batch_size, block_length)`):
Combined mask of committed and edited positions.
finished (`torch.BoolTensor` of shape `(batch_size,)`):
Current per-batch finished flags.
eos_token_id (`int`):
EOS token ID.
mask_token_id (`int`):
Mask token ID.
prompt_length (`int`):
Number of prompt tokens at the start of the sequence.
Returns:
`torch.BoolTensor`: Updated finished flags.
"""
batch_size = cur_x.shape[0]
for b in range(batch_size):
if finished[b]:
continue
eos_in_commits = (sampled_tokens[b][final_transfer[b]] == eos_token_id).any().item()
if not eos_in_commits:
continue
eos_pos = (cur_x[b] == eos_token_id).nonzero(as_tuple=True)
if len(eos_pos[0]) == 0:
continue
eos_pos = int(eos_pos[0][0].item())
if prompt_length >= eos_pos:
continue
if (cur_x[b, prompt_length:eos_pos] != mask_token_id).all().item():
finished[b] = True
return finished
def check_block_should_continue(
self,
step_idx: int,
masks_remaining: bool,
editing_enabled: bool,
editing_transfer_index: torch.BoolTensor,
post_steps: int,
max_post_steps: int,
finished: torch.BoolTensor,
) -> bool:
"""
Determine whether the inner refinement loop should continue for the current block.
Args:
step_idx (`int`):
Current refinement step index within this block.
masks_remaining (`bool`):
Whether any mask tokens remain in the block.
editing_enabled (`bool`):
Whether editing mode is active.
editing_transfer_index (`torch.BoolTensor`):
Which tokens were edited in this step.
post_steps (`int`):
Number of post-mask editing steps taken so far.
max_post_steps (`int`):
Maximum allowed post-mask editing steps.
finished (`torch.BoolTensor`):
Per-batch finished flags (from EOS detection).
Returns:
`bool`: `True` if refinement should continue, `False` to break.
"""
if finished.all():
return False
if not masks_remaining and not editing_enabled:
return False
if not masks_remaining and not editing_transfer_index.any():
return False
if masks_remaining and step_idx >= self.num_inference_steps:
return False
if not masks_remaining and post_steps > max_post_steps:
return False
return True
def add_noise(
self,
original_samples: torch.LongTensor,
attention_mask: torch.LongTensor,
*,
prompt_length: int,
block_length: int,
mask_token_id: int,
generator: torch.Generator | None = None,
) -> tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor, torch.BoolTensor]:
"""
Apply the forward (noising) process for semi-autoregressive block masking.
For each block after the prompt, a random fraction of valid (non-padding) tokens are replaced with
`mask_token_id`. Two complementary views are returned: `noisy` and `noisy_rev`, where the masked positions in
one are the unmasked positions in the other.
Args:
original_samples (`torch.LongTensor` of shape `(batch_size, seq_len)`):
Clean token IDs.
attention_mask (`torch.LongTensor` of shape `(batch_size, seq_len)`):
Padding mask (1 for valid, 0 for padding).
prompt_length (`int`):
Number of leading prompt tokens to keep unmasked.
block_length (`int`):
Block size for masking.
mask_token_id (`int`):
Token ID to use for masked positions.
generator (`torch.Generator`, *optional*):
RNG for reproducibility.
Returns:
`tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor, torch.BoolTensor]`:
`(noisy, noisy_rev, masked, masked_rev)` — the two complementary noisy sequences and their
corresponding boolean masks.
"""
batch_size, seq_len = original_samples.shape
device = original_samples.device
noisy = original_samples.clone()
noisy_rev = original_samples.clone()
masked = torch.zeros_like(original_samples, dtype=torch.bool)
masked_rev = torch.zeros_like(original_samples, dtype=torch.bool)
valid = attention_mask.to(dtype=torch.bool)
for block_start in range(prompt_length, seq_len, block_length):
block_end = min(seq_len, block_start + block_length)
seg_len = block_end - block_start
if seg_len <= 0:
continue
p_mask = torch.rand((batch_size, 1), device=device, generator=generator)
seg = torch.rand((batch_size, seg_len), device=device, generator=generator) < p_mask
seg = seg & valid[:, block_start:block_end]
seg_rev = (~seg) & valid[:, block_start:block_end]
masked[:, block_start:block_end] = seg
masked_rev[:, block_start:block_end] = seg_rev
noisy = torch.where(masked, torch.full_like(noisy, mask_token_id), noisy)
noisy_rev = torch.where(masked_rev, torch.full_like(noisy_rev, mask_token_id), noisy_rev)
return noisy, noisy_rev, masked, masked_rev
__all__ = ["BlockRefinementScheduler", "BlockRefinementSchedulerOutput"]

View File

@@ -11,7 +11,6 @@ from typing import Any, Iterable
import numpy as np
import torch
import torch.nn.functional as F
if getattr(torch, "distributed", None) is not None:
@@ -110,92 +109,6 @@ def compute_snr(noise_scheduler, timesteps):
return snr
def compute_confidence_aware_loss(
logits: torch.Tensor,
labels: torch.Tensor,
*,
lambda_conf: float = 0.0,
temperature: float = 1.0,
per_token_weights: torch.Tensor | None = None,
ignore_index: int = -100,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Computes a confidence-aware training loss for token classification-style heads.
This loss combines:
- `loss_sft`: standard supervised cross-entropy on all non-ignored labels.
- `loss_conf`: an entropy penalty applied only on tokens that are already predicted correctly.
Args:
logits (`torch.Tensor`): Logits of shape `(..., vocab_size)`.
labels (`torch.Tensor`): Labels of shape `(...)`, matching `logits.shape[:-1]`. Values set to `ignore_index`
are excluded from both losses.
lambda_conf (`float`, *optional*, defaults to `0.0`): Weight for the confidence term.
temperature (`float`, *optional*, defaults to `1.0`): Temperature used for the entropy term only. Lower values
sharpen the distribution and change the strength of the confidence gradients.
per_token_weights (`torch.Tensor`, *optional*): Optional weights of shape `(...)` to reweight both losses per
token (e.g. schedule-aware weights). Tokens with weight `0` contribute nothing.
ignore_index (`int`, *optional*, defaults to `-100`): Ignore index for labels.
Returns:
`Tuple[torch.Tensor, torch.Tensor, torch.Tensor]`: `(loss, loss_sft, loss_conf)`.
"""
if logits.ndim < 2:
raise ValueError(f"`logits` must have at least 2 dims, got shape {tuple(logits.shape)}.")
if labels.shape != logits.shape[:-1]:
raise ValueError(
f"`labels` shape must match `logits.shape[:-1]`, got labels={tuple(labels.shape)} logits={tuple(logits.shape)}."
)
if temperature <= 0:
raise ValueError(f"`temperature` must be > 0, got {temperature}.")
valid = labels.ne(ignore_index)
if per_token_weights is None:
weights = torch.ones_like(labels, dtype=logits.dtype)
else:
if per_token_weights.shape != labels.shape:
raise ValueError(
f"`per_token_weights` shape must match `labels` shape, got {tuple(per_token_weights.shape)} != {tuple(labels.shape)}."
)
weights = per_token_weights.to(dtype=logits.dtype)
# Supervised CE (optionally weighted).
vocab_size = logits.shape[-1]
per_token_nll = F.cross_entropy(
logits.reshape(-1, vocab_size),
labels.reshape(-1),
reduction="none",
ignore_index=ignore_index,
).reshape_as(labels)
denom_sft = (weights * valid.to(weights.dtype)).sum().clamp_min(1)
loss_sft = (per_token_nll * weights * valid.to(per_token_nll.dtype)).sum() / denom_sft
# Confidence loss: penalize entropy only where prediction is already correct.
if lambda_conf == 0.0:
loss_conf = torch.zeros((), device=logits.device, dtype=loss_sft.dtype)
return loss_sft, loss_sft, loss_conf
with torch.no_grad():
pred = logits.argmax(dim=-1)
correct = valid & pred.eq(labels)
scaled_logits = logits.float()
if temperature != 1.0:
scaled_logits = scaled_logits / float(temperature)
probs = torch.softmax(scaled_logits, dim=-1)
eps = torch.finfo(probs.dtype).tiny
log_probs = torch.log(probs.clamp_min(eps))
entropy = -(probs * log_probs).sum(dim=-1).to(dtype=logits.dtype)
denom_conf = (weights * correct.to(weights.dtype)).sum().clamp_min(1)
loss_conf = (entropy * weights * correct.to(entropy.dtype)).sum() / denom_conf
loss = loss_sft + float(lambda_conf) * loss_conf
return loss, loss_sft, loss_conf
def resolve_interpolation_mode(interpolation_type: str):
"""
Maps a string describing an interpolation function to the corresponding torchvision `InterpolationMode` enum. The

View File

@@ -521,36 +521,6 @@ class AutoencoderKLHunyuanVideo15(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class AutoencoderKLKVAE(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class AutoencoderKLKVAEVideo(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class AutoencoderKLLTX2Audio(metaclass=DummyObject):
_backends = ["torch"]
@@ -2518,36 +2488,6 @@ class AmusedScheduler(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class BlockRefinementScheduler(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class BlockRefinementSchedulerOutput(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class CMStochasticIterativeScheduler(metaclass=DummyObject):
_backends = ["torch"]

View File

@@ -2222,36 +2222,6 @@ class LEditsPPPipelineStableDiffusionXL(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class LLaDA2Pipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class LLaDA2PipelineOutput(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class LongCatImageEditPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]

View File

@@ -29,7 +29,6 @@ from numpy.linalg import norm
from packaging import version
from .constants import DIFFUSERS_REQUEST_TIMEOUT
from .deprecation_utils import deprecate
from .import_utils import (
BACKENDS_MAPPING,
is_accelerate_available,
@@ -68,11 +67,9 @@ else:
global_rng = random.Random()
logger = get_logger(__name__)
deprecate(
"diffusers.utils.testing_utils",
"1.0.0",
"diffusers.utils.testing_utils is deprecated and will be removed in a future version. "
"Determinism and device backend utilities have been moved to `diffusers.utils.torch_utils`. ",
logger.warning(
"diffusers.utils.testing_utils' is deprecated and will be removed in a future version. "
"Determinism and device backend utilities have been moved to `diffusers.utils.torch_utils`. "
)
_required_peft_version = is_peft_available() and version.parse(
version.parse(importlib.metadata.version("peft")).base_version

View File

@@ -19,16 +19,11 @@ from __future__ import annotations
import functools
import os
from typing import Callable, ParamSpec, TypeVar
from . import logging
from .import_utils import is_torch_available, is_torch_mlu_available, is_torch_npu_available, is_torch_version
T = TypeVar("T")
P = ParamSpec("P")
if is_torch_available():
import torch
from torch.fft import fftn, fftshift, ifftn, ifftshift
@@ -338,23 +333,5 @@ def disable_full_determinism():
torch.use_deterministic_algorithms(False)
@functools.wraps(functools.lru_cache)
def lru_cache_unless_export(maxsize=128, typed=False):
def outer_wrapper(fn: Callable[P, T]):
cached = functools.lru_cache(maxsize=maxsize, typed=typed)(fn)
if is_torch_version("<", "2.7.0"):
return cached
@functools.wraps(fn)
def inner_wrapper(*args: P.args, **kwargs: P.kwargs):
if torch.compiler.is_exporting():
return fn(*args, **kwargs)
return cached(*args, **kwargs)
return inner_wrapper
return outer_wrapper
if is_torch_available():
torch_device = get_device()

View File

@@ -1,73 +0,0 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from diffusers import AutoencoderKLKVAE
from ...testing_utils import enable_full_determinism, floats_tensor, torch_device
from ..test_modeling_common import ModelTesterMixin
from .testing_utils import AutoencoderTesterMixin
enable_full_determinism()
class AutoencoderKLKVAETests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AutoencoderKLKVAE
main_input_name = "sample"
base_precision = 1e-2
def get_autoencoder_kl_kvae_config(self):
return {
"in_channels": 3,
"channels": 32,
"num_enc_blocks": 1,
"num_dec_blocks": 1,
"z_channels": 4,
"double_z": True,
"ch_mult": (1, 2),
"sample_size": 32,
}
@property
def dummy_input(self):
batch_size = 2
num_channels = 3
sizes = (32, 32)
image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
return {"sample": image}
@property
def input_shape(self):
return (3, 32, 32)
@property
def output_shape(self):
return (3, 32, 32)
def prepare_init_args_and_inputs_for_common(self):
init_dict = self.get_autoencoder_kl_kvae_config()
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_gradient_checkpointing_is_applied(self):
expected_set = {
"KVAEEncoder2D",
"KVAEDecoder2D",
}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)

View File

@@ -1,118 +0,0 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from diffusers import AutoencoderKLKVAEVideo
from ...testing_utils import enable_full_determinism, floats_tensor, torch_device
from ..test_modeling_common import ModelTesterMixin
from .testing_utils import AutoencoderTesterMixin
enable_full_determinism()
class AutoencoderKLKVAEVideoTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AutoencoderKLKVAEVideo
main_input_name = "sample"
base_precision = 1e-2
def get_autoencoder_kl_kvae_video_config(self):
return {
"ch": 32,
"ch_mult": (1, 2),
"num_res_blocks": 1,
"in_channels": 3,
"out_ch": 3,
"z_channels": 4,
"temporal_compress_times": 2,
}
@property
def dummy_input(self):
batch_size = 2
num_frames = 3 # satisfies (T-1) % temporal_compress_times == 0 with temporal_compress_times=2
num_channels = 3
sizes = (16, 16)
video = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device)
return {"sample": video}
@property
def input_shape(self):
return (3, 3, 16, 16)
@property
def output_shape(self):
return (3, 3, 16, 16)
def prepare_init_args_and_inputs_for_common(self):
init_dict = self.get_autoencoder_kl_kvae_video_config()
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_gradient_checkpointing_is_applied(self):
expected_set = {
"KVAECachedEncoder3D",
"KVAECachedDecoder3D",
}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
@unittest.skip("Unsupported test.")
def test_outputs_equivalence(self):
pass
@unittest.skip(
"Multi-GPU inference is not supported due to the stateful cache_dict passing through the forward pass."
)
def test_model_parallelism(self):
pass
@unittest.skip(
"Multi-GPU inference is not supported due to the stateful cache_dict passing through the forward pass."
)
def test_sharded_checkpoints_device_map(self):
pass
def _run_nondeterministic(self, fn):
# reflection_pad3d_backward_out_cuda has no deterministic CUDA implementation;
# temporarily relax the requirement for training tests that do backward passes.
import torch
torch.use_deterministic_algorithms(False)
try:
fn()
finally:
torch.use_deterministic_algorithms(True)
def test_training(self):
self._run_nondeterministic(super().test_training)
def test_ema_training(self):
self._run_nondeterministic(super().test_ema_training)
@unittest.skip(
"Gradient checkpointing recomputes the forward pass, but the model uses a stateful cache_dict "
"that is mutated during the first forward. On recomputation the cache is already populated, "
"causing a different execution path and numerically different gradients. "
"GC still reduces peak memory usage; gradient correctness in the presence of GC is a known limitation."
)
def test_effective_gradient_checkpointing(self):
pass
def test_layerwise_casting_training(self):
self._run_nondeterministic(super().test_layerwise_casting_training)

View File

@@ -481,8 +481,6 @@ class LoraHotSwappingForModelTesterMixin:
# ensure that enable_lora_hotswap is called before loading the first adapter
import logging
from diffusers.utils import logging as diffusers_logging
lora_config = self._get_lora_config(8, 8, target_modules=["to_q"])
init_dict = self.get_init_dict()
model = self.model_class(**init_dict).to(torch_device)
@@ -490,31 +488,21 @@ class LoraHotSwappingForModelTesterMixin:
msg = (
"It is recommended to call `enable_lora_hotswap` before loading the first adapter to avoid recompilation."
)
diffusers_logging.enable_propagation()
try:
with caplog.at_level(logging.WARNING):
model.enable_lora_hotswap(target_rank=32, check_compiled="warn")
assert any(msg in record.message for record in caplog.records)
finally:
diffusers_logging.disable_propagation()
with caplog.at_level(logging.WARNING):
model.enable_lora_hotswap(target_rank=32, check_compiled="warn")
assert any(msg in record.message for record in caplog.records)
def test_enable_lora_hotswap_called_after_adapter_added_ignore(self, caplog):
# check possibility to ignore the error/warning
import logging
from diffusers.utils import logging as diffusers_logging
lora_config = self._get_lora_config(8, 8, target_modules=["to_q"])
init_dict = self.get_init_dict()
model = self.model_class(**init_dict).to(torch_device)
model.add_adapter(lora_config)
diffusers_logging.enable_propagation()
try:
with caplog.at_level(logging.WARNING):
model.enable_lora_hotswap(target_rank=32, check_compiled="ignore")
assert len(caplog.records) == 0
finally:
diffusers_logging.disable_propagation()
with caplog.at_level(logging.WARNING):
model.enable_lora_hotswap(target_rank=32, check_compiled="ignore")
assert len(caplog.records) == 0
def test_enable_lora_hotswap_wrong_check_compiled_argument_raises(self):
# check that wrong argument value raises an error
@@ -530,26 +518,20 @@ class LoraHotSwappingForModelTesterMixin:
# check the error and log
import logging
from diffusers.utils import logging as diffusers_logging
# at the moment, PEFT requires the 2nd adapter to target the same or a subset of layers
target_modules0 = ["to_q"]
target_modules1 = ["to_q", "to_k"]
diffusers_logging.enable_propagation()
try:
with pytest.raises(RuntimeError): # peft raises RuntimeError
with caplog.at_level(logging.ERROR):
self._check_model_hotswap(
tmp_path,
do_compile=True,
rank0=8,
rank1=8,
target_modules0=target_modules0,
target_modules1=target_modules1,
)
assert any("Hotswapping adapter0 was unsuccessful" in record.message for record in caplog.records)
finally:
diffusers_logging.disable_propagation()
with pytest.raises(RuntimeError): # peft raises RuntimeError
with caplog.at_level(logging.ERROR):
self._check_model_hotswap(
tmp_path,
do_compile=True,
rank0=8,
rank1=8,
target_modules0=target_modules0,
target_modules1=target_modules1,
)
assert any("Hotswapping adapter0 was unsuccessful" in record.message for record in caplog.records)
@pytest.mark.parametrize("rank0,rank1", [(11, 11), (7, 13), (13, 7)])
@require_torch_version_greater("2.7.1")

View File

@@ -22,7 +22,6 @@ import torch.distributed as dist
import torch.multiprocessing as mp
from diffusers.models._modeling_parallel import ContextParallelConfig
from diffusers.models.attention_dispatch import AttentionBackendName, _AttentionBackendRegistry
from ...testing_utils import (
is_context_parallel,
@@ -161,21 +160,16 @@ def _custom_mesh_worker(
@require_torch_multi_accelerator
class ContextParallelTesterMixin:
@pytest.mark.parametrize("cp_type", ["ulysses_degree", "ring_degree"], ids=["ulysses", "ring"])
def test_context_parallel_inference(self, cp_type, batch_size: int = 1):
def test_context_parallel_inference(self, cp_type):
if not torch.distributed.is_available():
pytest.skip("torch.distributed is not available.")
if not hasattr(self.model_class, "_cp_plan") or self.model_class._cp_plan is None:
pytest.skip("Model does not have a _cp_plan defined for context parallel inference.")
if cp_type == "ring_degree":
active_backend, _ = _AttentionBackendRegistry.get_active_backend()
if active_backend == AttentionBackendName.NATIVE:
pytest.skip("Ring attention is not supported with the native attention backend.")
world_size = 2
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs(batch_size=batch_size)
inputs_dict = self.get_dummy_inputs()
# Move all tensors to CPU for multiprocessing
inputs_dict = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in inputs_dict.items()}
@@ -200,10 +194,6 @@ class ContextParallelTesterMixin:
f"Context parallel inference failed: {return_dict.get('error', 'Unknown error')}"
)
@pytest.mark.parametrize("cp_type", ["ulysses_degree", "ring_degree"], ids=["ulysses", "ring"])
def test_context_parallel_batch_inputs(self, cp_type):
self.test_context_parallel_inference(cp_type, batch_size=2)
@pytest.mark.parametrize(
"cp_type,mesh_shape,mesh_dim_names",
[
@@ -219,11 +209,6 @@ class ContextParallelTesterMixin:
if not hasattr(self.model_class, "_cp_plan") or self.model_class._cp_plan is None:
pytest.skip("Model does not have a _cp_plan defined for context parallel inference.")
if cp_type == "ring_degree":
active_backend, _ = _AttentionBackendRegistry.get_active_backend()
if active_backend == AttentionBackendName.NATIVE:
pytest.skip("Ring attention is not supported with the native attention backend.")
world_size = 2
init_dict = self.get_init_dict()
inputs_dict = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in self.get_dummy_inputs().items()}

View File

@@ -13,31 +13,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
import unittest
import torch
from diffusers import BriaTransformer2DModel
from diffusers.models.attention_processor import FluxIPAdapterJointAttnProcessor2_0
from diffusers.models.embeddings import ImageProjection
from diffusers.utils.torch_utils import randn_tensor
from ...testing_utils import enable_full_determinism, torch_device
from ..testing_utils import (
BaseModelTesterConfig,
IPAdapterTesterMixin,
LoraHotSwappingForModelTesterMixin,
LoraTesterMixin,
ModelTesterMixin,
TorchCompileTesterMixin,
TrainingTesterMixin,
)
from ..test_modeling_common import LoraHotSwappingForModelTesterMixin, ModelTesterMixin, TorchCompileTesterMixin
enable_full_determinism()
def create_bria_ip_adapter_state_dict(model) -> dict[str, dict[str, Any]]:
def create_bria_ip_adapter_state_dict(model):
# "ip_adapter" (cross-attention weights)
ip_cross_attn_state_dict = {}
key_id = 0
@@ -58,8 +50,11 @@ def create_bria_ip_adapter_state_dict(model) -> dict[str, dict[str, Any]]:
f"{key_id}.to_v_ip.bias": sd["to_v_ip.0.bias"],
}
)
key_id += 1
# "image_proj" (ImageProjection layer weights)
image_projection = ImageProjection(
cross_attention_dim=model.config["joint_attention_dim"],
image_embed_dim=model.config["pooled_projection_dim"],
@@ -78,36 +73,53 @@ def create_bria_ip_adapter_state_dict(model) -> dict[str, dict[str, Any]]:
)
del sd
return {"image_proj": ip_image_projection_state_dict, "ip_adapter": ip_cross_attn_state_dict}
ip_state_dict = {}
ip_state_dict.update({"image_proj": ip_image_projection_state_dict, "ip_adapter": ip_cross_attn_state_dict})
return ip_state_dict
class BriaTransformerTesterConfig(BaseModelTesterConfig):
@property
def model_class(self):
return BriaTransformer2DModel
class BriaTransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = BriaTransformer2DModel
main_input_name = "hidden_states"
# We override the items here because the transformer under consideration is small.
model_split_percents = [0.8, 0.7, 0.7]
# Skip setting testing with default: AttnProcessor
uses_custom_attn_processor = True
@property
def main_input_name(self) -> str:
return "hidden_states"
def dummy_input(self):
batch_size = 1
num_latent_channels = 4
num_image_channels = 3
height = width = 4
sequence_length = 48
embedding_dim = 32
@property
def model_split_percents(self) -> list:
return [0.8, 0.7, 0.7]
hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
text_ids = torch.randn((sequence_length, num_image_channels)).to(torch_device)
image_ids = torch.randn((height * width, num_image_channels)).to(torch_device)
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
@property
def output_shape(self) -> tuple:
return (16, 4)
@property
def input_shape(self) -> tuple:
return (16, 4)
@property
def generator(self):
return torch.Generator("cpu").manual_seed(0)
def get_init_dict(self) -> dict:
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"img_ids": image_ids,
"txt_ids": text_ids,
"timestep": timestep,
}
@property
def input_shape(self):
return (16, 4)
@property
def output_shape(self):
return (16, 4)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"patch_size": 1,
"in_channels": 4,
"num_layers": 1,
@@ -119,35 +131,11 @@ class BriaTransformerTesterConfig(BaseModelTesterConfig):
"axes_dims_rope": [0, 4, 4],
}
def get_dummy_inputs(self, batch_size: int = 1) -> dict[str, torch.Tensor]:
num_latent_channels = 4
num_image_channels = 3
height = width = 4
sequence_length = 48
embedding_dim = 32
inputs_dict = self.dummy_input
return init_dict, inputs_dict
return {
"hidden_states": randn_tensor(
(batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device
),
"encoder_hidden_states": randn_tensor(
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
),
"img_ids": randn_tensor(
(height * width, num_image_channels), generator=self.generator, device=torch_device
),
"txt_ids": randn_tensor(
(sequence_length, num_image_channels), generator=self.generator, device=torch_device
),
"timestep": torch.tensor([1.0]).to(torch_device).expand(batch_size),
}
class TestBriaTransformer(BriaTransformerTesterConfig, ModelTesterMixin):
def test_deprecated_inputs_img_txt_ids_3d(self):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
@@ -155,6 +143,7 @@ class TestBriaTransformer(BriaTransformerTesterConfig, ModelTesterMixin):
with torch.no_grad():
output_1 = model(**inputs_dict).to_tuple()[0]
# update inputs_dict with txt_ids and img_ids as 3d tensors (deprecated)
text_ids_3d = inputs_dict["txt_ids"].unsqueeze(0)
image_ids_3d = inputs_dict["img_ids"].unsqueeze(0)
@@ -167,59 +156,26 @@ class TestBriaTransformer(BriaTransformerTesterConfig, ModelTesterMixin):
with torch.no_grad():
output_2 = model(**inputs_dict).to_tuple()[0]
assert output_1.shape == output_2.shape
assert torch.allclose(output_1, output_2, atol=1e-5), (
"output with deprecated inputs (img_ids and txt_ids as 3d torch tensors) "
"are not equal as them as 2d inputs"
self.assertEqual(output_1.shape, output_2.shape)
self.assertTrue(
torch.allclose(output_1, output_2, atol=1e-5),
msg="output with deprecated inputs (img_ids and txt_ids as 3d torch tensors) are not equal as them as 2d inputs",
)
class TestBriaTransformerTraining(BriaTransformerTesterConfig, TrainingTesterMixin):
def test_gradient_checkpointing_is_applied(self):
expected_set = {"BriaTransformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class TestBriaTransformerCompile(BriaTransformerTesterConfig, TorchCompileTesterMixin):
pass
class BriaTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
model_class = BriaTransformer2DModel
def prepare_init_args_and_inputs_for_common(self):
return BriaTransformerTests().prepare_init_args_and_inputs_for_common()
class TestBriaTransformerIPAdapter(BriaTransformerTesterConfig, IPAdapterTesterMixin):
@property
def ip_adapter_processor_cls(self):
return FluxIPAdapterJointAttnProcessor2_0
class BriaTransformerLoRAHotSwapTests(LoraHotSwappingForModelTesterMixin, unittest.TestCase):
model_class = BriaTransformer2DModel
def modify_inputs_for_ip_adapter(self, model, inputs_dict):
torch.manual_seed(0)
cross_attention_dim = getattr(model.config, "joint_attention_dim", 32)
image_embeds = torch.randn(1, 1, cross_attention_dim).to(torch_device)
inputs_dict.update({"joint_attention_kwargs": {"ip_adapter_image_embeds": image_embeds}})
return inputs_dict
def create_ip_adapter_state_dict(self, model: Any) -> dict[str, dict[str, Any]]:
return create_bria_ip_adapter_state_dict(model)
class TestBriaTransformerLoRA(BriaTransformerTesterConfig, LoraTesterMixin):
pass
class TestBriaTransformerLoRAHotSwap(BriaTransformerTesterConfig, LoraHotSwappingForModelTesterMixin):
@property
def different_shapes_for_compilation(self):
return [(4, 4), (4, 8), (8, 8)]
def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:
batch_size = 1
num_latent_channels = 4
num_image_channels = 3
sequence_length = 24
embedding_dim = 32
return {
"hidden_states": randn_tensor((batch_size, height * width, num_latent_channels), device=torch_device),
"encoder_hidden_states": randn_tensor((batch_size, sequence_length, embedding_dim), device=torch_device),
"img_ids": randn_tensor((height * width, num_image_channels), device=torch_device),
"txt_ids": randn_tensor((sequence_length, num_image_channels), device=torch_device),
"timestep": torch.tensor([1.0]).to(torch_device).expand(batch_size),
}
def prepare_init_args_and_inputs_for_common(self):
return BriaTransformerTests().prepare_init_args_and_inputs_for_common()

View File

@@ -13,50 +13,62 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import torch
from diffusers import BriaFiboTransformer2DModel
from diffusers.utils.torch_utils import randn_tensor
from ...testing_utils import enable_full_determinism, torch_device
from ..testing_utils import (
BaseModelTesterConfig,
ModelTesterMixin,
TorchCompileTesterMixin,
TrainingTesterMixin,
)
from ..test_modeling_common import ModelTesterMixin
enable_full_determinism()
class BriaFiboTransformerTesterConfig(BaseModelTesterConfig):
@property
def model_class(self):
return BriaFiboTransformer2DModel
class BriaFiboTransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = BriaFiboTransformer2DModel
main_input_name = "hidden_states"
# We override the items here because the transformer under consideration is small.
model_split_percents = [0.8, 0.7, 0.7]
# Skip setting testing with default: AttnProcessor
uses_custom_attn_processor = True
@property
def main_input_name(self) -> str:
return "hidden_states"
def dummy_input(self):
batch_size = 1
num_latent_channels = 48
num_image_channels = 3
height = width = 16
sequence_length = 32
embedding_dim = 64
hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
text_ids = torch.randn((sequence_length, num_image_channels)).to(torch_device)
image_ids = torch.randn((height * width, num_image_channels)).to(torch_device)
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"img_ids": image_ids,
"txt_ids": text_ids,
"timestep": timestep,
"text_encoder_layers": [encoder_hidden_states[:, :, :32], encoder_hidden_states[:, :, :32]],
}
@property
def model_split_percents(self) -> list:
return [0.8, 0.7, 0.7]
@property
def output_shape(self) -> tuple:
return (256, 48)
@property
def input_shape(self) -> tuple:
def input_shape(self):
return (16, 16)
@property
def generator(self):
return torch.Generator("cpu").manual_seed(0)
def output_shape(self):
return (256, 48)
def get_init_dict(self) -> dict:
return {
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"patch_size": 1,
"in_channels": 48,
"num_layers": 1,
@@ -69,41 +81,9 @@ class BriaFiboTransformerTesterConfig(BaseModelTesterConfig):
"axes_dims_rope": [0, 4, 4],
}
def get_dummy_inputs(self, batch_size: int = 1) -> dict[str, torch.Tensor]:
num_latent_channels = 48
num_image_channels = 3
height = width = 16
sequence_length = 32
embedding_dim = 64
inputs_dict = self.dummy_input
return init_dict, inputs_dict
encoder_hidden_states = randn_tensor(
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
)
return {
"hidden_states": randn_tensor(
(batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device
),
"encoder_hidden_states": encoder_hidden_states,
"img_ids": randn_tensor(
(height * width, num_image_channels), generator=self.generator, device=torch_device
),
"txt_ids": randn_tensor(
(sequence_length, num_image_channels), generator=self.generator, device=torch_device
),
"timestep": torch.tensor([1.0]).to(torch_device).expand(batch_size),
"text_encoder_layers": [encoder_hidden_states[:, :, :32], encoder_hidden_states[:, :, :32]],
}
class TestBriaFiboTransformer(BriaFiboTransformerTesterConfig, ModelTesterMixin):
pass
class TestBriaFiboTransformerTraining(BriaFiboTransformerTesterConfig, TrainingTesterMixin):
def test_gradient_checkpointing_is_applied(self):
expected_set = {"BriaFiboTransformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class TestBriaFiboTransformerCompile(BriaFiboTransformerTesterConfig, TorchCompileTesterMixin):
pass

View File

@@ -150,7 +150,8 @@ class FluxTransformerTesterConfig(BaseModelTesterConfig):
"axes_dims_rope": [4, 4, 8],
}
def get_dummy_inputs(self, batch_size: int = 1) -> dict[str, torch.Tensor]:
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
batch_size = 1
height = width = 4
num_latent_channels = 4
num_image_channels = 3

View File

@@ -90,7 +90,8 @@ class Flux2TransformerTesterConfig(BaseModelTesterConfig):
"axes_dims_rope": [4, 4, 4, 4],
}
def get_dummy_inputs(self, height: int = 4, width: int = 4, batch_size: int = 1) -> dict[str, torch.Tensor]:
def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:
batch_size = 1
num_latent_channels = 4
sequence_length = 48
embedding_dim = 32

View File

@@ -14,7 +14,6 @@
import warnings
import pytest
import torch
from diffusers import QwenImageTransformer2DModel
@@ -78,7 +77,8 @@ class QwenImageTransformerTesterConfig(BaseModelTesterConfig):
"axes_dims_rope": (8, 4, 4),
}
def get_dummy_inputs(self, batch_size: int = 1) -> dict[str, torch.Tensor]:
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
batch_size = 1
num_latent_channels = embedding_dim = 16
height = width = 4
sequence_length = 8
@@ -106,10 +106,9 @@ class QwenImageTransformerTesterConfig(BaseModelTesterConfig):
class TestQwenImageTransformer(QwenImageTransformerTesterConfig, ModelTesterMixin):
@pytest.mark.parametrize("batch_size", [1, 2])
def test_infers_text_seq_len_from_mask(self, batch_size):
def test_infers_text_seq_len_from_mask(self):
init_dict = self.get_init_dict()
inputs = self.get_dummy_inputs(batch_size=batch_size)
inputs = self.get_dummy_inputs()
model = self.model_class(**init_dict).to(torch_device)
encoder_hidden_states_mask = inputs["encoder_hidden_states_mask"].clone()
@@ -123,7 +122,7 @@ class TestQwenImageTransformer(QwenImageTransformerTesterConfig, ModelTesterMixi
assert isinstance(per_sample_len, torch.Tensor)
assert int(per_sample_len.max().item()) == 2
assert normalized_mask.dtype == torch.bool
assert normalized_mask.sum().item() == 2 * batch_size
assert normalized_mask.sum().item() == 2
assert rope_text_seq_len >= inputs["encoder_hidden_states"].shape[1]
inputs["encoder_hidden_states_mask"] = normalized_mask
@@ -140,7 +139,7 @@ class TestQwenImageTransformer(QwenImageTransformerTesterConfig, ModelTesterMixi
)
assert int(per_sample_len2.max().item()) == 8
assert normalized_mask2.sum().item() == 5 * batch_size
assert normalized_mask2.sum().item() == 5
rope_text_seq_len_none, per_sample_len_none, normalized_mask_none = compute_text_seq_len_from_mask(
inputs["encoder_hidden_states"], None
@@ -150,10 +149,9 @@ class TestQwenImageTransformer(QwenImageTransformerTesterConfig, ModelTesterMixi
assert per_sample_len_none is None
assert normalized_mask_none is None
@pytest.mark.parametrize("batch_size", [1, 2])
def test_non_contiguous_attention_mask(self, batch_size):
def test_non_contiguous_attention_mask(self):
init_dict = self.get_init_dict()
inputs = self.get_dummy_inputs(batch_size=batch_size)
inputs = self.get_dummy_inputs()
model = self.model_class(**init_dict).to(torch_device)
encoder_hidden_states_mask = inputs["encoder_hidden_states_mask"].clone()
@@ -286,14 +284,6 @@ class TestQwenImageTransformerLoRA(QwenImageTransformerTesterConfig, LoraTesterM
class TestQwenImageTransformerLoRAHotSwap(QwenImageTransformerTesterConfig, LoraHotSwappingForModelTesterMixin):
"""LoRA hot-swapping tests for QwenImage Transformer."""
@pytest.mark.xfail(True, reason="Recompilation issues.", strict=True)
def test_hotswapping_compiled_model_linear(self):
super().test_hotswapping_compiled_model_linear()
@pytest.mark.xfail(True, reason="Recompilation issues.", strict=True)
def test_hotswapping_compiled_model_both_linear_and_other(self):
super().test_hotswapping_compiled_model_both_linear_and_other()
@property
def different_shapes_for_compilation(self):
return [(4, 4), (4, 8), (8, 8)]

View File

@@ -1,3 +1,4 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -12,57 +13,58 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import torch
from diffusers import SanaTransformer2DModel
from diffusers.utils.torch_utils import randn_tensor
from ...testing_utils import (
enable_full_determinism,
torch_device,
from ...testing_utils import enable_full_determinism, torch_device
from ..testing_utils import (
AttentionTesterMixin,
BaseModelTesterConfig,
BitsAndBytesTesterMixin,
MemoryTesterMixin,
ModelTesterMixin,
TorchAoTesterMixin,
TorchCompileTesterMixin,
TrainingTesterMixin,
)
from ..test_modeling_common import ModelTesterMixin
enable_full_determinism()
class SanaTransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = SanaTransformer2DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
model_split_percents = [0.7, 0.7, 0.9]
class SanaTransformer2DTesterConfig(BaseModelTesterConfig):
@property
def model_class(self):
return SanaTransformer2DModel
@property
def dummy_input(self):
batch_size = 2
num_channels = 4
height = 32
width = 32
embedding_dim = 8
sequence_length = 8
def output_shape(self) -> tuple[int, ...]:
return (4, 32, 32)
hidden_states = torch.randn((batch_size, num_channels, height, width)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
@property
def input_shape(self) -> tuple[int, ...]:
return (4, 32, 32)
@property
def main_input_name(self) -> str:
return "hidden_states"
@property
def uses_custom_attn_processor(self) -> bool:
return True
@property
def model_split_percents(self) -> list:
return [0.7, 0.7, 0.9]
@property
def generator(self):
return torch.Generator("cpu").manual_seed(0)
def get_init_dict(self) -> dict[str, int | list[int] | tuple | str | bool]:
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"timestep": timestep,
}
@property
def input_shape(self):
return (4, 32, 32)
@property
def output_shape(self):
return (4, 32, 32)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"patch_size": 1,
"in_channels": 4,
"out_channels": 4,
@@ -75,9 +77,53 @@ class SanaTransformerTests(ModelTesterMixin, unittest.TestCase):
"caption_channels": 8,
"sample_size": 32,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
batch_size = 2
num_channels = 4
height = 32
width = 32
embedding_dim = 8
sequence_length = 8
return {
"hidden_states": randn_tensor(
(batch_size, num_channels, height, width), generator=self.generator, device=torch_device
),
"encoder_hidden_states": randn_tensor(
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
),
"timestep": torch.randint(0, 1000, size=(batch_size,)).to(torch_device),
}
class TestSanaTransformer2D(SanaTransformer2DTesterConfig, ModelTesterMixin):
"""Core model tests for Sana Transformer 2D."""
class TestSanaTransformer2DMemory(SanaTransformer2DTesterConfig, MemoryTesterMixin):
"""Memory optimization tests for Sana Transformer 2D."""
class TestSanaTransformer2DTraining(SanaTransformer2DTesterConfig, TrainingTesterMixin):
"""Training tests for Sana Transformer 2D."""
def test_gradient_checkpointing_is_applied(self):
expected_set = {"SanaTransformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class TestSanaTransformer2DAttention(SanaTransformer2DTesterConfig, AttentionTesterMixin):
"""Attention processor tests for Sana Transformer 2D."""
class TestSanaTransformer2DCompile(SanaTransformer2DTesterConfig, TorchCompileTesterMixin):
"""Torch compile tests for Sana Transformer 2D."""
class TestSanaTransformer2DBitsAndBytes(SanaTransformer2DTesterConfig, BitsAndBytesTesterMixin):
"""BitsAndBytes quantization tests for Sana Transformer 2D."""
class TestSanaTransformer2DTorchAo(SanaTransformer2DTesterConfig, TorchAoTesterMixin):
"""TorchAO quantization tests for Sana Transformer 2D."""

View File

@@ -1,3 +1,4 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -12,57 +13,54 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import torch
from diffusers import SanaVideoTransformer3DModel
from diffusers.utils.torch_utils import randn_tensor
from ...testing_utils import (
enable_full_determinism,
torch_device,
from ...testing_utils import enable_full_determinism, torch_device
from ..testing_utils import (
AttentionTesterMixin,
BaseModelTesterConfig,
BitsAndBytesTesterMixin,
MemoryTesterMixin,
ModelTesterMixin,
TorchAoTesterMixin,
TorchCompileTesterMixin,
TrainingTesterMixin,
)
from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
enable_full_determinism()
class SanaVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
model_class = SanaVideoTransformer3DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
class SanaVideoTransformer3DTesterConfig(BaseModelTesterConfig):
@property
def model_class(self):
return SanaVideoTransformer3DModel
@property
def dummy_input(self):
batch_size = 1
num_channels = 16
num_frames = 2
height = 16
width = 16
text_encoder_embedding_dim = 16
sequence_length = 12
def output_shape(self) -> tuple[int, ...]:
return (16, 2, 16, 16)
hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, text_encoder_embedding_dim)).to(torch_device)
@property
def input_shape(self) -> tuple[int, ...]:
return (16, 2, 16, 16)
@property
def main_input_name(self) -> str:
return "hidden_states"
@property
def uses_custom_attn_processor(self) -> bool:
return True
@property
def generator(self):
return torch.Generator("cpu").manual_seed(0)
def get_init_dict(self) -> dict[str, int | float | list[int] | tuple | str | bool]:
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"timestep": timestep,
}
@property
def input_shape(self):
return (16, 2, 16, 16)
@property
def output_shape(self):
return (16, 2, 16, 16)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"in_channels": 16,
"out_channels": 16,
"num_attention_heads": 2,
@@ -82,16 +80,56 @@ class SanaVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
"qk_norm": "rms_norm_across_heads",
"rope_max_seq_len": 32,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
batch_size = 1
num_channels = 16
num_frames = 2
height = 16
width = 16
text_encoder_embedding_dim = 16
sequence_length = 12
return {
"hidden_states": randn_tensor(
(batch_size, num_channels, num_frames, height, width), generator=self.generator, device=torch_device
),
"encoder_hidden_states": randn_tensor(
(batch_size, sequence_length, text_encoder_embedding_dim),
generator=self.generator,
device=torch_device,
),
"timestep": torch.randint(0, 1000, size=(batch_size,)).to(torch_device),
}
class TestSanaVideoTransformer3D(SanaVideoTransformer3DTesterConfig, ModelTesterMixin):
"""Core model tests for Sana Video Transformer 3D."""
class TestSanaVideoTransformer3DMemory(SanaVideoTransformer3DTesterConfig, MemoryTesterMixin):
"""Memory optimization tests for Sana Video Transformer 3D."""
class TestSanaVideoTransformer3DTraining(SanaVideoTransformer3DTesterConfig, TrainingTesterMixin):
"""Training tests for Sana Video Transformer 3D."""
def test_gradient_checkpointing_is_applied(self):
expected_set = {"SanaVideoTransformer3DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class SanaVideoTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
model_class = SanaVideoTransformer3DModel
class TestSanaVideoTransformer3DAttention(SanaVideoTransformer3DTesterConfig, AttentionTesterMixin):
"""Attention processor tests for Sana Video Transformer 3D."""
def prepare_init_args_and_inputs_for_common(self):
return SanaVideoTransformer3DTests().prepare_init_args_and_inputs_for_common()
class TestSanaVideoTransformer3DCompile(SanaVideoTransformer3DTesterConfig, TorchCompileTesterMixin):
"""Torch compile tests for Sana Video Transformer 3D."""
class TestSanaVideoTransformer3DBitsAndBytes(SanaVideoTransformer3DTesterConfig, BitsAndBytesTesterMixin):
"""BitsAndBytes quantization tests for Sana Video Transformer 3D."""
class TestSanaVideoTransformer3DTorchAo(SanaVideoTransformer3DTesterConfig, TorchAoTesterMixin):
"""TorchAO quantization tests for Sana Video Transformer 3D."""

View File

@@ -18,7 +18,7 @@ import unittest
import torch
from diffusers import DDIMScheduler, DDPMScheduler, UNet2DModel
from diffusers.training_utils import compute_confidence_aware_loss, set_seed
from diffusers.training_utils import set_seed
from ..testing_utils import slow
@@ -85,47 +85,3 @@ class TrainingTests(unittest.TestCase):
self.assertTrue(torch.allclose(ddpm_noisy_images, ddim_noisy_images, atol=1e-5))
self.assertTrue(torch.allclose(ddpm_noise_pred, ddim_noise_pred, atol=1e-5))
def test_confidence_aware_loss(self):
logits = torch.tensor([[[5.0, 0.0], [0.0, 5.0]]])
labels = torch.tensor([[0, 0]])
weights = torch.tensor([[1.0, 2.0]])
loss, loss_sft, loss_conf = compute_confidence_aware_loss(
logits, labels, lambda_conf=0.0, per_token_weights=weights
)
self.assertTrue(torch.allclose(loss, loss_sft))
self.assertTrue(torch.allclose(loss_conf, torch.zeros_like(loss_conf)))
lambda_conf = 0.25
loss, loss_sft, loss_conf = compute_confidence_aware_loss(
logits, labels, lambda_conf=lambda_conf, per_token_weights=weights
)
# Manual expected values for the small 2-class case.
per_token_nll = torch.nn.functional.cross_entropy(logits.view(-1, 2), labels.view(-1), reduction="none").view(
1, 2
)
expected_sft = (per_token_nll * weights).sum() / weights.sum()
pred = logits.argmax(dim=-1)
correct = pred.eq(labels)
log_probs = torch.log_softmax(logits.float(), dim=-1)
probs = log_probs.exp()
entropy = -(probs * log_probs).sum(dim=-1).to(dtype=logits.dtype)
expected_conf = (entropy * weights * correct.to(entropy.dtype)).sum() / (
weights * correct.to(weights.dtype)
).sum().clamp_min(1)
expected = expected_sft + lambda_conf * expected_conf
self.assertTrue(torch.allclose(loss_sft, expected_sft))
self.assertTrue(torch.allclose(loss_conf, expected_conf))
self.assertTrue(torch.allclose(loss, expected))
# Temperature affects only the confidence term.
loss_t, loss_sft_t, loss_conf_t = compute_confidence_aware_loss(
logits, labels, lambda_conf=lambda_conf, temperature=0.5, per_token_weights=weights
)
self.assertTrue(torch.allclose(loss_sft_t, expected_sft))
self.assertFalse(torch.allclose(loss_conf_t, expected_conf))
self.assertTrue(torch.allclose(loss_t, loss_sft_t + lambda_conf * loss_conf_t))

View File

@@ -13,10 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import os
import unittest
import warnings
import pytest
@@ -184,25 +182,6 @@ class DeprecateTester(unittest.TestCase):
assert str(warning.warning) == "This message is better!!!"
assert "diffusers/tests/others/test_utils.py" in warning.filename
def test_deprecate_testing_utils_module(self):
import diffusers.utils.testing_utils
with warnings.catch_warnings(record=True) as caught_warnings:
warnings.simplefilter("always")
importlib.reload(diffusers.utils.testing_utils)
deprecation_warnings = [w for w in caught_warnings if issubclass(w.category, FutureWarning)]
assert len(deprecation_warnings) >= 1, "Expected at least one FutureWarning from diffusers.utils.testing_utils"
messages = [str(w.message) for w in deprecation_warnings]
assert any("diffusers.utils.testing_utils" in msg for msg in messages), (
f"Expected a deprecation warning mentioning 'diffusers.utils.testing_utils', got: {messages}"
)
assert any(
"diffusers.utils.testing_utils is deprecated and will be removed in a future version." in msg
for msg in messages
), f"Expected deprecation message substring not found, got: {messages}"
# Copied from https://github.com/huggingface/transformers/blob/main/tests/utils/test_expectations.py
class ExpectationsTester(unittest.TestCase):

View File

@@ -1,245 +0,0 @@
import unittest
import torch
from diffusers import BlockRefinementScheduler, LLaDA2Pipeline
class _DummyModelOutput:
def __init__(self, logits):
self.logits = logits
class _DummyCausalLM(torch.nn.Module):
def __init__(self, vocab_size: int):
super().__init__()
self.vocab_size = int(vocab_size)
self.register_buffer("_device_anchor", torch.empty(0))
@property
def dtype(self):
return torch.float32
@property
def device(self):
return self._device_anchor.device
def forward(self, input_ids, attention_mask=None, position_ids=None, **kwargs):
batch_size, seq_len = input_ids.shape
logits = torch.zeros((batch_size, seq_len, self.vocab_size), device=input_ids.device, dtype=torch.float32)
# Make confidence vary with token position so top-k commits are deterministic.
positions = torch.arange(seq_len, device=input_ids.device, dtype=torch.float32).view(1, seq_len, 1)
token_ids = (torch.arange(seq_len, device=input_ids.device) % (self.vocab_size - 2)).view(1, seq_len, 1)
logits.scatter_(2, token_ids.expand(batch_size, -1, -1), 1.0 + positions.expand(batch_size, -1, -1) * 0.1)
return _DummyModelOutput(logits=logits)
def _make_pipeline(tokenizer=None):
model = _DummyCausalLM(vocab_size=32)
scheduler = BlockRefinementScheduler()
return LLaDA2Pipeline(model=model, scheduler=scheduler, tokenizer=tokenizer)
class LLaDA2PipelineTest(unittest.TestCase):
def test_pipeline_runs(self):
pipe = _make_pipeline().to("cpu")
input_ids = torch.tensor([[5, 6, 7, 8], [1, 2, 3, 4]], dtype=torch.long)
out = pipe(
input_ids=input_ids,
use_chat_template=False,
gen_length=24,
block_length=8,
num_inference_steps=8,
temperature=0.0,
threshold=2.0, # force top-k commits
minimal_topk=1,
eos_early_stop=False,
mask_token_id=31,
eos_token_id=None,
output_type="seq",
)
self.assertEqual(out.sequences.shape, (2, 24))
self.assertFalse((out.sequences == 31).any().item())
def test_pipeline_return_tuple(self):
pipe = _make_pipeline().to("cpu")
input_ids = torch.tensor([[5, 6, 7, 8]], dtype=torch.long)
sequences, texts = pipe(
input_ids=input_ids,
use_chat_template=False,
gen_length=16,
block_length=8,
num_inference_steps=4,
temperature=0.0,
threshold=2.0,
minimal_topk=1,
eos_early_stop=False,
mask_token_id=31,
output_type="seq",
return_dict=False,
)
self.assertEqual(sequences.shape, (1, 16))
self.assertIsNone(texts)
def test_output_type_seq(self):
"""output_type='seq' should return sequences but no texts."""
pipe = _make_pipeline().to("cpu")
out = pipe(
input_ids=torch.tensor([[5, 6, 7, 8]], dtype=torch.long),
use_chat_template=False,
gen_length=16,
block_length=8,
num_inference_steps=4,
temperature=0.0,
threshold=2.0,
minimal_topk=1,
eos_early_stop=False,
mask_token_id=31,
output_type="seq",
)
self.assertIsNotNone(out.sequences)
self.assertEqual(out.sequences.shape, (1, 16))
self.assertIsNone(out.texts)
def test_output_type_text_without_tokenizer(self):
"""output_type='text' without a tokenizer should return texts=None."""
pipe = _make_pipeline(tokenizer=None).to("cpu")
out = pipe(
input_ids=torch.tensor([[5, 6, 7, 8]], dtype=torch.long),
use_chat_template=False,
gen_length=16,
block_length=8,
num_inference_steps=4,
temperature=0.0,
threshold=2.0,
minimal_topk=1,
eos_early_stop=False,
mask_token_id=31,
output_type="text",
)
self.assertIsNotNone(out.sequences)
self.assertIsNone(out.texts)
def test_output_type_text_with_tokenizer(self):
"""output_type='text' with a tokenizer should return decoded texts."""
tok = type(
"Tok",
(),
{
"eos_token_id": None,
"mask_token_id": 31,
"batch_decode": lambda self, seqs, **kw: [f"decoded_{len(s)}" for s in seqs],
},
)()
pipe = _make_pipeline(tokenizer=tok).to("cpu")
out = pipe(
input_ids=torch.tensor([[5, 6, 7, 8]], dtype=torch.long),
use_chat_template=False,
gen_length=16,
block_length=8,
num_inference_steps=4,
temperature=0.0,
threshold=2.0,
minimal_topk=1,
eos_early_stop=False,
output_type="text",
)
self.assertIsNotNone(out.sequences)
self.assertIsNotNone(out.texts)
self.assertEqual(len(out.texts), 1)
self.assertTrue(out.texts[0].startswith("decoded_"))
def test_output_type_invalid_raises(self):
"""Invalid output_type should raise ValueError."""
pipe = _make_pipeline().to("cpu")
with self.assertRaises(ValueError):
pipe(
input_ids=torch.tensor([[5, 6, 7, 8]], dtype=torch.long),
use_chat_template=False,
gen_length=16,
block_length=8,
num_inference_steps=4,
mask_token_id=31,
output_type="invalid",
)
def test_prepare_input_ids_from_tensor(self):
pipe = _make_pipeline()
ids = torch.tensor([[1, 2, 3]], dtype=torch.long)
result = pipe._prepare_input_ids(
prompt=None,
messages=None,
input_ids=ids,
use_chat_template=False,
add_generation_prompt=False,
chat_template_kwargs=None,
)
self.assertTrue(torch.equal(result, ids))
def test_prepare_input_ids_from_1d_tensor(self):
pipe = _make_pipeline()
ids = torch.tensor([1, 2, 3], dtype=torch.long)
result = pipe._prepare_input_ids(
prompt=None,
messages=None,
input_ids=ids,
use_chat_template=False,
add_generation_prompt=False,
chat_template_kwargs=None,
)
self.assertEqual(result.shape, (1, 3))
def test_prepare_input_ids_no_tokenizer_raises(self):
pipe = _make_pipeline(tokenizer=None)
with self.assertRaises(ValueError):
pipe._prepare_input_ids(
prompt="hello",
messages=None,
input_ids=None,
use_chat_template=False,
add_generation_prompt=False,
chat_template_kwargs=None,
)
def test_prepare_input_ids_both_prompt_and_messages_raises(self):
pipe = _make_pipeline()
# Manually set tokenizer to a simple object so _prepare_input_ids doesn't short-circuit
pipe.tokenizer = type("Tok", (), {"eos_token_id": None, "mask_token_id": None})()
with self.assertRaises(ValueError):
pipe._prepare_input_ids(
prompt="hello",
messages=[{"role": "user", "content": "hi"}],
input_ids=None,
use_chat_template=False,
add_generation_prompt=False,
chat_template_kwargs=None,
)
def test_prepare_input_ids_neither_raises(self):
pipe = _make_pipeline()
pipe.tokenizer = type("Tok", (), {"eos_token_id": None, "mask_token_id": None})()
with self.assertRaises(ValueError):
pipe._prepare_input_ids(
prompt=None,
messages=None,
input_ids=None,
use_chat_template=False,
add_generation_prompt=False,
chat_template_kwargs=None,
)
if __name__ == "__main__":
unittest.main()

View File

@@ -1534,18 +1534,14 @@ class PipelineTesterMixin:
pipe.set_progress_bar_config(disable=None)
pipe.to("cpu")
model_devices = [
component.device.type for component in components.values() if getattr(component, "device", None)
]
model_devices = [component.device.type for component in components.values() if hasattr(component, "device")]
self.assertTrue(all(device == "cpu" for device in model_devices))
output_cpu = pipe(**self.get_dummy_inputs("cpu"))[0]
self.assertTrue(np.isnan(output_cpu).sum() == 0)
pipe.to(torch_device)
model_devices = [
component.device.type for component in components.values() if getattr(component, "device", None)
]
model_devices = [component.device.type for component in components.values() if hasattr(component, "device")]
self.assertTrue(all(device == torch_device for device in model_devices))
output_device = pipe(**self.get_dummy_inputs(torch_device))[0]
@@ -1556,11 +1552,11 @@ class PipelineTesterMixin:
pipe = self.pipeline_class(**components)
pipe.set_progress_bar_config(disable=None)
model_dtypes = [component.dtype for component in components.values() if getattr(component, "dtype", None)]
model_dtypes = [component.dtype for component in components.values() if hasattr(component, "dtype")]
self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes))
pipe.to(dtype=torch.float16)
model_dtypes = [component.dtype for component in components.values() if getattr(component, "dtype", None)]
model_dtypes = [component.dtype for component in components.values() if hasattr(component, "dtype")]
self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes))
def test_attention_slicing_forward_pass(self, expected_max_diff=1e-3):

View File

@@ -1,470 +0,0 @@
import tempfile
import unittest
import torch
from diffusers import BlockRefinementScheduler
class BlockRefinementSchedulerTest(unittest.TestCase):
def get_scheduler(self, **kwargs):
config = {
"block_length": 32,
"num_inference_steps": 8,
"threshold": 0.95,
"editing_threshold": None,
"minimal_topk": 1,
}
config.update(kwargs)
return BlockRefinementScheduler(**config)
def _make_logits_from_probs(self, target_probs: torch.Tensor, vocab_size: int = 100) -> torch.Tensor:
"""Create logits where softmax of the target token has approximately the given probability."""
batch_size, block_length = target_probs.shape
logits = torch.zeros(batch_size, block_length, vocab_size)
# Set token 0 as the "predicted" token with a logit proportional to desired probability
for b in range(batch_size):
for t in range(block_length):
p = target_probs[b, t].item()
if p > 0:
logits[b, t, t % (vocab_size - 1)] = 10.0 * p
return logits
def test_set_timesteps(self):
scheduler = self.get_scheduler()
scheduler.set_timesteps(8)
self.assertEqual(scheduler.num_inference_steps, 8)
self.assertEqual(len(scheduler.timesteps), 8)
self.assertEqual(scheduler.timesteps[0].item(), 7)
self.assertEqual(scheduler.timesteps[-1].item(), 0)
def test_set_timesteps_invalid(self):
scheduler = self.get_scheduler()
with self.assertRaises(ValueError):
scheduler.set_timesteps(0)
def test_get_num_transfer_tokens_even(self):
scheduler = self.get_scheduler()
schedule = scheduler.get_num_transfer_tokens(block_length=32, num_inference_steps=8)
self.assertEqual(schedule.sum().item(), 32)
self.assertEqual(len(schedule), 8)
self.assertTrue((schedule == 4).all().item())
def test_get_num_transfer_tokens_remainder(self):
scheduler = self.get_scheduler()
schedule = scheduler.get_num_transfer_tokens(block_length=10, num_inference_steps=3)
self.assertEqual(schedule.sum().item(), 10)
self.assertEqual(len(schedule), 3)
self.assertEqual(schedule[0].item(), 4)
self.assertEqual(schedule[1].item(), 3)
self.assertEqual(schedule[2].item(), 3)
def test_transfer_schedule_created_on_set_timesteps(self):
scheduler = self.get_scheduler(block_length=16)
scheduler.set_timesteps(4)
self.assertIsNotNone(scheduler._transfer_schedule)
self.assertEqual(scheduler._transfer_schedule.sum().item(), 16)
def test_save_load_config_round_trip(self):
scheduler = self.get_scheduler(block_length=64, threshold=0.8, editing_threshold=0.5, minimal_topk=2)
with tempfile.TemporaryDirectory() as tmpdir:
scheduler.save_config(tmpdir)
loaded = BlockRefinementScheduler.from_pretrained(tmpdir)
self.assertEqual(loaded.config.block_length, 64)
self.assertEqual(loaded.config.threshold, 0.8)
self.assertEqual(loaded.config.editing_threshold, 0.5)
self.assertEqual(loaded.config.minimal_topk, 2)
def test_from_config(self):
scheduler = self.get_scheduler(block_length=16, threshold=0.7)
new_scheduler = BlockRefinementScheduler.from_config(scheduler.config)
self.assertEqual(new_scheduler.config.block_length, 16)
self.assertEqual(new_scheduler.config.threshold, 0.7)
def test_step_commits_tokens(self):
"""Verify that step() commits mask tokens based on confidence."""
scheduler = self.get_scheduler(block_length=8)
scheduler.set_timesteps(2)
batch_size, block_length, vocab_size = 1, 8, 32
mask_id = 31
sample = torch.full((batch_size, block_length), mask_id, dtype=torch.long)
# Create logits where confidence decreases with position
logits = torch.zeros(batch_size, block_length, vocab_size)
for i in range(block_length):
logits[0, i, i] = 10.0 - i # decreasing confidence
out = scheduler.step(
model_output=logits,
timestep=0,
sample=sample,
mask_token_id=mask_id,
temperature=0.0,
threshold=0.95,
return_dict=True,
)
# With 8 tokens and 2 steps, first step should commit 4 tokens
committed = out.transfer_index[0].sum().item()
self.assertEqual(committed, 4)
def test_step_no_editing_by_default(self):
"""Without editing_threshold, no non-mask tokens should be changed."""
scheduler = self.get_scheduler(block_length=4)
scheduler.set_timesteps(2)
vocab_size = 32
sample = torch.tensor([[10, 20, 31, 31]], dtype=torch.long)
logits = torch.zeros(1, 4, vocab_size)
logits[0, :, 15] = 10.0 # predict token 15 for all positions
out = scheduler.step(
model_output=logits,
timestep=0,
sample=sample,
mask_token_id=31,
temperature=0.0,
editing_threshold=None,
return_dict=True,
)
self.assertFalse(out.editing_transfer_index.any().item())
self.assertFalse(out.transfer_index[0, 0].item())
self.assertFalse(out.transfer_index[0, 1].item())
def test_step_editing_replaces_tokens(self):
"""With editing_threshold, non-mask tokens with high confidence and different prediction get replaced."""
scheduler = self.get_scheduler(block_length=4)
scheduler.set_timesteps(2)
vocab_size = 32
sample = torch.tensor([[10, 20, 31, 31]], dtype=torch.long)
logits = torch.zeros(1, 4, vocab_size)
# Token 0: predict 50 (different from 10) with very high logit
logits[0, 0, 15] = 20.0
# Token 1: predict 20 (same as current)
logits[0, 1, 20] = 20.0
# Mask tokens
logits[0, 2, 5] = 5.0
logits[0, 3, 6] = 5.0
out = scheduler.step(
model_output=logits,
timestep=0,
sample=sample,
mask_token_id=31,
temperature=0.0,
editing_threshold=0.5,
return_dict=True,
)
# Token 0 should be edited (different prediction, high confidence)
self.assertTrue(out.editing_transfer_index[0, 0].item())
# Token 1 should NOT be edited (same prediction)
self.assertFalse(out.editing_transfer_index[0, 1].item())
def test_step_prompt_mask_prevents_editing(self):
"""Prompt positions should never be edited even with editing enabled."""
scheduler = self.get_scheduler(block_length=4)
scheduler.set_timesteps(2)
vocab_size = 32
sample = torch.tensor([[10, 20, 31, 31]], dtype=torch.long)
logits = torch.zeros(1, 4, vocab_size)
logits[0, :, 15] = 20.0
prompt_mask = torch.tensor([True, True, False, False])
out = scheduler.step(
model_output=logits,
timestep=0,
sample=sample,
mask_token_id=31,
temperature=0.0,
editing_threshold=0.5,
prompt_mask=prompt_mask,
return_dict=True,
)
self.assertFalse(out.editing_transfer_index[0, 0].item())
self.assertFalse(out.editing_transfer_index[0, 1].item())
def test_step_return_tuple(self):
"""Verify tuple output when return_dict=False."""
scheduler = self.get_scheduler(block_length=4)
scheduler.set_timesteps(2)
vocab_size = 32
sample = torch.full((1, 4), 31, dtype=torch.long)
logits = torch.randn(1, 4, vocab_size)
result = scheduler.step(
model_output=logits,
timestep=0,
sample=sample,
mask_token_id=31,
temperature=0.0,
return_dict=False,
)
self.assertIsInstance(result, tuple)
self.assertEqual(len(result), 5)
def test_step_batched(self):
"""Verify step works with batch_size > 1."""
scheduler = self.get_scheduler(block_length=4)
scheduler.set_timesteps(2)
batch_size, vocab_size = 3, 32
mask_id = 31
sample = torch.full((batch_size, 4), mask_id, dtype=torch.long)
logits = torch.randn(batch_size, 4, vocab_size)
out = scheduler.step(
model_output=logits,
timestep=0,
sample=sample,
mask_token_id=mask_id,
temperature=0.0,
return_dict=True,
)
self.assertEqual(out.prev_sample.shape, (batch_size, 4))
self.assertEqual(out.transfer_index.shape, (batch_size, 4))
def test_check_block_should_continue_finished(self):
scheduler = self.get_scheduler()
scheduler.set_timesteps(8)
finished = torch.tensor([True, True])
result = scheduler.check_block_should_continue(
step_idx=0,
masks_remaining=True,
editing_enabled=False,
editing_transfer_index=torch.zeros(2, 32, dtype=torch.bool),
post_steps=0,
max_post_steps=16,
finished=finished,
)
self.assertFalse(result)
def test_check_block_should_continue_no_masks_no_edits(self):
scheduler = self.get_scheduler()
scheduler.set_timesteps(8)
finished = torch.tensor([False])
result = scheduler.check_block_should_continue(
step_idx=5,
masks_remaining=False,
editing_enabled=True,
editing_transfer_index=torch.zeros(1, 32, dtype=torch.bool),
post_steps=1,
max_post_steps=16,
finished=finished,
)
self.assertFalse(result)
def test_check_block_should_continue_steps_exhausted(self):
scheduler = self.get_scheduler()
scheduler.set_timesteps(8)
finished = torch.tensor([False])
result = scheduler.check_block_should_continue(
step_idx=8,
masks_remaining=True,
editing_enabled=False,
editing_transfer_index=torch.zeros(1, 32, dtype=torch.bool),
post_steps=0,
max_post_steps=16,
finished=finished,
)
self.assertFalse(result)
def test_check_eos_finished_marks_batch(self):
"""When EOS is committed and all tokens before it are unmasked, mark batch as finished."""
mask_id, eos_id, prompt_length = 99, 2, 2
# cur_x: [prompt, prompt, token, eos, mask, mask]
cur_x = torch.tensor([[10, 11, 5, eos_id, mask_id, mask_id]], dtype=torch.long)
sampled_tokens = torch.tensor([[0, 0, 0, eos_id]], dtype=torch.long)
final_transfer = torch.tensor([[False, False, False, True]])
finished = torch.tensor([False])
finished = BlockRefinementScheduler.check_eos_finished(
cur_x=cur_x,
sampled_tokens=sampled_tokens,
final_transfer=final_transfer,
finished=finished,
eos_token_id=eos_id,
mask_token_id=mask_id,
prompt_length=prompt_length,
)
self.assertTrue(finished[0].item())
def test_check_eos_finished_ignores_when_masks_before_eos(self):
"""If there are still mask tokens between prompt and EOS, don't mark as finished."""
mask_id, eos_id, prompt_length = 99, 2, 2
# cur_x: [prompt, prompt, mask, eos] — mask before EOS
cur_x = torch.tensor([[10, 11, mask_id, eos_id]], dtype=torch.long)
sampled_tokens = torch.tensor([[0, 0]], dtype=torch.long)
final_transfer = torch.tensor([[False, True]])
finished = torch.tensor([False])
finished = BlockRefinementScheduler.check_eos_finished(
cur_x=cur_x,
sampled_tokens=sampled_tokens,
final_transfer=final_transfer,
finished=finished,
eos_token_id=eos_id,
mask_token_id=mask_id,
prompt_length=prompt_length,
)
self.assertFalse(finished[0].item())
def test_check_eos_finished_already_finished(self):
"""Already-finished batches should stay finished."""
mask_id, eos_id = 99, 2
cur_x = torch.tensor([[10, 11, 5, 6]], dtype=torch.long)
sampled_tokens = torch.tensor([[0, 0]], dtype=torch.long)
final_transfer = torch.tensor([[False, False]])
finished = torch.tensor([True])
finished = BlockRefinementScheduler.check_eos_finished(
cur_x=cur_x,
sampled_tokens=sampled_tokens,
final_transfer=final_transfer,
finished=finished,
eos_token_id=eos_id,
mask_token_id=mask_id,
prompt_length=2,
)
self.assertTrue(finished[0].item())
def test_add_noise(self):
scheduler = self.get_scheduler(block_length=4)
input_ids = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]], dtype=torch.long)
attention_mask = torch.ones_like(input_ids)
mask_token_id = 99
gen = torch.Generator().manual_seed(42)
noisy, noisy_rev, masked, masked_rev = scheduler.add_noise(
input_ids,
attention_mask,
prompt_length=2,
block_length=4,
mask_token_id=mask_token_id,
generator=gen,
)
# Prompt positions should never be masked
self.assertFalse(masked[0, 0].item())
self.assertFalse(masked[0, 1].item())
self.assertFalse(masked_rev[0, 0].item())
self.assertFalse(masked_rev[0, 1].item())
# Noisy should have mask_token_id where masked is True
self.assertTrue((noisy[masked] == mask_token_id).all().item())
self.assertTrue((noisy_rev[masked_rev] == mask_token_id).all().item())
# masked and masked_rev should be complementary within valid non-prompt positions
non_prompt = torch.zeros_like(masked)
non_prompt[0, 2:] = True
combined = masked | masked_rev
self.assertTrue((combined[0, 2:] == non_prompt[0, 2:]).all().item())
class TestTopPFiltering(unittest.TestCase):
def test_top_p_filtering(self):
logits = torch.tensor([[1.0, 2.0, 3.0, 4.0]])
filtered = BlockRefinementScheduler._top_p_filtering(logits, top_p=0.5)
self.assertTrue((filtered > torch.finfo(filtered.dtype).min).any())
self.assertTrue((filtered == torch.finfo(filtered.dtype).min).any())
def test_top_p_filtering_none(self):
logits = torch.tensor([[1.0, 2.0, 3.0]])
result = BlockRefinementScheduler._top_p_filtering(logits, top_p=None)
self.assertTrue(torch.equal(result, logits))
def test_top_p_filtering_one(self):
logits = torch.tensor([[1.0, 2.0, 3.0]])
result = BlockRefinementScheduler._top_p_filtering(logits, top_p=1.0)
self.assertTrue(torch.equal(result, logits))
class TestTopKFiltering(unittest.TestCase):
def test_top_k_filtering(self):
logits = torch.tensor([[1.0, 4.0, 2.0, 3.0]])
filtered = BlockRefinementScheduler._top_k_filtering(logits, top_k=2)
self.assertAlmostEqual(filtered[0, 1].item(), 4.0)
self.assertAlmostEqual(filtered[0, 3].item(), 3.0)
self.assertEqual(filtered[0, 0].item(), torch.finfo(filtered.dtype).min)
self.assertEqual(filtered[0, 2].item(), torch.finfo(filtered.dtype).min)
def test_top_k_filtering_none(self):
logits = torch.tensor([[1.0, 2.0, 3.0]])
result = BlockRefinementScheduler._top_k_filtering(logits, top_k=None)
self.assertTrue(torch.equal(result, logits))
def test_top_k_filtering_zero(self):
logits = torch.tensor([[1.0, 2.0, 3.0]])
result = BlockRefinementScheduler._top_k_filtering(logits, top_k=0)
self.assertTrue(torch.equal(result, logits))
def test_top_k_filtering_large_k(self):
logits = torch.tensor([[1.0, 2.0, 3.0]])
result = BlockRefinementScheduler._top_k_filtering(logits, top_k=100)
self.assertTrue(torch.equal(result, logits))
class TestSampleFromLogits(unittest.TestCase):
def test_greedy_sampling(self):
logits = torch.tensor([[1.0, 5.0, 2.0]])
tokens, probs = BlockRefinementScheduler._sample_from_logits(
logits,
temperature=0.0,
top_k=None,
top_p=None,
generator=None,
use_multinomial=False,
)
self.assertEqual(tokens.item(), 1)
self.assertEqual(tokens.shape, (1,))
self.assertEqual(probs.shape, (1,))
def test_multinomial_sampling(self):
logits = torch.tensor([[0.0, 100.0, -100.0]])
gen = torch.Generator().manual_seed(42)
tokens, probs = BlockRefinementScheduler._sample_from_logits(
logits,
temperature=1.0,
top_k=None,
top_p=None,
generator=gen,
use_multinomial=True,
)
self.assertEqual(tokens.item(), 1)
def test_temperature_scaling(self):
logits = torch.tensor([[1.0, 2.0, 3.0]])
tokens, _ = BlockRefinementScheduler._sample_from_logits(
logits,
temperature=0.01,
top_k=None,
top_p=None,
generator=None,
use_multinomial=False,
)
self.assertEqual(tokens.item(), 2)
def test_negative_temperature_raises(self):
logits = torch.tensor([[1.0, 2.0]])
with self.assertRaises(ValueError):
BlockRefinementScheduler._sample_from_logits(
logits,
temperature=-1.0,
top_k=None,
top_p=None,
generator=None,
use_multinomial=False,
)
if __name__ == "__main__":
unittest.main()

View File

@@ -43,7 +43,7 @@ def filter_pipelines(usage_dict, usage_cutoff=10000):
def fetch_pipeline_objects():
models = api.list_models(filter="diffusers")
models = api.list_models(library="diffusers")
downloads = defaultdict(int)
for model in models: