mirror of
https://github.com/huggingface/diffusers.git
synced 2026-03-27 02:47:41 +08:00
Compare commits
23 Commits
release-wo
...
sd3-test-r
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6ec4dee783 | ||
|
|
50015c966a | ||
|
|
762ae059fa | ||
|
|
5d207e756e | ||
|
|
e358ddcce6 | ||
|
|
153fcbc5a8 | ||
|
|
da6718f080 | ||
|
|
832676d35e | ||
|
|
7bbd96da5d | ||
|
|
62777fa819 | ||
|
|
f1fd515257 | ||
|
|
afdda57f61 | ||
|
|
5fc2bd2c8f | ||
|
|
6350a7690a | ||
|
|
9d4c9dcf21 | ||
|
|
ef309a1bb0 | ||
|
|
b9761ce5a2 | ||
|
|
52558b45d8 | ||
|
|
c02c17c6ee | ||
|
|
a9855c4204 | ||
|
|
0b35834351 | ||
|
|
522b523e40 | ||
|
|
e9b9f25f67 |
8
.github/workflows/release_tests_fast.yml
vendored
8
.github/workflows/release_tests_fast.yml
vendored
@@ -4,6 +4,7 @@
|
||||
name: (Release) Fast GPU Tests on main
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
push:
|
||||
branches:
|
||||
- "v*.*.*-release"
|
||||
@@ -33,6 +34,7 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
uv pip install -e ".[quality]"
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
- name: Environment
|
||||
run: |
|
||||
python utils/print_env.py
|
||||
@@ -74,6 +76,7 @@ jobs:
|
||||
run: |
|
||||
uv pip install -e ".[quality]"
|
||||
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
- name: Environment
|
||||
run: |
|
||||
python utils/print_env.py
|
||||
@@ -125,6 +128,7 @@ jobs:
|
||||
uv pip install -e ".[quality]"
|
||||
uv pip install peft@git+https://github.com/huggingface/peft.git
|
||||
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
@@ -175,6 +179,7 @@ jobs:
|
||||
uv pip install -e ".[quality]"
|
||||
uv pip install peft@git+https://github.com/huggingface/peft.git
|
||||
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
@@ -232,6 +237,7 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
uv pip install -e ".[quality,training]"
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
- name: Environment
|
||||
run: |
|
||||
python utils/print_env.py
|
||||
@@ -274,6 +280,7 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
uv pip install -e ".[quality,training]"
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
- name: Environment
|
||||
run: |
|
||||
python utils/print_env.py
|
||||
@@ -316,6 +323,7 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
uv pip install -e ".[quality,training]"
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
|
||||
@@ -446,6 +446,10 @@
|
||||
title: AutoencoderKLHunyuanVideo
|
||||
- local: api/models/autoencoder_kl_hunyuan_video15
|
||||
title: AutoencoderKLHunyuanVideo15
|
||||
- local: api/models/autoencoder_kl_kvae
|
||||
title: AutoencoderKLKVAE
|
||||
- local: api/models/autoencoder_kl_kvae_video
|
||||
title: AutoencoderKLKVAEVideo
|
||||
- local: api/models/autoencoderkl_audio_ltx_2
|
||||
title: AutoencoderKLLTX2Audio
|
||||
- local: api/models/autoencoderkl_ltx_2
|
||||
@@ -666,6 +670,10 @@
|
||||
- local: api/pipelines/z_image
|
||||
title: Z-Image
|
||||
title: Image
|
||||
- sections:
|
||||
- local: api/pipelines/llada2
|
||||
title: LLaDA2
|
||||
title: Text
|
||||
- sections:
|
||||
- local: api/pipelines/allegro
|
||||
title: Allegro
|
||||
@@ -714,6 +722,8 @@
|
||||
- sections:
|
||||
- local: api/schedulers/overview
|
||||
title: Overview
|
||||
- local: api/schedulers/block_refinement
|
||||
title: BlockRefinementScheduler
|
||||
- local: api/schedulers/cm_stochastic_iterative
|
||||
title: CMStochasticIterativeScheduler
|
||||
- local: api/schedulers/ddim_cogvideox
|
||||
|
||||
32
docs/source/en/api/models/autoencoder_kl_kvae.md
Normal file
32
docs/source/en/api/models/autoencoder_kl_kvae.md
Normal file
@@ -0,0 +1,32 @@
|
||||
<!-- Copyright 2025 The Kandinsky Team and The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. -->
|
||||
|
||||
# AutoencoderKLKVAE
|
||||
|
||||
The 2D variational autoencoder (VAE) model with KL loss.
|
||||
|
||||
The model can be loaded with the following code snippet.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import AutoencoderKLKVAE
|
||||
|
||||
vae = AutoencoderKLKVAE.from_pretrained("kandinskylab/KVAE-2D-1.0", subfolder="diffusers", torch_dtype=torch.bfloat16)
|
||||
```
|
||||
|
||||
## AutoencoderKLKVAE
|
||||
|
||||
[[autodoc]] AutoencoderKLKVAE
|
||||
- decode
|
||||
- all
|
||||
33
docs/source/en/api/models/autoencoder_kl_kvae_video.md
Normal file
33
docs/source/en/api/models/autoencoder_kl_kvae_video.md
Normal file
@@ -0,0 +1,33 @@
|
||||
<!-- Copyright 2025 The Kandinsky Team and The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. -->
|
||||
|
||||
# AutoencoderKLKVAEVideo
|
||||
|
||||
The 3D variational autoencoder (VAE) model with KL loss.
|
||||
|
||||
The model can be loaded with the following code snippet.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import AutoencoderKLKVAEVideo
|
||||
|
||||
vae = AutoencoderKLKVAEVideo.from_pretrained("kandinskylab/KVAE-3D-1.0", subfolder="diffusers", torch_dtype=torch.float16)
|
||||
```
|
||||
|
||||
## AutoencoderKLKVAEVideo
|
||||
|
||||
[[autodoc]] AutoencoderKLKVAEVideo
|
||||
- decode
|
||||
- all
|
||||
|
||||
90
docs/source/en/api/pipelines/llada2.md
Normal file
90
docs/source/en/api/pipelines/llada2.md
Normal file
@@ -0,0 +1,90 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# LLaDA2
|
||||
|
||||
[LLaDA2](https://huggingface.co/collections/inclusionAI/llada21) is a family of discrete diffusion language models
|
||||
that generate text through block-wise iterative refinement. Instead of autoregressive token-by-token generation,
|
||||
LLaDA2 starts with a fully masked sequence and progressively unmasks tokens by confidence over multiple refinement
|
||||
steps.
|
||||
|
||||
## Usage
|
||||
|
||||
```py
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from diffusers import BlockRefinementScheduler, LLaDA2Pipeline
|
||||
|
||||
model_id = "inclusionAI/LLaDA2.1-mini"
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id, trust_remote_code=True, dtype=torch.bfloat16, device_map="auto"
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
|
||||
scheduler = BlockRefinementScheduler()
|
||||
|
||||
pipe = LLaDA2Pipeline(model=model, scheduler=scheduler, tokenizer=tokenizer)
|
||||
output = pipe(
|
||||
prompt="Write a short poem about the ocean.",
|
||||
gen_length=256,
|
||||
block_length=32,
|
||||
num_inference_steps=32,
|
||||
threshold=0.7,
|
||||
editing_threshold=0.5,
|
||||
max_post_steps=16,
|
||||
temperature=0.0,
|
||||
)
|
||||
print(output.texts[0])
|
||||
```
|
||||
|
||||
## Callbacks
|
||||
|
||||
Callbacks run after each refinement step. Pass `callback_on_step_end_tensor_inputs` to select which tensors are
|
||||
included in `callback_kwargs`. In the current implementation, `block_x` (the sequence window being refined) and
|
||||
`transfer_index` (mask-filling commit mask) are provided; return `{"block_x": ...}` from the callback to replace the
|
||||
window.
|
||||
|
||||
```py
|
||||
def on_step_end(pipe, step, timestep, callback_kwargs):
|
||||
block_x = callback_kwargs["block_x"]
|
||||
# Inspect or modify `block_x` here.
|
||||
return {"block_x": block_x}
|
||||
|
||||
out = pipe(
|
||||
prompt="Write a short poem.",
|
||||
callback_on_step_end=on_step_end,
|
||||
callback_on_step_end_tensor_inputs=["block_x"],
|
||||
)
|
||||
```
|
||||
|
||||
## Recommended parameters
|
||||
|
||||
LLaDA2.1 models support two modes:
|
||||
|
||||
| Mode | `threshold` | `editing_threshold` | `max_post_steps` |
|
||||
|------|-------------|---------------------|------------------|
|
||||
| Quality | 0.7 | 0.5 | 16 |
|
||||
| Speed | 0.5 | `None` | 16 |
|
||||
|
||||
Pass `editing_threshold=None`, `0.0`, or a negative value to turn off post-mask editing.
|
||||
|
||||
For LLaDA2.0 models, disable editing by passing `editing_threshold=None` or `0.0`.
|
||||
|
||||
For all models: `block_length=32`, `temperature=0.0`, `num_inference_steps=32`.
|
||||
|
||||
## LLaDA2Pipeline
|
||||
[[autodoc]] LLaDA2Pipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## LLaDA2PipelineOutput
|
||||
[[autodoc]] pipelines.LLaDA2PipelineOutput
|
||||
@@ -63,6 +63,7 @@ The table below lists all the pipelines currently available in 🤗 Diffusers an
|
||||
| [Latent Diffusion](latent_diffusion) | text2image, super-resolution |
|
||||
| [Latte](latte) | text2image |
|
||||
| [LEDITS++](ledits_pp) | image editing |
|
||||
| [LLaDA2](llada2) | text2text |
|
||||
| [Lumina-T2X](lumina) | text2image |
|
||||
| [Marigold](marigold) | depth-estimation, normals-estimation, intrinsic-decomposition |
|
||||
| [MultiDiffusion](panorama) | text2image |
|
||||
|
||||
25
docs/source/en/api/schedulers/block_refinement.md
Normal file
25
docs/source/en/api/schedulers/block_refinement.md
Normal file
@@ -0,0 +1,25 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# BlockRefinementScheduler
|
||||
|
||||
The `BlockRefinementScheduler` manages block-wise iterative refinement for discrete token diffusion. At each step it
|
||||
commits the most confident tokens and optionally edits already-committed tokens when the model predicts a different
|
||||
token with high confidence.
|
||||
|
||||
This scheduler is used by [`LLaDA2Pipeline`].
|
||||
|
||||
## BlockRefinementScheduler
|
||||
[[autodoc]] BlockRefinementScheduler
|
||||
|
||||
## BlockRefinementSchedulerOutput
|
||||
[[autodoc]] schedulers.scheduling_block_refinement.BlockRefinementSchedulerOutput
|
||||
@@ -143,6 +143,7 @@ Refer to the table below for a complete list of available attention backends and
|
||||
| `flash_varlen` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention |
|
||||
| `flash_varlen_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention from kernels |
|
||||
| `aiter` | [AI Tensor Engine for ROCm](https://github.com/ROCm/aiter) | FlashAttention for AMD ROCm |
|
||||
| `flash_4_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-4 |
|
||||
| `_flash_3` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-3 |
|
||||
| `_flash_varlen_3` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention-3 |
|
||||
| `_flash_3_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-3 from kernels |
|
||||
|
||||
50
examples/discrete_diffusion/README.md
Normal file
50
examples/discrete_diffusion/README.md
Normal file
@@ -0,0 +1,50 @@
|
||||
# Discrete Token Diffusion (Experimental)
|
||||
|
||||
This folder contains **training and sampling examples** for *discrete diffusion over token IDs* (language-model style), built to follow the `diffusers` + `accelerate` training conventions.
|
||||
|
||||
## LLaDA2
|
||||
|
||||
[LLaDA2](https://huggingface.co/collections/inclusionAI/llada21) generates text through block-wise iterative refinement. Instead of autoregressive token-by-token generation, it starts with a fully masked sequence and progressively unmasks tokens by confidence over multiple refinement steps.
|
||||
|
||||
### Train
|
||||
|
||||
The training script uses confidence-aware loss and works with any causal LM from the Hub (e.g. Qwen, Llama, Mistral):
|
||||
|
||||
```bash
|
||||
accelerate launch examples/discrete_diffusion/train_llada2.py \
|
||||
--model_name_or_path Qwen/Qwen2.5-0.5B \
|
||||
--dataset_name wikitext \
|
||||
--dataset_config_name wikitext-2-raw-v1 \
|
||||
--text_column text \
|
||||
--output_dir llada2-output \
|
||||
--max_train_steps 1000 \
|
||||
--prompt_length 32 \
|
||||
--block_length 32 \
|
||||
--lambda_conf 2.0 \
|
||||
--conf_temperature 0.5
|
||||
```
|
||||
|
||||
If you don't want to download a dataset, you can use random-token data:
|
||||
|
||||
```bash
|
||||
accelerate launch examples/discrete_diffusion/train_llada2.py \
|
||||
--model_name_or_path Qwen/Qwen2.5-0.5B \
|
||||
--output_dir llada2-output \
|
||||
--use_dummy_data \
|
||||
--num_dummy_samples 2048
|
||||
```
|
||||
|
||||
### Sample
|
||||
|
||||
```bash
|
||||
python examples/discrete_diffusion/sample_llada2.py \
|
||||
--model_id inclusionAI/LLaDA2.1-mini \
|
||||
--prompt "Write a short poem about the ocean." \
|
||||
--gen_length 256 \
|
||||
--num_inference_steps 32 \
|
||||
--threshold 0.7 \
|
||||
--editing_threshold 0.5 \
|
||||
--max_post_steps 16 \
|
||||
--use_chat_template \
|
||||
--add_generation_prompt
|
||||
```
|
||||
263
examples/discrete_diffusion/sample_llada2.py
Normal file
263
examples/discrete_diffusion/sample_llada2.py
Normal file
@@ -0,0 +1,263 @@
|
||||
#!/usr/bin/env python
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Sample script for LLaDA2-style discrete diffusion text generation.
|
||||
|
||||
This script demonstrates how to use the LLaDA2Pipeline for text generation
|
||||
using block-wise iterative refinement.
|
||||
|
||||
Example usage:
|
||||
python sample_llada2.py --model_id inclusionAI/LLaDA2.0-mini --prompt "What is the capital of France?"
|
||||
python sample_llada2.py --model_id inclusionAI/LLaDA2.0-flash-CAP --prompt "Explain quantum computing." --temperature 0.7
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from diffusers import BlockRefinementScheduler, LLaDA2Pipeline
|
||||
from diffusers.hooks import apply_group_offloading
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Generate text using LLaDA2Pipeline with block-wise discrete diffusion."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_id",
|
||||
type=str,
|
||||
default="inclusionAI/LLaDA2.0-mini",
|
||||
help="HuggingFace model ID or path to local model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prompt",
|
||||
type=str,
|
||||
default="Why does Camus think that Sisyphus is happy?",
|
||||
help="Text prompt to generate from.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gen_length",
|
||||
type=int,
|
||||
default=2048,
|
||||
help="Number of tokens to generate.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--block_length",
|
||||
type=int,
|
||||
default=32,
|
||||
help="Size of each generation block.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_inference_steps",
|
||||
type=int,
|
||||
default=32,
|
||||
help="Number of refinement steps per block.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--temperature",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="Sampling temperature (0.0 for greedy).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--top_p",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Nucleus sampling probability threshold.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--top_k",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Top-k sampling parameter.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--threshold",
|
||||
type=float,
|
||||
default=0.95,
|
||||
help="Confidence threshold for committing tokens.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--editing_threshold",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Confidence threshold for editing already-committed tokens. Set to enable post-mask editing (e.g. 0.5).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_post_steps",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Maximum post-mask editing iterations per block (e.g. 16). Only used when --editing_threshold is set.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sampling_method",
|
||||
type=str,
|
||||
default="multinomial",
|
||||
choices=["auto", "greedy", "multinomial"],
|
||||
help="Sampling method for block refinement.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--eos_early_stop",
|
||||
action="store_true",
|
||||
help="Stop generation early when EOS token is generated.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_chat_template",
|
||||
action="store_true",
|
||||
help="Use the tokenizer chat template for the prompt.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--add_generation_prompt",
|
||||
action="store_true",
|
||||
help="Add the generation prompt when using the chat template.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default="cuda" if torch.cuda.is_available() else "cpu",
|
||||
help="Device to run inference on.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
type=str,
|
||||
default="bfloat16",
|
||||
choices=["float32", "float16", "bfloat16"],
|
||||
help="Model dtype.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Random seed for reproducibility.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--offload",
|
||||
type=str,
|
||||
default=None,
|
||||
choices=["group", "sequential"],
|
||||
help="Memory offloading strategy: 'group' for group offloading (faster), 'sequential' for sequential CPU offload (slower but lower memory).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--revision",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Model revision (branch, tag, or commit hash) to load from the Hub.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Parse dtype
|
||||
dtype_map = {
|
||||
"float32": torch.float32,
|
||||
"float16": torch.float16,
|
||||
"bfloat16": torch.bfloat16,
|
||||
}
|
||||
torch_dtype = dtype_map[args.dtype]
|
||||
|
||||
print(f"Loading model: {args.model_id}")
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model_id, trust_remote_code=True, revision=args.revision)
|
||||
|
||||
# Load model with appropriate memory settings based on offload strategy
|
||||
if args.offload == "group":
|
||||
# For group offloading, load to CPU first then apply hooks
|
||||
print("Using group offloading for memory efficiency...")
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
args.model_id,
|
||||
trust_remote_code=True,
|
||||
dtype=torch_dtype,
|
||||
low_cpu_mem_usage=True,
|
||||
revision=args.revision,
|
||||
)
|
||||
# Apply group offloading with CUDA streams for better performance
|
||||
onload_device = torch.device(args.device)
|
||||
offload_device = torch.device("cpu")
|
||||
apply_group_offloading(
|
||||
model,
|
||||
onload_device=onload_device,
|
||||
offload_device=offload_device,
|
||||
offload_type="leaf_level",
|
||||
use_stream=True,
|
||||
)
|
||||
elif args.offload == "sequential":
|
||||
# For sequential offloading, load to CPU first
|
||||
print("Using sequential CPU offloading (slower but lower memory)...")
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
args.model_id,
|
||||
trust_remote_code=True,
|
||||
dtype=torch_dtype,
|
||||
low_cpu_mem_usage=True,
|
||||
revision=args.revision,
|
||||
)
|
||||
# Sequential offloading will be applied via pipeline
|
||||
else:
|
||||
# Default: use device_map="auto" for automatic memory management
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
args.model_id,
|
||||
trust_remote_code=True,
|
||||
dtype=torch_dtype,
|
||||
device_map="auto",
|
||||
low_cpu_mem_usage=True,
|
||||
revision=args.revision,
|
||||
)
|
||||
model.eval()
|
||||
|
||||
# Create pipeline
|
||||
scheduler = BlockRefinementScheduler()
|
||||
pipe = LLaDA2Pipeline(model=model, scheduler=scheduler, tokenizer=tokenizer)
|
||||
|
||||
# Apply sequential CPU offload if requested
|
||||
if args.offload == "sequential":
|
||||
pipe.enable_sequential_cpu_offload()
|
||||
|
||||
# Set up generator for reproducibility
|
||||
generator = None
|
||||
if args.seed is not None:
|
||||
generator = torch.Generator(device=args.device).manual_seed(args.seed)
|
||||
|
||||
print(f"\nPrompt: {args.prompt}")
|
||||
print(
|
||||
f"Generating {args.gen_length} tokens with block_length={args.block_length}, steps={args.num_inference_steps}"
|
||||
)
|
||||
print("-" * 50)
|
||||
|
||||
# Generate
|
||||
output = pipe(
|
||||
prompt=args.prompt,
|
||||
use_chat_template=args.use_chat_template,
|
||||
add_generation_prompt=args.add_generation_prompt,
|
||||
gen_length=args.gen_length,
|
||||
block_length=args.block_length,
|
||||
num_inference_steps=args.num_inference_steps,
|
||||
temperature=args.temperature,
|
||||
top_p=args.top_p,
|
||||
top_k=args.top_k,
|
||||
threshold=args.threshold,
|
||||
editing_threshold=args.editing_threshold,
|
||||
max_post_steps=args.max_post_steps,
|
||||
sampling_method=args.sampling_method,
|
||||
eos_early_stop=args.eos_early_stop,
|
||||
generator=generator,
|
||||
)
|
||||
|
||||
print("\nGenerated text:")
|
||||
print(output.texts[0])
|
||||
|
||||
print(f"\nGenerated {output.sequences.shape[1]} tokens")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
321
examples/discrete_diffusion/train_llada2.py
Normal file
321
examples/discrete_diffusion/train_llada2.py
Normal file
@@ -0,0 +1,321 @@
|
||||
#!/usr/bin/env python
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import math
|
||||
import os
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
from accelerate import Accelerator
|
||||
from accelerate.logging import get_logger
|
||||
from accelerate.utils import ProjectConfiguration, set_seed
|
||||
from datasets import load_dataset
|
||||
from torch.utils.data import DataLoader
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, DataCollatorForLanguageModeling, get_scheduler
|
||||
|
||||
from diffusers import BlockRefinementScheduler
|
||||
from diffusers.training_utils import compute_confidence_aware_loss
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainConfig:
|
||||
model_name_or_path: str
|
||||
dataset_name: str
|
||||
dataset_config_name: Optional[str]
|
||||
text_column: str
|
||||
cache_dir: Optional[str]
|
||||
use_dummy_data: bool
|
||||
num_dummy_samples: int
|
||||
|
||||
output_dir: str
|
||||
seed: int
|
||||
max_train_steps: int
|
||||
checkpointing_steps: int
|
||||
logging_steps: int
|
||||
|
||||
per_device_train_batch_size: int
|
||||
gradient_accumulation_steps: int
|
||||
learning_rate: float
|
||||
weight_decay: float
|
||||
lr_scheduler: str
|
||||
lr_warmup_steps: int
|
||||
|
||||
max_length: int
|
||||
prompt_length: int
|
||||
block_length: int
|
||||
|
||||
lambda_conf: float
|
||||
conf_temperature: float
|
||||
|
||||
|
||||
def parse_args() -> TrainConfig:
|
||||
parser = argparse.ArgumentParser(description="Train block-refinement with a confidence-aware loss on a causal LM.")
|
||||
|
||||
parser.add_argument("--model_name_or_path", type=str, default="Qwen/Qwen2.5-0.5B")
|
||||
parser.add_argument("--dataset_name", type=str, default="wikitext")
|
||||
parser.add_argument("--dataset_config_name", type=str, default="wikitext-2-raw-v1")
|
||||
parser.add_argument("--text_column", type=str, default="text")
|
||||
parser.add_argument("--cache_dir", type=str, default=None)
|
||||
parser.add_argument("--use_dummy_data", action="store_true", help="Use random-token data instead of downloading.")
|
||||
parser.add_argument("--num_dummy_samples", type=int, default=2048)
|
||||
|
||||
parser.add_argument("--output_dir", type=str, default="block-refinement-output")
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument("--max_train_steps", type=int, default=1000)
|
||||
parser.add_argument("--checkpointing_steps", type=int, default=500)
|
||||
parser.add_argument("--logging_steps", type=int, default=50)
|
||||
|
||||
parser.add_argument("--per_device_train_batch_size", type=int, default=1)
|
||||
parser.add_argument("--gradient_accumulation_steps", type=int, default=8)
|
||||
parser.add_argument("--learning_rate", type=float, default=2e-5)
|
||||
parser.add_argument("--weight_decay", type=float, default=0.0)
|
||||
parser.add_argument(
|
||||
"--lr_scheduler", type=str, default="cosine", choices=["linear", "cosine", "cosine_with_restarts"]
|
||||
)
|
||||
parser.add_argument("--lr_warmup_steps", type=int, default=100)
|
||||
|
||||
parser.add_argument("--max_length", type=int, default=256)
|
||||
parser.add_argument("--prompt_length", type=int, default=32)
|
||||
parser.add_argument("--block_length", type=int, default=32)
|
||||
|
||||
parser.add_argument("--lambda_conf", type=float, default=2.0)
|
||||
parser.add_argument("--conf_temperature", type=float, default=0.5)
|
||||
|
||||
args = parser.parse_args()
|
||||
return TrainConfig(**vars(args))
|
||||
|
||||
|
||||
def tokenize_fn(examples: Dict, tokenizer, text_column: str, max_length: int):
|
||||
texts = examples[text_column]
|
||||
texts = [t for t in texts if isinstance(t, str) and len(t.strip()) > 0]
|
||||
return tokenizer(texts, truncation=True, padding=False, max_length=max_length)
|
||||
|
||||
|
||||
class RandomTokenDataset(torch.utils.data.Dataset):
|
||||
def __init__(self, *, num_samples: int, seq_len: int, vocab_size: int, pad_token_id: int):
|
||||
self.num_samples = int(num_samples)
|
||||
self.seq_len = int(seq_len)
|
||||
self.vocab_size = int(vocab_size)
|
||||
self.pad_token_id = int(pad_token_id)
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
||||
def __getitem__(self, idx):
|
||||
del idx
|
||||
input_ids = torch.randint(0, self.vocab_size, (self.seq_len,), dtype=torch.long)
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
return {"input_ids": input_ids, "attention_mask": attention_mask}
|
||||
|
||||
|
||||
def main():
|
||||
cfg = parse_args()
|
||||
if cfg.prompt_length >= cfg.max_length:
|
||||
raise ValueError("`prompt_length` must be < `max_length`.")
|
||||
if cfg.block_length <= 0:
|
||||
raise ValueError("`block_length` must be > 0.")
|
||||
|
||||
project_config = ProjectConfiguration(project_dir=cfg.output_dir, logging_dir=os.path.join(cfg.output_dir, "logs"))
|
||||
accelerator = Accelerator(
|
||||
gradient_accumulation_steps=cfg.gradient_accumulation_steps,
|
||||
project_config=project_config,
|
||||
)
|
||||
if accelerator.is_main_process:
|
||||
os.makedirs(cfg.output_dir, exist_ok=True)
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
set_seed(cfg.seed)
|
||||
logger.info("Training configuration: %s", asdict(cfg))
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(cfg.model_name_or_path, use_fast=True, cache_dir=cfg.cache_dir)
|
||||
if tokenizer.pad_token_id is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
if tokenizer.mask_token_id is None:
|
||||
tokenizer.add_special_tokens({"mask_token": "[MASK]"})
|
||||
|
||||
load_dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
|
||||
model = AutoModelForCausalLM.from_pretrained(cfg.model_name_or_path, cache_dir=cfg.cache_dir, dtype=load_dtype)
|
||||
model.resize_token_embeddings(len(tokenizer))
|
||||
if load_dtype == torch.float32:
|
||||
model.to(dtype=torch.float32)
|
||||
|
||||
mask_token_id = int(tokenizer.mask_token_id)
|
||||
|
||||
if cfg.use_dummy_data:
|
||||
dataset = RandomTokenDataset(
|
||||
num_samples=cfg.num_dummy_samples,
|
||||
seq_len=cfg.max_length,
|
||||
vocab_size=len(tokenizer),
|
||||
pad_token_id=int(tokenizer.pad_token_id),
|
||||
)
|
||||
train_dataloader = DataLoader(
|
||||
dataset,
|
||||
shuffle=True,
|
||||
batch_size=cfg.per_device_train_batch_size,
|
||||
drop_last=True,
|
||||
)
|
||||
else:
|
||||
raw_datasets = load_dataset(cfg.dataset_name, cfg.dataset_config_name, cache_dir=cfg.cache_dir)
|
||||
if "train" not in raw_datasets:
|
||||
raise ValueError(f"Dataset {cfg.dataset_name} has no 'train' split.")
|
||||
|
||||
with accelerator.main_process_first():
|
||||
tokenized = raw_datasets["train"].map(
|
||||
lambda ex: tokenize_fn(ex, tokenizer, cfg.text_column, cfg.max_length),
|
||||
batched=True,
|
||||
remove_columns=raw_datasets["train"].column_names,
|
||||
desc="Tokenizing",
|
||||
)
|
||||
|
||||
collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False, return_tensors="pt")
|
||||
train_dataloader = DataLoader(
|
||||
tokenized, shuffle=True, collate_fn=collator, batch_size=cfg.per_device_train_batch_size, drop_last=True
|
||||
)
|
||||
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.learning_rate, weight_decay=cfg.weight_decay)
|
||||
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / cfg.gradient_accumulation_steps)
|
||||
num_train_epochs = math.ceil(cfg.max_train_steps / num_update_steps_per_epoch)
|
||||
|
||||
lr_scheduler = get_scheduler(
|
||||
name=cfg.lr_scheduler,
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=cfg.lr_warmup_steps,
|
||||
num_training_steps=cfg.max_train_steps,
|
||||
)
|
||||
|
||||
model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
model, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
|
||||
noise_scheduler = BlockRefinementScheduler(block_length=cfg.block_length)
|
||||
|
||||
global_step = 0
|
||||
model.train()
|
||||
|
||||
for _epoch in range(num_train_epochs):
|
||||
for batch in train_dataloader:
|
||||
with accelerator.accumulate(model):
|
||||
input_ids = batch["input_ids"]
|
||||
attention_mask = batch.get("attention_mask", torch.ones_like(input_ids))
|
||||
|
||||
gen = torch.Generator(device=input_ids.device).manual_seed(cfg.seed + global_step)
|
||||
noisy, noisy_rev, masked, masked_rev = noise_scheduler.add_noise(
|
||||
input_ids,
|
||||
attention_mask,
|
||||
prompt_length=cfg.prompt_length,
|
||||
block_length=cfg.block_length,
|
||||
mask_token_id=mask_token_id,
|
||||
generator=gen,
|
||||
)
|
||||
|
||||
position_ids = (
|
||||
torch.arange(input_ids.shape[1], device=input_ids.device).unsqueeze(0).expand_as(input_ids)
|
||||
)
|
||||
|
||||
logits = model(input_ids=noisy, attention_mask=attention_mask, position_ids=position_ids).logits
|
||||
logits_rev = model(
|
||||
input_ids=noisy_rev, attention_mask=attention_mask, position_ids=position_ids
|
||||
).logits
|
||||
|
||||
logits = logits.clone()
|
||||
logits[..., mask_token_id] = torch.finfo(logits.dtype).min
|
||||
logits_rev = logits_rev.clone()
|
||||
logits_rev[..., mask_token_id] = torch.finfo(logits_rev.dtype).min
|
||||
|
||||
valid = attention_mask.to(dtype=torch.bool)
|
||||
masked = masked & valid
|
||||
masked_rev = masked_rev & valid
|
||||
|
||||
labels = input_ids.clone()
|
||||
labels[~masked] = -100
|
||||
labels_rev = input_ids.clone()
|
||||
labels_rev[~masked_rev] = -100
|
||||
|
||||
weights = masked.to(dtype=logits.dtype)
|
||||
weights_rev = masked_rev.to(dtype=logits.dtype)
|
||||
|
||||
loss, loss_sft, loss_conf = compute_confidence_aware_loss(
|
||||
logits,
|
||||
labels,
|
||||
lambda_conf=cfg.lambda_conf,
|
||||
temperature=cfg.conf_temperature,
|
||||
per_token_weights=weights,
|
||||
)
|
||||
loss_rev, loss_sft_rev, loss_conf_rev = compute_confidence_aware_loss(
|
||||
logits_rev,
|
||||
labels_rev,
|
||||
lambda_conf=cfg.lambda_conf,
|
||||
temperature=cfg.conf_temperature,
|
||||
per_token_weights=weights_rev,
|
||||
)
|
||||
|
||||
total_loss = loss + loss_rev
|
||||
accelerator.backward(total_loss)
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
if accelerator.sync_gradients:
|
||||
global_step += 1
|
||||
|
||||
if global_step % cfg.logging_steps == 0 and accelerator.is_main_process:
|
||||
logger.info(
|
||||
"step=%d loss=%.4f sft=%.4f conf=%.4f lr=%.6g",
|
||||
global_step,
|
||||
total_loss.item(),
|
||||
(loss_sft + loss_sft_rev).item(),
|
||||
(loss_conf + loss_conf_rev).item(),
|
||||
lr_scheduler.get_last_lr()[0],
|
||||
)
|
||||
print(
|
||||
f"step={global_step} loss={total_loss.item():.4f} "
|
||||
f"sft={(loss_sft + loss_sft_rev).item():.4f} "
|
||||
f"conf={(loss_conf + loss_conf_rev).item():.4f} "
|
||||
f"lr={lr_scheduler.get_last_lr()[0]:.6g}"
|
||||
)
|
||||
|
||||
if cfg.checkpointing_steps > 0 and global_step % cfg.checkpointing_steps == 0:
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
save_dir = os.path.join(cfg.output_dir, f"checkpoint-{global_step}")
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
accelerator.unwrap_model(model).save_pretrained(save_dir, save_function=accelerator.save)
|
||||
tokenizer.save_pretrained(save_dir)
|
||||
|
||||
if global_step >= cfg.max_train_steps:
|
||||
break
|
||||
|
||||
if global_step >= cfg.max_train_steps:
|
||||
break
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
final_dir = os.path.join(cfg.output_dir, "final")
|
||||
os.makedirs(final_dir, exist_ok=True)
|
||||
accelerator.unwrap_model(model).save_pretrained(final_dir, save_function=accelerator.save)
|
||||
tokenizer.save_pretrained(final_dir)
|
||||
|
||||
logger.info("Done.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -193,6 +193,8 @@ else:
|
||||
"AutoencoderKLHunyuanImageRefiner",
|
||||
"AutoencoderKLHunyuanVideo",
|
||||
"AutoencoderKLHunyuanVideo15",
|
||||
"AutoencoderKLKVAE",
|
||||
"AutoencoderKLKVAEVideo",
|
||||
"AutoencoderKLLTX2Audio",
|
||||
"AutoencoderKLLTX2Video",
|
||||
"AutoencoderKLLTXVideo",
|
||||
@@ -342,6 +344,8 @@ else:
|
||||
_import_structure["schedulers"].extend(
|
||||
[
|
||||
"AmusedScheduler",
|
||||
"BlockRefinementScheduler",
|
||||
"BlockRefinementSchedulerOutput",
|
||||
"CMStochasticIterativeScheduler",
|
||||
"CogVideoXDDIMScheduler",
|
||||
"CogVideoXDPMScheduler",
|
||||
@@ -578,6 +582,8 @@ else:
|
||||
"LDMTextToImagePipeline",
|
||||
"LEditsPPPipelineStableDiffusion",
|
||||
"LEditsPPPipelineStableDiffusionXL",
|
||||
"LLaDA2Pipeline",
|
||||
"LLaDA2PipelineOutput",
|
||||
"LongCatImageEditPipeline",
|
||||
"LongCatImagePipeline",
|
||||
"LTX2ConditionPipeline",
|
||||
@@ -975,6 +981,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
AutoencoderKLHunyuanImageRefiner,
|
||||
AutoencoderKLHunyuanVideo,
|
||||
AutoencoderKLHunyuanVideo15,
|
||||
AutoencoderKLKVAE,
|
||||
AutoencoderKLKVAEVideo,
|
||||
AutoencoderKLLTX2Audio,
|
||||
AutoencoderKLLTX2Video,
|
||||
AutoencoderKLLTXVideo,
|
||||
@@ -1120,6 +1128,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .quantizers import DiffusersQuantizer
|
||||
from .schedulers import (
|
||||
AmusedScheduler,
|
||||
BlockRefinementScheduler,
|
||||
BlockRefinementSchedulerOutput,
|
||||
CMStochasticIterativeScheduler,
|
||||
CogVideoXDDIMScheduler,
|
||||
CogVideoXDPMScheduler,
|
||||
@@ -1335,6 +1345,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
LDMTextToImagePipeline,
|
||||
LEditsPPPipelineStableDiffusion,
|
||||
LEditsPPPipelineStableDiffusionXL,
|
||||
LLaDA2Pipeline,
|
||||
LLaDA2PipelineOutput,
|
||||
LongCatImageEditPipeline,
|
||||
LongCatImagePipeline,
|
||||
LTX2ConditionPipeline,
|
||||
|
||||
@@ -12,7 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import copy
|
||||
import functools
|
||||
import inspect
|
||||
from dataclasses import dataclass
|
||||
from typing import Type
|
||||
@@ -32,7 +31,7 @@ from ..models._modeling_parallel import (
|
||||
gather_size_by_comm,
|
||||
)
|
||||
from ..utils import get_logger
|
||||
from ..utils.torch_utils import maybe_allow_in_graph, unwrap_module
|
||||
from ..utils.torch_utils import lru_cache_unless_export, maybe_allow_in_graph, unwrap_module
|
||||
from .hooks import HookRegistry, ModelHook
|
||||
|
||||
|
||||
@@ -327,7 +326,7 @@ class PartitionAnythingSharder:
|
||||
return tensor
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=64)
|
||||
@lru_cache_unless_export(maxsize=64)
|
||||
def _fill_gather_shapes(shape: tuple[int], gather_dims: tuple[int], dim: int, world_size: int) -> list[list[int]]:
|
||||
gather_shapes = []
|
||||
for i in range(world_size):
|
||||
|
||||
@@ -2443,6 +2443,191 @@ def _convert_non_diffusers_flux2_lora_to_diffusers(state_dict):
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
def _convert_kohya_flux2_lora_to_diffusers(state_dict):
|
||||
def _convert_to_ai_toolkit(sds_sd, ait_sd, sds_key, ait_key):
|
||||
if sds_key + ".lora_down.weight" not in sds_sd:
|
||||
return
|
||||
down_weight = sds_sd.pop(sds_key + ".lora_down.weight")
|
||||
|
||||
# scale weight by alpha and dim
|
||||
rank = down_weight.shape[0]
|
||||
default_alpha = torch.tensor(rank, dtype=down_weight.dtype, device=down_weight.device, requires_grad=False)
|
||||
alpha = sds_sd.pop(sds_key + ".alpha", default_alpha).item()
|
||||
scale = alpha / rank
|
||||
|
||||
scale_down = scale
|
||||
scale_up = 1.0
|
||||
while scale_down * 2 < scale_up:
|
||||
scale_down *= 2
|
||||
scale_up /= 2
|
||||
|
||||
ait_sd[ait_key + ".lora_A.weight"] = down_weight * scale_down
|
||||
ait_sd[ait_key + ".lora_B.weight"] = sds_sd.pop(sds_key + ".lora_up.weight") * scale_up
|
||||
|
||||
def _convert_to_ai_toolkit_cat(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
|
||||
if sds_key + ".lora_down.weight" not in sds_sd:
|
||||
return
|
||||
down_weight = sds_sd.pop(sds_key + ".lora_down.weight")
|
||||
up_weight = sds_sd.pop(sds_key + ".lora_up.weight")
|
||||
sd_lora_rank = down_weight.shape[0]
|
||||
|
||||
default_alpha = torch.tensor(
|
||||
sd_lora_rank, dtype=down_weight.dtype, device=down_weight.device, requires_grad=False
|
||||
)
|
||||
alpha = sds_sd.pop(sds_key + ".alpha", default_alpha)
|
||||
scale = alpha / sd_lora_rank
|
||||
|
||||
scale_down = scale
|
||||
scale_up = 1.0
|
||||
while scale_down * 2 < scale_up:
|
||||
scale_down *= 2
|
||||
scale_up /= 2
|
||||
|
||||
down_weight = down_weight * scale_down
|
||||
up_weight = up_weight * scale_up
|
||||
|
||||
num_splits = len(ait_keys)
|
||||
if dims is None:
|
||||
dims = [up_weight.shape[0] // num_splits] * num_splits
|
||||
else:
|
||||
assert sum(dims) == up_weight.shape[0]
|
||||
|
||||
# check if upweight is sparse
|
||||
is_sparse = False
|
||||
if sd_lora_rank % num_splits == 0:
|
||||
ait_rank = sd_lora_rank // num_splits
|
||||
is_sparse = True
|
||||
i = 0
|
||||
for j in range(len(dims)):
|
||||
for k in range(len(dims)):
|
||||
if j == k:
|
||||
continue
|
||||
is_sparse = is_sparse and torch.all(
|
||||
up_weight[i : i + dims[j], k * ait_rank : (k + 1) * ait_rank] == 0
|
||||
)
|
||||
i += dims[j]
|
||||
if is_sparse:
|
||||
logger.info(f"weight is sparse: {sds_key}")
|
||||
|
||||
ait_down_keys = [k + ".lora_A.weight" for k in ait_keys]
|
||||
ait_up_keys = [k + ".lora_B.weight" for k in ait_keys]
|
||||
if not is_sparse:
|
||||
ait_sd.update(dict.fromkeys(ait_down_keys, down_weight))
|
||||
ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))}) # noqa: C416
|
||||
else:
|
||||
ait_sd.update({k: v for k, v in zip(ait_down_keys, torch.chunk(down_weight, num_splits, dim=0))}) # noqa: C416
|
||||
i = 0
|
||||
for j in range(len(dims)):
|
||||
ait_sd[ait_up_keys[j]] = up_weight[i : i + dims[j], j * ait_rank : (j + 1) * ait_rank].contiguous()
|
||||
i += dims[j]
|
||||
|
||||
# Detect number of blocks from keys
|
||||
num_double_layers = 0
|
||||
num_single_layers = 0
|
||||
for key in state_dict.keys():
|
||||
if key.startswith("lora_unet_double_blocks_"):
|
||||
block_idx = int(key.split("_")[4])
|
||||
num_double_layers = max(num_double_layers, block_idx + 1)
|
||||
elif key.startswith("lora_unet_single_blocks_"):
|
||||
block_idx = int(key.split("_")[4])
|
||||
num_single_layers = max(num_single_layers, block_idx + 1)
|
||||
|
||||
ait_sd = {}
|
||||
|
||||
for i in range(num_double_layers):
|
||||
# Attention projections
|
||||
_convert_to_ai_toolkit(
|
||||
state_dict,
|
||||
ait_sd,
|
||||
f"lora_unet_double_blocks_{i}_img_attn_proj",
|
||||
f"transformer.transformer_blocks.{i}.attn.to_out.0",
|
||||
)
|
||||
_convert_to_ai_toolkit_cat(
|
||||
state_dict,
|
||||
ait_sd,
|
||||
f"lora_unet_double_blocks_{i}_img_attn_qkv",
|
||||
[
|
||||
f"transformer.transformer_blocks.{i}.attn.to_q",
|
||||
f"transformer.transformer_blocks.{i}.attn.to_k",
|
||||
f"transformer.transformer_blocks.{i}.attn.to_v",
|
||||
],
|
||||
)
|
||||
_convert_to_ai_toolkit(
|
||||
state_dict,
|
||||
ait_sd,
|
||||
f"lora_unet_double_blocks_{i}_txt_attn_proj",
|
||||
f"transformer.transformer_blocks.{i}.attn.to_add_out",
|
||||
)
|
||||
_convert_to_ai_toolkit_cat(
|
||||
state_dict,
|
||||
ait_sd,
|
||||
f"lora_unet_double_blocks_{i}_txt_attn_qkv",
|
||||
[
|
||||
f"transformer.transformer_blocks.{i}.attn.add_q_proj",
|
||||
f"transformer.transformer_blocks.{i}.attn.add_k_proj",
|
||||
f"transformer.transformer_blocks.{i}.attn.add_v_proj",
|
||||
],
|
||||
)
|
||||
# MLP layers (Flux2 uses ff.linear_in/linear_out)
|
||||
_convert_to_ai_toolkit(
|
||||
state_dict,
|
||||
ait_sd,
|
||||
f"lora_unet_double_blocks_{i}_img_mlp_0",
|
||||
f"transformer.transformer_blocks.{i}.ff.linear_in",
|
||||
)
|
||||
_convert_to_ai_toolkit(
|
||||
state_dict,
|
||||
ait_sd,
|
||||
f"lora_unet_double_blocks_{i}_img_mlp_2",
|
||||
f"transformer.transformer_blocks.{i}.ff.linear_out",
|
||||
)
|
||||
_convert_to_ai_toolkit(
|
||||
state_dict,
|
||||
ait_sd,
|
||||
f"lora_unet_double_blocks_{i}_txt_mlp_0",
|
||||
f"transformer.transformer_blocks.{i}.ff_context.linear_in",
|
||||
)
|
||||
_convert_to_ai_toolkit(
|
||||
state_dict,
|
||||
ait_sd,
|
||||
f"lora_unet_double_blocks_{i}_txt_mlp_2",
|
||||
f"transformer.transformer_blocks.{i}.ff_context.linear_out",
|
||||
)
|
||||
|
||||
for i in range(num_single_layers):
|
||||
# Single blocks: linear1 -> attn.to_qkv_mlp_proj (fused, no split needed)
|
||||
_convert_to_ai_toolkit(
|
||||
state_dict,
|
||||
ait_sd,
|
||||
f"lora_unet_single_blocks_{i}_linear1",
|
||||
f"transformer.single_transformer_blocks.{i}.attn.to_qkv_mlp_proj",
|
||||
)
|
||||
# Single blocks: linear2 -> attn.to_out
|
||||
_convert_to_ai_toolkit(
|
||||
state_dict,
|
||||
ait_sd,
|
||||
f"lora_unet_single_blocks_{i}_linear2",
|
||||
f"transformer.single_transformer_blocks.{i}.attn.to_out",
|
||||
)
|
||||
|
||||
# Handle optional extra keys
|
||||
extra_mappings = {
|
||||
"lora_unet_img_in": "transformer.x_embedder",
|
||||
"lora_unet_txt_in": "transformer.context_embedder",
|
||||
"lora_unet_time_in_in_layer": "transformer.time_guidance_embed.timestep_embedder.linear_1",
|
||||
"lora_unet_time_in_out_layer": "transformer.time_guidance_embed.timestep_embedder.linear_2",
|
||||
"lora_unet_final_layer_linear": "transformer.proj_out",
|
||||
}
|
||||
for sds_key, ait_key in extra_mappings.items():
|
||||
_convert_to_ai_toolkit(state_dict, ait_sd, sds_key, ait_key)
|
||||
|
||||
remaining_keys = list(state_dict.keys())
|
||||
if remaining_keys:
|
||||
logger.warning(f"Unsupported keys for Kohya Flux2 LoRA conversion: {remaining_keys}")
|
||||
|
||||
return ait_sd
|
||||
|
||||
|
||||
def _convert_non_diffusers_z_image_lora_to_diffusers(state_dict):
|
||||
"""
|
||||
Convert non-diffusers ZImage LoRA state dict to diffusers format.
|
||||
|
||||
@@ -43,6 +43,7 @@ from .lora_conversion_utils import (
|
||||
_convert_bfl_flux_control_lora_to_diffusers,
|
||||
_convert_fal_kontext_lora_to_diffusers,
|
||||
_convert_hunyuan_video_lora_to_diffusers,
|
||||
_convert_kohya_flux2_lora_to_diffusers,
|
||||
_convert_kohya_flux_lora_to_diffusers,
|
||||
_convert_musubi_wan_lora_to_diffusers,
|
||||
_convert_non_diffusers_flux2_lora_to_diffusers,
|
||||
@@ -5673,6 +5674,13 @@ class Flux2LoraLoaderMixin(LoraBaseMixin):
|
||||
logger.warning(warn_msg)
|
||||
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
|
||||
|
||||
is_kohya = any(".lora_down.weight" in k for k in state_dict)
|
||||
if is_kohya:
|
||||
state_dict = _convert_kohya_flux2_lora_to_diffusers(state_dict)
|
||||
# Kohya already takes care of scaling the LoRA parameters with alpha.
|
||||
out = (state_dict, metadata) if return_lora_metadata else state_dict
|
||||
return out
|
||||
|
||||
is_peft_format = any(k.startswith("base_model.model.") for k in state_dict)
|
||||
if is_peft_format:
|
||||
state_dict = {k.replace("base_model.model.", "diffusion_model."): v for k, v in state_dict.items()}
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
@@ -44,33 +45,13 @@ from .unet_loader_utils import _maybe_expand_lora_scales
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_SET_ADAPTER_SCALE_FN_MAPPING = {
|
||||
"UNet2DConditionModel": _maybe_expand_lora_scales,
|
||||
"UNetMotionModel": _maybe_expand_lora_scales,
|
||||
"SD3Transformer2DModel": lambda model_cls, weights: weights,
|
||||
"FluxTransformer2DModel": lambda model_cls, weights: weights,
|
||||
"CogVideoXTransformer3DModel": lambda model_cls, weights: weights,
|
||||
"ConsisIDTransformer3DModel": lambda model_cls, weights: weights,
|
||||
"HeliosTransformer3DModel": lambda model_cls, weights: weights,
|
||||
"MochiTransformer3DModel": lambda model_cls, weights: weights,
|
||||
"HunyuanVideoTransformer3DModel": lambda model_cls, weights: weights,
|
||||
"LTXVideoTransformer3DModel": lambda model_cls, weights: weights,
|
||||
"SanaTransformer2DModel": lambda model_cls, weights: weights,
|
||||
"AuraFlowTransformer2DModel": lambda model_cls, weights: weights,
|
||||
"Lumina2Transformer2DModel": lambda model_cls, weights: weights,
|
||||
"WanTransformer3DModel": lambda model_cls, weights: weights,
|
||||
"CogView4Transformer2DModel": lambda model_cls, weights: weights,
|
||||
"HiDreamImageTransformer2DModel": lambda model_cls, weights: weights,
|
||||
"HunyuanVideoFramepackTransformer3DModel": lambda model_cls, weights: weights,
|
||||
"WanVACETransformer3DModel": lambda model_cls, weights: weights,
|
||||
"ChromaTransformer2DModel": lambda model_cls, weights: weights,
|
||||
"ChronoEditTransformer3DModel": lambda model_cls, weights: weights,
|
||||
"QwenImageTransformer2DModel": lambda model_cls, weights: weights,
|
||||
"Flux2Transformer2DModel": lambda model_cls, weights: weights,
|
||||
"ZImageTransformer2DModel": lambda model_cls, weights: weights,
|
||||
"LTX2VideoTransformer3DModel": lambda model_cls, weights: weights,
|
||||
"LTX2TextConnectors": lambda model_cls, weights: weights,
|
||||
}
|
||||
_SET_ADAPTER_SCALE_FN_MAPPING = defaultdict(
|
||||
lambda: (lambda model_cls, weights: weights),
|
||||
{
|
||||
"UNet2DConditionModel": _maybe_expand_lora_scales,
|
||||
"UNetMotionModel": _maybe_expand_lora_scales,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class PeftAdapterMixin:
|
||||
|
||||
@@ -40,6 +40,8 @@ if is_torch_available():
|
||||
_import_structure["autoencoders.autoencoder_kl_hunyuanimage"] = ["AutoencoderKLHunyuanImage"]
|
||||
_import_structure["autoencoders.autoencoder_kl_hunyuanimage_refiner"] = ["AutoencoderKLHunyuanImageRefiner"]
|
||||
_import_structure["autoencoders.autoencoder_kl_hunyuanvideo15"] = ["AutoencoderKLHunyuanVideo15"]
|
||||
_import_structure["autoencoders.autoencoder_kl_kvae"] = ["AutoencoderKLKVAE"]
|
||||
_import_structure["autoencoders.autoencoder_kl_kvae_video"] = ["AutoencoderKLKVAEVideo"]
|
||||
_import_structure["autoencoders.autoencoder_kl_ltx"] = ["AutoencoderKLLTXVideo"]
|
||||
_import_structure["autoencoders.autoencoder_kl_ltx2"] = ["AutoencoderKLLTX2Video"]
|
||||
_import_structure["autoencoders.autoencoder_kl_ltx2_audio"] = ["AutoencoderKLLTX2Audio"]
|
||||
@@ -161,6 +163,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
AutoencoderKLHunyuanImageRefiner,
|
||||
AutoencoderKLHunyuanVideo,
|
||||
AutoencoderKLHunyuanVideo15,
|
||||
AutoencoderKLKVAE,
|
||||
AutoencoderKLKVAEVideo,
|
||||
AutoencoderKLLTX2Audio,
|
||||
AutoencoderKLLTX2Video,
|
||||
AutoencoderKLLTXVideo,
|
||||
|
||||
@@ -49,7 +49,7 @@ from ..utils import (
|
||||
is_xformers_version,
|
||||
)
|
||||
from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS
|
||||
from ..utils.torch_utils import maybe_allow_in_graph
|
||||
from ..utils.torch_utils import lru_cache_unless_export, maybe_allow_in_graph
|
||||
from ._modeling_parallel import gather_size_by_comm
|
||||
|
||||
|
||||
@@ -229,6 +229,7 @@ class AttentionBackendName(str, Enum):
|
||||
FLASH_HUB = "flash_hub"
|
||||
FLASH_VARLEN = "flash_varlen"
|
||||
FLASH_VARLEN_HUB = "flash_varlen_hub"
|
||||
FLASH_4_HUB = "flash_4_hub"
|
||||
_FLASH_3 = "_flash_3"
|
||||
_FLASH_VARLEN_3 = "_flash_varlen_3"
|
||||
_FLASH_3_HUB = "_flash_3_hub"
|
||||
@@ -358,6 +359,11 @@ _HUB_KERNELS_REGISTRY: dict["AttentionBackendName", _HubKernelConfig] = {
|
||||
function_attr="sageattn",
|
||||
version=1,
|
||||
),
|
||||
AttentionBackendName.FLASH_4_HUB: _HubKernelConfig(
|
||||
repo_id="kernels-staging/flash-attn4",
|
||||
function_attr="flash_attn_func",
|
||||
version=0,
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@@ -521,6 +527,7 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
|
||||
AttentionBackendName._FLASH_3_HUB,
|
||||
AttentionBackendName._FLASH_3_VARLEN_HUB,
|
||||
AttentionBackendName.SAGE_HUB,
|
||||
AttentionBackendName.FLASH_4_HUB,
|
||||
]:
|
||||
if not is_kernels_available():
|
||||
raise RuntimeError(
|
||||
@@ -531,6 +538,11 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
|
||||
f"Backend '{backend.value}' needs to be used with a `kernels` version of at least 0.12. Please update with `pip install -U kernels`."
|
||||
)
|
||||
|
||||
if backend == AttentionBackendName.FLASH_4_HUB and not is_kernels_available(">=", "0.12.3"):
|
||||
raise RuntimeError(
|
||||
f"Backend '{backend.value}' needs to be used with a `kernels` version of at least 0.12.3. Please update with `pip install -U kernels`."
|
||||
)
|
||||
|
||||
elif backend == AttentionBackendName.AITER:
|
||||
if not _CAN_USE_AITER_ATTN:
|
||||
raise RuntimeError(
|
||||
@@ -575,7 +587,7 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
|
||||
)
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=128)
|
||||
@lru_cache_unless_export(maxsize=128)
|
||||
def _prepare_for_flash_attn_or_sage_varlen_without_mask(
|
||||
batch_size: int,
|
||||
seq_len_q: int,
|
||||
@@ -2676,6 +2688,37 @@ def _flash_attention_3_varlen_hub(
|
||||
return (out, lse) if return_lse else out
|
||||
|
||||
|
||||
@_AttentionBackendRegistry.register(
|
||||
AttentionBackendName.FLASH_4_HUB,
|
||||
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
||||
supports_context_parallel=False,
|
||||
)
|
||||
def _flash_attention_4_hub(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: torch.Tensor | None = None,
|
||||
scale: float | None = None,
|
||||
is_causal: bool = False,
|
||||
return_lse: bool = False,
|
||||
_parallel_config: "ParallelConfig" | None = None,
|
||||
) -> torch.Tensor:
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not supported for flash-attn 4.")
|
||||
|
||||
func = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_4_HUB].kernel_fn
|
||||
out = func(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
softmax_scale=scale,
|
||||
causal=is_causal,
|
||||
)
|
||||
if isinstance(out, tuple):
|
||||
return (out[0], out[1]) if return_lse else out[0]
|
||||
return out
|
||||
|
||||
|
||||
@_AttentionBackendRegistry.register(
|
||||
AttentionBackendName._FLASH_VARLEN_3,
|
||||
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
||||
|
||||
@@ -9,6 +9,8 @@ from .autoencoder_kl_hunyuan_video import AutoencoderKLHunyuanVideo
|
||||
from .autoencoder_kl_hunyuanimage import AutoencoderKLHunyuanImage
|
||||
from .autoencoder_kl_hunyuanimage_refiner import AutoencoderKLHunyuanImageRefiner
|
||||
from .autoencoder_kl_hunyuanvideo15 import AutoencoderKLHunyuanVideo15
|
||||
from .autoencoder_kl_kvae import AutoencoderKLKVAE
|
||||
from .autoencoder_kl_kvae_video import AutoencoderKLKVAEVideo
|
||||
from .autoencoder_kl_ltx import AutoencoderKLLTXVideo
|
||||
from .autoencoder_kl_ltx2 import AutoencoderKLLTX2Video
|
||||
from .autoencoder_kl_ltx2_audio import AutoencoderKLLTX2Audio
|
||||
|
||||
@@ -87,7 +87,14 @@ class HunyuanImageRefinerRMS_norm(nn.Module):
|
||||
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
|
||||
|
||||
def forward(self, x):
|
||||
return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias
|
||||
needs_fp32_normalize = x.dtype in (torch.float16, torch.bfloat16) or any(
|
||||
t in str(x.dtype) for t in ("float4_", "float8_")
|
||||
)
|
||||
normalized = F.normalize(x.float() if needs_fp32_normalize else x, dim=(1 if self.channel_first else -1)).to(
|
||||
x.dtype
|
||||
)
|
||||
|
||||
return normalized * self.scale * self.gamma + self.bias
|
||||
|
||||
|
||||
class HunyuanImageRefinerAttnBlock(nn.Module):
|
||||
|
||||
@@ -87,7 +87,14 @@ class HunyuanVideo15RMS_norm(nn.Module):
|
||||
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
|
||||
|
||||
def forward(self, x):
|
||||
return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias
|
||||
needs_fp32_normalize = x.dtype in (torch.float16, torch.bfloat16) or any(
|
||||
t in str(x.dtype) for t in ("float4_", "float8_")
|
||||
)
|
||||
normalized = F.normalize(x.float() if needs_fp32_normalize else x, dim=(1 if self.channel_first else -1)).to(
|
||||
x.dtype
|
||||
)
|
||||
|
||||
return normalized * self.scale * self.gamma + self.bias
|
||||
|
||||
|
||||
class HunyuanVideo15AttnBlock(nn.Module):
|
||||
|
||||
802
src/diffusers/models/autoencoders/autoencoder_kl_kvae.py
Normal file
802
src/diffusers/models/autoencoders/autoencoder_kl_kvae.py
Normal file
@@ -0,0 +1,802 @@
|
||||
# Copyright 2025 The Kandinsky Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils.accelerate_utils import apply_forward_hook
|
||||
from ..activations import get_activation
|
||||
from ..modeling_outputs import AutoencoderKLOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
|
||||
|
||||
|
||||
class KVAEResnetBlock2D(nn.Module):
|
||||
r"""
|
||||
A Resnet block with optional guidance.
|
||||
|
||||
Parameters:
|
||||
in_channels (`int`): The number of channels in the input.
|
||||
out_channels (`int`, *optional*, default to `None`):
|
||||
The number of output channels for the first conv2d layer. If None, same as `in_channels`.
|
||||
conv_shortcut (`bool`, *optional*, default to `False`):
|
||||
If `True` and `in_channels` not equal to `out_channels`, add a 3x3 nn.conv2d layer for skip-connection.
|
||||
temb_channels (`int`, *optional*, default to `512`): The number of channels in timestep embedding.
|
||||
zq_ch (`int`, *optional*, default to `None`): Guidance channels for normalization.
|
||||
add_conv (`bool`, *optional*, default to `False`):
|
||||
If `True` add conv2d layer for normalization.
|
||||
normalization (`nn.Module`, *optional*, default to `None`): The normalization layer.
|
||||
act_fn (`str`, *optional*, default to `"swish"`): The activation function to use.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
in_channels: int,
|
||||
out_channels: Optional[int] = None,
|
||||
conv_shortcut: bool = False,
|
||||
temb_channels: int = 512,
|
||||
zq_ch: Optional[int] = None,
|
||||
add_conv: bool = False,
|
||||
act_fn: str = "swish",
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
self.out_channels = out_channels
|
||||
self.use_conv_shortcut = conv_shortcut
|
||||
self.nonlinearity = get_activation(act_fn)
|
||||
|
||||
if zq_ch is None:
|
||||
self.norm1 = nn.GroupNorm(num_channels=in_channels, num_groups=32, eps=1e-6, affine=True)
|
||||
else:
|
||||
self.norm1 = KVAEDecoderSpatialNorm2D(in_channels, zq_channels=zq_ch, add_conv=add_conv)
|
||||
|
||||
self.conv1 = nn.Conv2d(
|
||||
in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=(1, 1), padding_mode="replicate"
|
||||
)
|
||||
if temb_channels > 0:
|
||||
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
|
||||
if zq_ch is None:
|
||||
self.norm2 = nn.GroupNorm(num_channels=out_channels, num_groups=32, eps=1e-6, affine=True)
|
||||
else:
|
||||
self.norm2 = KVAEDecoderSpatialNorm2D(out_channels, zq_channels=zq_ch, add_conv=add_conv)
|
||||
self.conv2 = nn.Conv2d(
|
||||
in_channels=out_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
padding=(1, 1),
|
||||
padding_mode="replicate",
|
||||
)
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
self.conv_shortcut = nn.Conv2d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
padding=(1, 1),
|
||||
padding_mode="replicate",
|
||||
)
|
||||
else:
|
||||
self.nin_shortcut = nn.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor, temb: torch.Tensor, zq: torch.Tensor = None) -> torch.Tensor:
|
||||
h = x
|
||||
|
||||
if zq is None:
|
||||
h = self.norm1(h)
|
||||
else:
|
||||
h = self.norm1(h, zq)
|
||||
|
||||
h = self.nonlinearity(h)
|
||||
h = self.conv1(h)
|
||||
|
||||
if temb is not None:
|
||||
h = h + self.temb_proj(self.nonlinearity(temb))[:, :, None, None, None]
|
||||
|
||||
if zq is None:
|
||||
h = self.norm2(h)
|
||||
else:
|
||||
h = self.norm2(h, zq)
|
||||
|
||||
h = self.nonlinearity(h)
|
||||
|
||||
h = self.conv2(h)
|
||||
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
x = self.conv_shortcut(x)
|
||||
else:
|
||||
x = self.nin_shortcut(x)
|
||||
|
||||
return x + h
|
||||
|
||||
|
||||
class KVAEPXSDownsample(nn.Module):
|
||||
def __init__(self, in_channels: int, factor: int = 2):
|
||||
r"""
|
||||
A Downsampling module.
|
||||
|
||||
Args:
|
||||
in_channels (`int`): The number of channels in the input.
|
||||
factor (`int`, *optional*, default to `2`): The downsampling factor.
|
||||
"""
|
||||
super().__init__()
|
||||
self.factor = factor
|
||||
self.unshuffle = nn.PixelUnshuffle(self.factor)
|
||||
self.spatial_conv = nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), padding_mode="reflect"
|
||||
)
|
||||
self.linear = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# x: (bchw)
|
||||
pxs_interm = self.unshuffle(x)
|
||||
b, c, h, w = pxs_interm.shape
|
||||
pxs_interm_view = pxs_interm.view(b, c // self.factor**2, self.factor**2, h, w)
|
||||
pxs_out = torch.mean(pxs_interm_view, dim=2)
|
||||
|
||||
conv_out = self.spatial_conv(x)
|
||||
|
||||
# adding it all together
|
||||
out = conv_out + pxs_out
|
||||
return self.linear(out)
|
||||
|
||||
|
||||
class KVAEPXSUpsample(nn.Module):
|
||||
def __init__(self, in_channels: int, factor: int = 2):
|
||||
r"""
|
||||
An Upsampling module.
|
||||
|
||||
Args:
|
||||
in_channels (`int`): The number of channels in the input.
|
||||
factor (`int`, *optional*, default to `2`): The upsampling factor.
|
||||
"""
|
||||
super().__init__()
|
||||
self.factor = factor
|
||||
self.shuffle = nn.PixelShuffle(self.factor)
|
||||
self.spatial_conv = nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode="reflect"
|
||||
)
|
||||
|
||||
self.linear = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
repeated = x.repeat_interleave(self.factor**2, dim=1)
|
||||
pxs_interm = self.shuffle(repeated)
|
||||
|
||||
image_like_ups = F.interpolate(x, scale_factor=2, mode="nearest")
|
||||
conv_out = self.spatial_conv(image_like_ups)
|
||||
|
||||
# adding it all together
|
||||
out = conv_out + pxs_interm
|
||||
return self.linear(out)
|
||||
|
||||
|
||||
class KVAEDecoderSpatialNorm2D(nn.Module):
|
||||
r"""
|
||||
A 2D normalization module for decoder.
|
||||
|
||||
Args:
|
||||
in_channels (`int`): The number of channels in the input.
|
||||
zq_channels (`int`): The number of channels in the guidance.
|
||||
add_conv (`bool`, *optional*, default to `false`):
|
||||
If `True` add conv2d 3x3 layer for guidance in the beginning.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
zq_channels: int,
|
||||
add_conv: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.norm_layer = nn.GroupNorm(num_channels=in_channels, num_groups=32, eps=1e-6, affine=True)
|
||||
|
||||
self.add_conv = add_conv
|
||||
if add_conv:
|
||||
self.conv = nn.Conv2d(
|
||||
in_channels=zq_channels,
|
||||
out_channels=zq_channels,
|
||||
kernel_size=3,
|
||||
padding=(1, 1),
|
||||
padding_mode="replicate",
|
||||
)
|
||||
|
||||
self.conv_y = nn.Conv2d(
|
||||
in_channels=zq_channels,
|
||||
out_channels=in_channels,
|
||||
kernel_size=1,
|
||||
)
|
||||
self.conv_b = nn.Conv2d(
|
||||
in_channels=zq_channels,
|
||||
out_channels=in_channels,
|
||||
kernel_size=1,
|
||||
)
|
||||
|
||||
def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor:
|
||||
f_first = f
|
||||
f_first_size = f_first.shape[2:]
|
||||
zq = F.interpolate(zq, size=f_first_size, mode="nearest")
|
||||
|
||||
if self.add_conv:
|
||||
zq = self.conv(zq)
|
||||
|
||||
norm_f = self.norm_layer(f)
|
||||
new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
|
||||
return new_f
|
||||
|
||||
|
||||
class KVAEEncoder2D(nn.Module):
|
||||
r"""
|
||||
A 2D encoder module.
|
||||
|
||||
Args:
|
||||
ch (`int`): The base number of channels in multiresolution blocks.
|
||||
ch_mult (`Tuple[int, ...]`, *optional*, default to `(1, 2, 4, 8)`):
|
||||
The channel multipliers in multiresolution blocks.
|
||||
num_res_blocks (`int`): The number of Resnet blocks.
|
||||
in_channels (`int`): The number of channels in the input.
|
||||
z_channels (`int`): The number of output channels.
|
||||
double_z (`bool`, *optional*, defaults to `True`):
|
||||
Whether to double the number of output channels for the last block.
|
||||
act_fn (`str`, *optional*, default to `"swish"`): The activation function to use.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
ch: int,
|
||||
ch_mult: Tuple[int, ...] = (1, 2, 4, 8),
|
||||
num_res_blocks: int,
|
||||
in_channels: int,
|
||||
z_channels: int,
|
||||
double_z: bool = True,
|
||||
act_fn: str = "swish",
|
||||
):
|
||||
super().__init__()
|
||||
self.ch = ch
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
if isinstance(num_res_blocks, int):
|
||||
self.num_res_blocks = [num_res_blocks] * self.num_resolutions
|
||||
else:
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.nonlinearity = get_activation(act_fn)
|
||||
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.conv_in = nn.Conv2d(
|
||||
in_channels=in_channels,
|
||||
out_channels=self.ch,
|
||||
kernel_size=3,
|
||||
padding=(1, 1),
|
||||
)
|
||||
|
||||
in_ch_mult = (1,) + tuple(ch_mult)
|
||||
self.down = nn.ModuleList()
|
||||
for i_level in range(self.num_resolutions):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_in = ch * in_ch_mult[i_level]
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks[i_level]):
|
||||
block.append(
|
||||
KVAEResnetBlock2D(
|
||||
in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
)
|
||||
)
|
||||
block_in = block_out
|
||||
down = nn.Module()
|
||||
down.block = block
|
||||
down.attn = attn
|
||||
if i_level < self.num_resolutions - 1:
|
||||
down.downsample = KVAEPXSDownsample(in_channels=block_in) # mb: bad out channels
|
||||
self.down.append(down)
|
||||
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = KVAEResnetBlock2D(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
)
|
||||
|
||||
self.mid.block_2 = KVAEResnetBlock2D(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
)
|
||||
|
||||
# end
|
||||
self.norm_out = nn.GroupNorm(num_channels=block_in, num_groups=32, eps=1e-6, affine=True)
|
||||
|
||||
self.conv_out = nn.Conv2d(
|
||||
in_channels=block_in,
|
||||
out_channels=2 * z_channels if double_z else z_channels,
|
||||
kernel_size=3,
|
||||
padding=(1, 1),
|
||||
)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# timestep embedding
|
||||
temb = None
|
||||
|
||||
# downsampling
|
||||
h = self.conv_in(x)
|
||||
for i_level in range(self.num_resolutions):
|
||||
for i_block in range(self.num_res_blocks[i_level]):
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
h = self._gradient_checkpointing_func(self.down[i_level].block[i_block], h, temb)
|
||||
else:
|
||||
h = self.down[i_level].block[i_block](h, temb)
|
||||
if len(self.down[i_level].attn) > 0:
|
||||
h = self.down[i_level].attn[i_block](h)
|
||||
if i_level != self.num_resolutions - 1:
|
||||
h = self.down[i_level].downsample(h)
|
||||
|
||||
# middle
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
h = self._gradient_checkpointing_func(self.mid.block_1, h, temb)
|
||||
h = self._gradient_checkpointing_func(self.mid.block_2, h, temb)
|
||||
else:
|
||||
h = self.mid.block_1(h, temb)
|
||||
h = self.mid.block_2(h, temb)
|
||||
|
||||
# end
|
||||
h = self.norm_out(h)
|
||||
h = self.nonlinearity(h)
|
||||
h = self.conv_out(h)
|
||||
|
||||
return h
|
||||
|
||||
|
||||
class KVAEDecoder2D(nn.Module):
|
||||
r"""
|
||||
A 2D decoder module.
|
||||
|
||||
Args:
|
||||
ch (`int`): The base number of channels in multiresolution blocks.
|
||||
out_ch (`int`): The number of output channels.
|
||||
ch_mult (`Tuple[int, ...]`, *optional*, default to `(1, 2, 4, 8)`):
|
||||
The channel multipliers in multiresolution blocks.
|
||||
num_res_blocks (`int`): The number of Resnet blocks.
|
||||
in_channels (`int`): The number of channels in the input.
|
||||
z_channels (`int`): The number of input channels.
|
||||
give_pre_end (`bool`, *optional*, default to `false`):
|
||||
If `True` exit the forward pass early and return the penultimate feature map.
|
||||
zq_ch (`bool`, *optional*, default to `None`): The number of channels in the guidance.
|
||||
add_conv (`bool`, *optional*, default to `false`): If `True` add conv2d layer for Resnet normalization layer.
|
||||
act_fn (`str`, *optional*, default to `"swish"`): The activation function to use.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
ch: int,
|
||||
out_ch: int,
|
||||
ch_mult: Tuple[int, ...] = (1, 2, 4, 8),
|
||||
num_res_blocks: int,
|
||||
in_channels: int,
|
||||
z_channels: int,
|
||||
give_pre_end: bool = False,
|
||||
zq_ch: Optional[int] = None,
|
||||
add_conv: bool = False,
|
||||
act_fn: str = "swish",
|
||||
):
|
||||
super().__init__()
|
||||
self.ch = ch
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.in_channels = in_channels
|
||||
self.give_pre_end = give_pre_end
|
||||
self.nonlinearity = get_activation(act_fn)
|
||||
|
||||
if zq_ch is None:
|
||||
zq_ch = z_channels
|
||||
|
||||
# compute in_ch_mult, block_in and curr_res at lowest res
|
||||
block_in = ch * ch_mult[self.num_resolutions - 1]
|
||||
|
||||
self.conv_in = nn.Conv2d(
|
||||
in_channels=z_channels, out_channels=block_in, kernel_size=3, padding=(1, 1), padding_mode="replicate"
|
||||
)
|
||||
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = KVAEResnetBlock2D(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
zq_ch=zq_ch,
|
||||
add_conv=add_conv,
|
||||
)
|
||||
|
||||
self.mid.block_2 = KVAEResnetBlock2D(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
zq_ch=zq_ch,
|
||||
add_conv=add_conv,
|
||||
)
|
||||
|
||||
# upsampling
|
||||
self.up = nn.ModuleList()
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
block.append(
|
||||
KVAEResnetBlock2D(
|
||||
in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
zq_ch=zq_ch,
|
||||
add_conv=add_conv,
|
||||
)
|
||||
)
|
||||
block_in = block_out
|
||||
up = nn.Module()
|
||||
up.block = block
|
||||
up.attn = attn
|
||||
if i_level != 0:
|
||||
up.upsample = KVAEPXSUpsample(in_channels=block_in)
|
||||
self.up.insert(0, up)
|
||||
|
||||
self.norm_out = KVAEDecoderSpatialNorm2D(block_in, zq_ch, add_conv=add_conv) # , gather=gather_norm)
|
||||
|
||||
self.conv_out = nn.Conv2d(
|
||||
in_channels=block_in, out_channels=out_ch, kernel_size=3, padding=(1, 1), padding_mode="replicate"
|
||||
)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, z: torch.Tensor) -> torch.Tensor:
|
||||
self.last_z_shape = z.shape
|
||||
|
||||
# timestep embedding
|
||||
temb = None
|
||||
|
||||
# z to block_in
|
||||
zq = z
|
||||
h = self.conv_in(z)
|
||||
|
||||
# middle
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
h = self._gradient_checkpointing_func(self.mid.block_1, h, temb, zq)
|
||||
h = self._gradient_checkpointing_func(self.mid.block_2, h, temb, zq)
|
||||
else:
|
||||
h = self.mid.block_1(h, temb, zq)
|
||||
h = self.mid.block_2(h, temb, zq)
|
||||
|
||||
# upsampling
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
h = self._gradient_checkpointing_func(self.up[i_level].block[i_block], h, temb, zq)
|
||||
else:
|
||||
h = self.up[i_level].block[i_block](h, temb, zq)
|
||||
if len(self.up[i_level].attn) > 0:
|
||||
h = self.up[i_level].attn[i_block](h, zq)
|
||||
if i_level != 0:
|
||||
h = self.up[i_level].upsample(h)
|
||||
|
||||
# end
|
||||
if self.give_pre_end:
|
||||
return h
|
||||
|
||||
h = self.norm_out(h, zq)
|
||||
h = self.nonlinearity(h)
|
||||
h = self.conv_out(h)
|
||||
|
||||
return h
|
||||
|
||||
|
||||
class AutoencoderKLKVAE(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
r"""
|
||||
A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
|
||||
|
||||
This model inherits from [`ModelMixin`]. Check the superclass documentation for its generic methods implemented for
|
||||
all models (such as downloading or saving).
|
||||
|
||||
Parameters:
|
||||
in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
|
||||
channels (int, *optional*, defaults to 128): The base number of channels in multiresolution blocks.
|
||||
num_enc_blocks (int, *optional*, defaults to 2):
|
||||
The number of Resnet blocks in encoder multiresolution layers.
|
||||
num_dec_blocks (int, *optional*, defaults to 2):
|
||||
The number of Resnet blocks in decoder multiresolution layers.
|
||||
z_channels (int, *optional*, defaults to 16): Number of channels in the latent space.
|
||||
double_z (`bool`, *optional*, defaults to `True`):
|
||||
Whether to double the number of output channels of encoder.
|
||||
ch_mult (`Tuple[int, ...]`, *optional*, default to `(1, 2, 4, 8)`):
|
||||
The channel multipliers in multiresolution blocks.
|
||||
sample_size (`int`, *optional*, defaults to `1024`): Sample input size.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 3,
|
||||
channels: int = 128,
|
||||
num_enc_blocks: int = 2,
|
||||
num_dec_blocks: int = 2,
|
||||
z_channels: int = 16,
|
||||
double_z: bool = True,
|
||||
ch_mult: Tuple[int, ...] = (1, 2, 4, 8),
|
||||
sample_size: int = 1024,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# pass init params to Encoder
|
||||
self.encoder = KVAEEncoder2D(
|
||||
in_channels=in_channels,
|
||||
ch=channels,
|
||||
ch_mult=ch_mult,
|
||||
num_res_blocks=num_enc_blocks,
|
||||
z_channels=z_channels,
|
||||
double_z=double_z,
|
||||
)
|
||||
|
||||
# pass init params to Decoder
|
||||
self.decoder = KVAEDecoder2D(
|
||||
out_ch=in_channels,
|
||||
ch=channels,
|
||||
ch_mult=ch_mult,
|
||||
num_res_blocks=num_dec_blocks,
|
||||
in_channels=None,
|
||||
z_channels=z_channels,
|
||||
)
|
||||
|
||||
self.use_slicing = False
|
||||
self.use_tiling = False
|
||||
|
||||
# only relevant if vae tiling is enabled
|
||||
self.tile_sample_min_size = self.config.sample_size
|
||||
sample_size = (
|
||||
self.config.sample_size[0]
|
||||
if isinstance(self.config.sample_size, (list, tuple))
|
||||
else self.config.sample_size
|
||||
)
|
||||
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.ch_mult) - 1)))
|
||||
self.tile_overlap_factor = 0.25
|
||||
|
||||
def _encode(self, x: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, num_channels, height, width = x.shape
|
||||
|
||||
if self.use_tiling and (width > self.tile_sample_min_size or height > self.tile_sample_min_size):
|
||||
return self._tiled_encode(x)
|
||||
|
||||
enc = self.encoder(x)
|
||||
|
||||
return enc
|
||||
|
||||
@apply_forward_hook
|
||||
def encode(
|
||||
self, x: torch.Tensor, return_dict: bool = True
|
||||
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
|
||||
"""
|
||||
Encode a batch of images into latents.
|
||||
|
||||
Args:
|
||||
x (`torch.Tensor`): Input batch of images.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
The latent representations of the encoded images. If `return_dict` is True, a
|
||||
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
|
||||
"""
|
||||
if self.use_slicing and x.shape[0] > 1:
|
||||
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
|
||||
h = torch.cat(encoded_slices)
|
||||
else:
|
||||
h = self._encode(x)
|
||||
|
||||
posterior = DiagonalGaussianDistribution(h)
|
||||
|
||||
if not return_dict:
|
||||
return (posterior,)
|
||||
|
||||
return AutoencoderKLOutput(latent_dist=posterior)
|
||||
|
||||
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
||||
if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):
|
||||
return self.tiled_decode(z, return_dict=return_dict)
|
||||
|
||||
dec = self.decoder(z)
|
||||
|
||||
if not return_dict:
|
||||
return (dec,)
|
||||
|
||||
return DecoderOutput(sample=dec)
|
||||
|
||||
@apply_forward_hook
|
||||
def decode(
|
||||
self, z: torch.FloatTensor, return_dict: bool = True, generator=None
|
||||
) -> Union[DecoderOutput, torch.FloatTensor]:
|
||||
"""
|
||||
Decode a batch of images.
|
||||
|
||||
Args:
|
||||
z (`torch.Tensor`): Input batch of latent vectors.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
[`~models.vae.DecoderOutput`] or `tuple`:
|
||||
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
|
||||
returned.
|
||||
|
||||
"""
|
||||
if self.use_slicing and z.shape[0] > 1:
|
||||
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
|
||||
decoded = torch.cat(decoded_slices)
|
||||
else:
|
||||
decoded = self._decode(z).sample
|
||||
|
||||
if not return_dict:
|
||||
return (decoded,)
|
||||
|
||||
return DecoderOutput(sample=decoded)
|
||||
|
||||
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
||||
blend_extent = min(a.shape[2], b.shape[2], blend_extent)
|
||||
for y in range(blend_extent):
|
||||
b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent)
|
||||
return b
|
||||
|
||||
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
||||
blend_extent = min(a.shape[3], b.shape[3], blend_extent)
|
||||
for x in range(blend_extent):
|
||||
b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
|
||||
return b
|
||||
|
||||
def _tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
|
||||
r"""Encode a batch of images using a tiled encoder.
|
||||
|
||||
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
|
||||
steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
|
||||
different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
|
||||
tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
|
||||
output, but they should be much less noticeable.
|
||||
|
||||
Args:
|
||||
x (`torch.Tensor`): Input batch of images.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
The latent representation of the encoded videos.
|
||||
"""
|
||||
|
||||
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
|
||||
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
|
||||
row_limit = self.tile_latent_min_size - blend_extent
|
||||
|
||||
# Split the image into 512x512 tiles and encode them separately.
|
||||
rows = []
|
||||
for i in range(0, x.shape[2], overlap_size):
|
||||
row = []
|
||||
for j in range(0, x.shape[3], overlap_size):
|
||||
tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
|
||||
tile = self.encoder(tile)
|
||||
row.append(tile)
|
||||
rows.append(row)
|
||||
result_rows = []
|
||||
for i, row in enumerate(rows):
|
||||
result_row = []
|
||||
for j, tile in enumerate(row):
|
||||
# blend the above tile and the left tile
|
||||
# to the current tile and add the current tile to the result row
|
||||
if i > 0:
|
||||
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
|
||||
if j > 0:
|
||||
tile = self.blend_h(row[j - 1], tile, blend_extent)
|
||||
result_row.append(tile[:, :, :row_limit, :row_limit])
|
||||
result_rows.append(torch.cat(result_row, dim=3))
|
||||
|
||||
enc = torch.cat(result_rows, dim=2)
|
||||
return enc
|
||||
|
||||
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
||||
r"""
|
||||
Decode a batch of images using a tiled decoder.
|
||||
|
||||
Args:
|
||||
z (`torch.Tensor`): Input batch of latent vectors.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
[`~models.vae.DecoderOutput`] or `tuple`:
|
||||
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
|
||||
returned.
|
||||
"""
|
||||
overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
|
||||
blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
|
||||
row_limit = self.tile_sample_min_size - blend_extent
|
||||
|
||||
# Split z into overlapping 64x64 tiles and decode them separately.
|
||||
# The tiles have an overlap to avoid seams between tiles.
|
||||
rows = []
|
||||
for i in range(0, z.shape[2], overlap_size):
|
||||
row = []
|
||||
for j in range(0, z.shape[3], overlap_size):
|
||||
tile = z[:, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size]
|
||||
decoded = self.decoder(tile)
|
||||
row.append(decoded)
|
||||
rows.append(row)
|
||||
result_rows = []
|
||||
for i, row in enumerate(rows):
|
||||
result_row = []
|
||||
for j, tile in enumerate(row):
|
||||
# blend the above tile and the left tile
|
||||
# to the current tile and add the current tile to the result row
|
||||
if i > 0:
|
||||
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
|
||||
if j > 0:
|
||||
tile = self.blend_h(row[j - 1], tile, blend_extent)
|
||||
result_row.append(tile[:, :, :row_limit, :row_limit])
|
||||
result_rows.append(torch.cat(result_row, dim=3))
|
||||
|
||||
dec = torch.cat(result_rows, dim=2)
|
||||
if not return_dict:
|
||||
return (dec,)
|
||||
|
||||
return DecoderOutput(sample=dec)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
sample_posterior: bool = False,
|
||||
return_dict: bool = True,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
) -> Union[DecoderOutput, torch.Tensor]:
|
||||
r"""
|
||||
Args:
|
||||
sample (`torch.Tensor`): Input sample.
|
||||
sample_posterior (`bool`, *optional*, defaults to `False`):
|
||||
Whether to sample from the posterior.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
|
||||
"""
|
||||
x = sample
|
||||
posterior = self.encode(x).latent_dist
|
||||
if sample_posterior:
|
||||
z = posterior.sample(generator=generator)
|
||||
else:
|
||||
z = posterior.mode()
|
||||
dec = self.decode(z).sample
|
||||
|
||||
if not return_dict:
|
||||
return (dec,)
|
||||
|
||||
return DecoderOutput(sample=dec)
|
||||
954
src/diffusers/models/autoencoders/autoencoder_kl_kvae_video.py
Normal file
954
src/diffusers/models/autoencoders/autoencoder_kl_kvae_video.py
Normal file
@@ -0,0 +1,954 @@
|
||||
# Copyright 2025 The Kandinsky Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import math
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders.single_file_model import FromOriginalModelMixin
|
||||
from ...utils import logging
|
||||
from ...utils.accelerate_utils import apply_forward_hook
|
||||
from ..modeling_outputs import AutoencoderKLOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def nonlinearity(x: torch.Tensor) -> torch.Tensor:
|
||||
return F.silu(x)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Base layers
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class KVAESafeConv3d(nn.Conv3d):
|
||||
r"""
|
||||
A 3D convolution layer that splits the input tensor into smaller parts to avoid OOM.
|
||||
"""
|
||||
|
||||
def forward(self, input: torch.Tensor, write_to: torch.Tensor = None) -> torch.Tensor:
|
||||
memory_count = input.numel() * input.element_size() / (10**9)
|
||||
|
||||
if memory_count > 3:
|
||||
kernel_size = self.kernel_size[0]
|
||||
part_num = math.ceil(memory_count / 2)
|
||||
input_chunks = torch.chunk(input, part_num, dim=2)
|
||||
|
||||
if write_to is None:
|
||||
output = []
|
||||
for i, chunk in enumerate(input_chunks):
|
||||
if i == 0 or kernel_size == 1:
|
||||
z = torch.clone(chunk)
|
||||
else:
|
||||
z = torch.cat([z[:, :, -kernel_size + 1 :], chunk], dim=2)
|
||||
output.append(super().forward(z))
|
||||
return torch.cat(output, dim=2)
|
||||
else:
|
||||
time_offset = 0
|
||||
for i, chunk in enumerate(input_chunks):
|
||||
if i == 0 or kernel_size == 1:
|
||||
z = torch.clone(chunk)
|
||||
else:
|
||||
z = torch.cat([z[:, :, -kernel_size + 1 :], chunk], dim=2)
|
||||
z_time = z.size(2) - (kernel_size - 1)
|
||||
write_to[:, :, time_offset : time_offset + z_time] = super().forward(z)
|
||||
time_offset += z_time
|
||||
return write_to
|
||||
else:
|
||||
if write_to is None:
|
||||
return super().forward(input)
|
||||
else:
|
||||
write_to[...] = super().forward(input)
|
||||
return write_to
|
||||
|
||||
|
||||
class KVAECausalConv3d(nn.Module):
|
||||
r"""
|
||||
A 3D causal convolution layer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
chan_in: int,
|
||||
chan_out: int,
|
||||
kernel_size: Union[int, Tuple[int, int, int]],
|
||||
stride: Tuple[int, int, int] = (1, 1, 1),
|
||||
dilation: Tuple[int, int, int] = (1, 1, 1),
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
if isinstance(kernel_size, int):
|
||||
kernel_size = (kernel_size, kernel_size, kernel_size)
|
||||
|
||||
time_kernel_size, height_kernel_size, width_kernel_size = kernel_size
|
||||
|
||||
self.height_pad = height_kernel_size // 2
|
||||
self.width_pad = width_kernel_size // 2
|
||||
self.time_pad = time_kernel_size - 1
|
||||
self.time_kernel_size = time_kernel_size
|
||||
self.stride = stride
|
||||
|
||||
self.conv = KVAESafeConv3d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs)
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
padding_3d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad, self.time_pad, 0)
|
||||
input_padded = F.pad(input, padding_3d, mode="replicate")
|
||||
return self.conv(input_padded)
|
||||
|
||||
|
||||
class KVAECachedCausalConv3d(nn.Module):
|
||||
r"""
|
||||
A 3D causal convolution layer with caching for temporal processing.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
chan_in: int,
|
||||
chan_out: int,
|
||||
kernel_size: Union[int, Tuple[int, int, int]],
|
||||
stride: Tuple[int, int, int] = (1, 1, 1),
|
||||
dilation: Tuple[int, int, int] = (1, 1, 1),
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
if isinstance(kernel_size, int):
|
||||
kernel_size = (kernel_size, kernel_size, kernel_size)
|
||||
|
||||
time_kernel_size, height_kernel_size, width_kernel_size = kernel_size
|
||||
|
||||
self.height_pad = height_kernel_size // 2
|
||||
self.width_pad = width_kernel_size // 2
|
||||
self.time_pad = time_kernel_size - 1
|
||||
self.time_kernel_size = time_kernel_size
|
||||
self.stride = stride
|
||||
|
||||
self.conv = KVAESafeConv3d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs)
|
||||
|
||||
def forward(self, input: torch.Tensor, cache: Dict) -> torch.Tensor:
|
||||
t_stride = self.stride[0]
|
||||
padding_3d = (self.height_pad, self.height_pad, self.width_pad, self.width_pad, 0, 0)
|
||||
input_parallel = F.pad(input, padding_3d, mode="replicate")
|
||||
|
||||
if cache["padding"] is None:
|
||||
first_frame = input_parallel[:, :, :1]
|
||||
time_pad_shape = list(first_frame.shape)
|
||||
time_pad_shape[2] = self.time_pad
|
||||
padding = first_frame.expand(time_pad_shape)
|
||||
else:
|
||||
padding = cache["padding"]
|
||||
|
||||
out_size = list(input.shape)
|
||||
out_size[1] = self.conv.out_channels
|
||||
if t_stride == 2:
|
||||
out_size[2] = (input.size(2) + 1) // 2
|
||||
output = torch.empty(tuple(out_size), dtype=input.dtype, device=input.device)
|
||||
|
||||
offset_out = math.ceil(padding.size(2) / t_stride)
|
||||
offset_in = offset_out * t_stride - padding.size(2)
|
||||
|
||||
if offset_out > 0:
|
||||
padding_poisoned = torch.cat(
|
||||
[padding, input_parallel[:, :, : offset_in + self.time_kernel_size - t_stride]], dim=2
|
||||
)
|
||||
output[:, :, :offset_out] = self.conv(padding_poisoned)
|
||||
|
||||
if offset_out < output.size(2):
|
||||
output[:, :, offset_out:] = self.conv(input_parallel[:, :, offset_in:])
|
||||
|
||||
pad_offset = (
|
||||
offset_in
|
||||
+ t_stride * math.trunc((input_parallel.size(2) - offset_in - self.time_kernel_size) / t_stride)
|
||||
+ t_stride
|
||||
)
|
||||
cache["padding"] = torch.clone(input_parallel[:, :, pad_offset:])
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class KVAECachedGroupNorm(nn.Module):
|
||||
r"""
|
||||
GroupNorm with caching support for temporal processing.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels: int):
|
||||
super().__init__()
|
||||
self.norm_layer = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
|
||||
def forward(self, x: torch.Tensor, cache: Dict = None) -> torch.Tensor:
|
||||
out = self.norm_layer(x)
|
||||
if cache is not None and cache.get("mean") is None and cache.get("var") is None:
|
||||
cache["mean"] = 1
|
||||
cache["var"] = 1
|
||||
return out
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Cached layers
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class KVAECachedSpatialNorm3D(nn.Module):
|
||||
r"""
|
||||
Spatially conditioned normalization for decoder with caching.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
f_channels: int,
|
||||
zq_channels: int,
|
||||
add_conv: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.norm_layer = KVAECachedGroupNorm(f_channels)
|
||||
self.add_conv = add_conv
|
||||
|
||||
if add_conv:
|
||||
self.conv = KVAECachedCausalConv3d(chan_in=zq_channels, chan_out=zq_channels, kernel_size=3)
|
||||
|
||||
self.conv_y = KVAESafeConv3d(zq_channels, f_channels, kernel_size=1)
|
||||
self.conv_b = KVAESafeConv3d(zq_channels, f_channels, kernel_size=1)
|
||||
|
||||
def forward(self, f: torch.Tensor, zq: torch.Tensor, cache: Dict) -> torch.Tensor:
|
||||
if cache["norm"].get("mean") is None and cache["norm"].get("var") is None:
|
||||
f_first, f_rest = f[:, :, :1], f[:, :, 1:]
|
||||
f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:]
|
||||
zq_first, zq_rest = zq[:, :, :1], zq[:, :, 1:]
|
||||
|
||||
zq_first = F.interpolate(zq_first, size=f_first_size, mode="nearest")
|
||||
|
||||
if zq.size(2) > 1:
|
||||
zq_rest_splits = torch.split(zq_rest, 32, dim=1)
|
||||
interpolated_splits = [
|
||||
F.interpolate(split, size=f_rest_size, mode="nearest") for split in zq_rest_splits
|
||||
]
|
||||
zq_rest = torch.cat(interpolated_splits, dim=1)
|
||||
zq = torch.cat([zq_first, zq_rest], dim=2)
|
||||
else:
|
||||
zq = zq_first
|
||||
else:
|
||||
f_size = f.shape[-3:]
|
||||
zq_splits = torch.split(zq, 32, dim=1)
|
||||
interpolated_splits = [F.interpolate(split, size=f_size, mode="nearest") for split in zq_splits]
|
||||
zq = torch.cat(interpolated_splits, dim=1)
|
||||
|
||||
if self.add_conv:
|
||||
zq = self.conv(zq, cache["add_conv"])
|
||||
|
||||
norm_f = self.norm_layer(f, cache["norm"])
|
||||
norm_f = norm_f * self.conv_y(zq)
|
||||
norm_f = norm_f + self.conv_b(zq)
|
||||
|
||||
return norm_f
|
||||
|
||||
|
||||
class KVAECachedResnetBlock3D(nn.Module):
|
||||
r"""
|
||||
A 3D ResNet block with caching.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: Optional[int] = None,
|
||||
conv_shortcut: bool = False,
|
||||
dropout: float = 0.0,
|
||||
temb_channels: int = 0,
|
||||
zq_ch: Optional[int] = None,
|
||||
add_conv: bool = False,
|
||||
gather_norm: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
self.out_channels = out_channels
|
||||
self.use_conv_shortcut = conv_shortcut
|
||||
|
||||
if zq_ch is None:
|
||||
self.norm1 = KVAECachedGroupNorm(in_channels)
|
||||
else:
|
||||
self.norm1 = KVAECachedSpatialNorm3D(in_channels, zq_ch, add_conv=add_conv)
|
||||
|
||||
self.conv1 = KVAECachedCausalConv3d(chan_in=in_channels, chan_out=out_channels, kernel_size=3)
|
||||
|
||||
if temb_channels > 0:
|
||||
self.temb_proj = nn.Linear(temb_channels, out_channels)
|
||||
|
||||
if zq_ch is None:
|
||||
self.norm2 = KVAECachedGroupNorm(out_channels)
|
||||
else:
|
||||
self.norm2 = KVAECachedSpatialNorm3D(out_channels, zq_ch, add_conv=add_conv)
|
||||
|
||||
self.conv2 = KVAECachedCausalConv3d(chan_in=out_channels, chan_out=out_channels, kernel_size=3)
|
||||
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
self.conv_shortcut = KVAECachedCausalConv3d(chan_in=in_channels, chan_out=out_channels, kernel_size=3)
|
||||
else:
|
||||
self.nin_shortcut = KVAESafeConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
def forward(self, x: torch.Tensor, temb: torch.Tensor, layer_cache: Dict, zq: torch.Tensor = None) -> torch.Tensor:
|
||||
h = x
|
||||
|
||||
if zq is None:
|
||||
# Encoder path - norm takes cache
|
||||
h = self.norm1(h, cache=layer_cache["norm1"])
|
||||
else:
|
||||
# Decoder path - spatial norm takes zq and cache
|
||||
h = self.norm1(h, zq, cache=layer_cache["norm1"])
|
||||
|
||||
h = F.silu(h)
|
||||
h = self.conv1(h, cache=layer_cache["conv1"])
|
||||
|
||||
if temb is not None:
|
||||
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None, None]
|
||||
|
||||
if zq is None:
|
||||
h = self.norm2(h, cache=layer_cache["norm2"])
|
||||
else:
|
||||
h = self.norm2(h, zq, cache=layer_cache["norm2"])
|
||||
|
||||
h = F.silu(h)
|
||||
h = self.conv2(h, cache=layer_cache["conv2"])
|
||||
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
x = self.conv_shortcut(x, cache=layer_cache["conv_shortcut"])
|
||||
else:
|
||||
x = self.nin_shortcut(x)
|
||||
|
||||
return x + h
|
||||
|
||||
|
||||
class KVAECachedPXSDownsample(nn.Module):
|
||||
r"""
|
||||
A 3D downsampling layer using PixelUnshuffle with caching.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels: int, compress_time: bool, factor: int = 2):
|
||||
super().__init__()
|
||||
self.temporal_compress = compress_time
|
||||
self.factor = factor
|
||||
self.unshuffle = nn.PixelUnshuffle(self.factor)
|
||||
self.s_pool = nn.AvgPool3d((1, 2, 2), (1, 2, 2))
|
||||
|
||||
self.spatial_conv = KVAESafeConv3d(
|
||||
in_channels,
|
||||
in_channels,
|
||||
kernel_size=(1, 3, 3),
|
||||
stride=(1, 2, 2),
|
||||
padding=(0, 1, 1),
|
||||
padding_mode="reflect",
|
||||
)
|
||||
|
||||
if self.temporal_compress:
|
||||
self.temporal_conv = KVAECachedCausalConv3d(
|
||||
in_channels, in_channels, kernel_size=(3, 1, 1), stride=(2, 1, 1), dilation=(1, 1, 1)
|
||||
)
|
||||
|
||||
self.linear = nn.Conv3d(in_channels, in_channels, kernel_size=1, stride=1)
|
||||
|
||||
def spatial_downsample(self, input: torch.Tensor) -> torch.Tensor:
|
||||
b, c, t, h, w = input.shape
|
||||
pxs_input = input.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
|
||||
# pxs_input = rearrange(input, 'b c t h w -> (b t) c h w')
|
||||
pxs_interm = self.unshuffle(pxs_input)
|
||||
b_it, c_it, h_it, w_it = pxs_interm.shape
|
||||
pxs_interm_view = pxs_interm.view(b_it, c_it // self.factor**2, self.factor**2, h_it, w_it)
|
||||
pxs_out = torch.mean(pxs_interm_view, dim=2)
|
||||
pxs_out = pxs_out.view(b, t, -1, h_it, w_it).permute(0, 2, 1, 3, 4)
|
||||
# pxs_out = rearrange(pxs_out, '(b t) c h w -> b c t h w', t=input.size(2))
|
||||
conv_out = self.spatial_conv(input)
|
||||
return conv_out + pxs_out
|
||||
|
||||
def temporal_downsample(self, input: torch.Tensor, cache: list) -> torch.Tensor:
|
||||
b, c, t, h, w = input.shape
|
||||
|
||||
permuted = input.permute(0, 3, 4, 1, 2).reshape(b * h * w, c, t)
|
||||
|
||||
if cache[0]["padding"] is None:
|
||||
first, rest = permuted[..., :1], permuted[..., 1:]
|
||||
if rest.size(-1) > 0:
|
||||
rest_interp = F.avg_pool1d(rest, kernel_size=2, stride=2)
|
||||
full_interp = torch.cat([first, rest_interp], dim=-1)
|
||||
else:
|
||||
full_interp = first
|
||||
else:
|
||||
rest = permuted
|
||||
if rest.size(-1) > 0:
|
||||
full_interp = F.avg_pool1d(rest, kernel_size=2, stride=2)
|
||||
|
||||
t_new = full_interp.size(-1)
|
||||
full_interp = full_interp.view(b, h, w, c, t_new).permute(0, 3, 4, 1, 2)
|
||||
conv_out = self.temporal_conv(input, cache[0])
|
||||
return conv_out + full_interp
|
||||
|
||||
def forward(self, x: torch.Tensor, cache: list) -> torch.Tensor:
|
||||
out = self.spatial_downsample(x)
|
||||
|
||||
if self.temporal_compress:
|
||||
out = self.temporal_downsample(out, cache=cache)
|
||||
|
||||
return self.linear(out)
|
||||
|
||||
|
||||
class KVAECachedPXSUpsample(nn.Module):
|
||||
r"""
|
||||
A 3D upsampling layer using PixelShuffle with caching.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels: int, compress_time: bool, factor: int = 2):
|
||||
super().__init__()
|
||||
self.temporal_compress = compress_time
|
||||
self.factor = factor
|
||||
self.shuffle = nn.PixelShuffle(self.factor)
|
||||
|
||||
self.spatial_conv = KVAESafeConv3d(
|
||||
in_channels,
|
||||
in_channels,
|
||||
kernel_size=(1, 3, 3),
|
||||
stride=(1, 1, 1),
|
||||
padding=(0, 1, 1),
|
||||
padding_mode="reflect",
|
||||
)
|
||||
|
||||
if self.temporal_compress:
|
||||
self.temporal_conv = KVAECachedCausalConv3d(
|
||||
in_channels, in_channels, kernel_size=(3, 1, 1), stride=(1, 1, 1), dilation=(1, 1, 1)
|
||||
)
|
||||
|
||||
self.linear = KVAESafeConv3d(in_channels, in_channels, kernel_size=1, stride=1)
|
||||
|
||||
def spatial_upsample(self, input: torch.Tensor) -> torch.Tensor:
|
||||
b, c, t, h, w = input.shape
|
||||
input_view = input.permute(0, 2, 1, 3, 4).reshape(b, t * c, h, w)
|
||||
input_interp = F.interpolate(input_view, scale_factor=2, mode="nearest")
|
||||
input_interp = input_interp.view(b, t, c, 2 * h, 2 * w).permute(0, 2, 1, 3, 4)
|
||||
|
||||
out = self.spatial_conv(input_interp)
|
||||
return input_interp + out
|
||||
|
||||
def temporal_upsample(self, input: torch.Tensor, cache: Dict) -> torch.Tensor:
|
||||
time_factor = 1.0 + 1.0 * (input.size(2) > 1)
|
||||
if isinstance(time_factor, torch.Tensor):
|
||||
time_factor = time_factor.item()
|
||||
|
||||
repeated = input.repeat_interleave(int(time_factor), dim=2)
|
||||
|
||||
if cache["padding"] is None:
|
||||
tail = repeated[..., int(time_factor - 1) :, :, :]
|
||||
else:
|
||||
tail = repeated
|
||||
|
||||
conv_out = self.temporal_conv(tail, cache)
|
||||
return conv_out + tail
|
||||
|
||||
def forward(self, x: torch.Tensor, cache: Dict) -> torch.Tensor:
|
||||
if self.temporal_compress:
|
||||
x = self.temporal_upsample(x, cache)
|
||||
|
||||
s_out = self.spatial_upsample(x)
|
||||
to = torch.empty_like(s_out)
|
||||
lin_out = self.linear(s_out, write_to=to)
|
||||
return lin_out
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Cached Encoder/Decoder
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class KVAECachedEncoder3D(nn.Module):
|
||||
r"""
|
||||
Cached 3D Encoder for KVAE.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ch: int = 128,
|
||||
ch_mult: Tuple[int, ...] = (1, 2, 4, 8),
|
||||
num_res_blocks: int = 2,
|
||||
dropout: float = 0.0,
|
||||
in_channels: int = 3,
|
||||
z_channels: int = 16,
|
||||
double_z: bool = True,
|
||||
temporal_compress_times: int = 4,
|
||||
):
|
||||
super().__init__()
|
||||
self.ch = ch
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.in_channels = in_channels
|
||||
self.temporal_compress_level = int(np.log2(temporal_compress_times))
|
||||
|
||||
self.conv_in = KVAECachedCausalConv3d(chan_in=in_channels, chan_out=self.ch, kernel_size=3)
|
||||
|
||||
in_ch_mult = (1,) + tuple(ch_mult)
|
||||
self.down = nn.ModuleList()
|
||||
block_in = ch
|
||||
|
||||
for i_level in range(self.num_resolutions):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
|
||||
block_in = ch * in_ch_mult[i_level]
|
||||
block_out = ch * ch_mult[i_level]
|
||||
|
||||
for i_block in range(self.num_res_blocks):
|
||||
block.append(
|
||||
KVAECachedResnetBlock3D(
|
||||
in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
dropout=dropout,
|
||||
temb_channels=self.temb_ch,
|
||||
)
|
||||
)
|
||||
block_in = block_out
|
||||
|
||||
down = nn.Module()
|
||||
down.block = block
|
||||
down.attn = attn
|
||||
|
||||
if i_level != self.num_resolutions - 1:
|
||||
if i_level < self.temporal_compress_level:
|
||||
down.downsample = KVAECachedPXSDownsample(block_in, compress_time=True)
|
||||
else:
|
||||
down.downsample = KVAECachedPXSDownsample(block_in, compress_time=False)
|
||||
self.down.append(down)
|
||||
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = KVAECachedResnetBlock3D(
|
||||
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
|
||||
)
|
||||
self.mid.block_2 = KVAECachedResnetBlock3D(
|
||||
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
|
||||
)
|
||||
|
||||
self.norm_out = KVAECachedGroupNorm(block_in)
|
||||
self.conv_out = KVAECachedCausalConv3d(
|
||||
chan_in=block_in, chan_out=2 * z_channels if double_z else z_channels, kernel_size=3
|
||||
)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, x: torch.Tensor, cache_dict: Dict) -> torch.Tensor:
|
||||
temb = None
|
||||
|
||||
h = self.conv_in(x, cache=cache_dict["conv_in"])
|
||||
|
||||
for i_level in range(self.num_resolutions):
|
||||
for i_block in range(self.num_res_blocks):
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
h = self._gradient_checkpointing_func(
|
||||
self.down[i_level].block[i_block], h, temb, cache_dict[i_level][i_block]
|
||||
)
|
||||
else:
|
||||
h = self.down[i_level].block[i_block](h, temb, layer_cache=cache_dict[i_level][i_block])
|
||||
if len(self.down[i_level].attn) > 0:
|
||||
h = self.down[i_level].attn[i_block](h)
|
||||
if i_level != self.num_resolutions - 1:
|
||||
h = self.down[i_level].downsample(h, cache=cache_dict[i_level]["down"])
|
||||
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
h = self._gradient_checkpointing_func(self.mid.block_1, h, temb, cache_dict["mid_1"])
|
||||
h = self._gradient_checkpointing_func(self.mid.block_2, h, temb, cache_dict["mid_2"])
|
||||
else:
|
||||
h = self.mid.block_1(h, temb, layer_cache=cache_dict["mid_1"])
|
||||
h = self.mid.block_2(h, temb, layer_cache=cache_dict["mid_2"])
|
||||
|
||||
h = self.norm_out(h, cache=cache_dict["norm_out"])
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h, cache=cache_dict["conv_out"])
|
||||
|
||||
return h
|
||||
|
||||
|
||||
class KVAECachedDecoder3D(nn.Module):
|
||||
r"""
|
||||
Cached 3D Decoder for KVAE.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ch: int = 128,
|
||||
out_ch: int = 3,
|
||||
ch_mult: Tuple[int, ...] = (1, 2, 4, 8),
|
||||
num_res_blocks: int = 2,
|
||||
dropout: float = 0.0,
|
||||
z_channels: int = 16,
|
||||
zq_ch: Optional[int] = None,
|
||||
add_conv: bool = False,
|
||||
temporal_compress_times: int = 4,
|
||||
):
|
||||
super().__init__()
|
||||
self.ch = ch
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.temporal_compress_level = int(np.log2(temporal_compress_times))
|
||||
|
||||
if zq_ch is None:
|
||||
zq_ch = z_channels
|
||||
|
||||
block_in = ch * ch_mult[self.num_resolutions - 1]
|
||||
|
||||
self.conv_in = KVAECachedCausalConv3d(chan_in=z_channels, chan_out=block_in, kernel_size=3)
|
||||
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = KVAECachedResnetBlock3D(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
zq_ch=zq_ch,
|
||||
add_conv=add_conv,
|
||||
)
|
||||
self.mid.block_2 = KVAECachedResnetBlock3D(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
zq_ch=zq_ch,
|
||||
add_conv=add_conv,
|
||||
)
|
||||
|
||||
self.up = nn.ModuleList()
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_out = ch * ch_mult[i_level]
|
||||
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
block.append(
|
||||
KVAECachedResnetBlock3D(
|
||||
in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
zq_ch=zq_ch,
|
||||
add_conv=add_conv,
|
||||
)
|
||||
)
|
||||
block_in = block_out
|
||||
|
||||
up = nn.Module()
|
||||
up.block = block
|
||||
up.attn = attn
|
||||
|
||||
if i_level != 0:
|
||||
if i_level < self.num_resolutions - self.temporal_compress_level:
|
||||
up.upsample = KVAECachedPXSUpsample(block_in, compress_time=False)
|
||||
else:
|
||||
up.upsample = KVAECachedPXSUpsample(block_in, compress_time=True)
|
||||
self.up.insert(0, up)
|
||||
|
||||
self.norm_out = KVAECachedSpatialNorm3D(block_in, zq_ch, add_conv=add_conv)
|
||||
self.conv_out = KVAECachedCausalConv3d(chan_in=block_in, chan_out=out_ch, kernel_size=3)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, z: torch.Tensor, cache_dict: Dict) -> torch.Tensor:
|
||||
temb = None
|
||||
zq = z
|
||||
|
||||
h = self.conv_in(z, cache_dict["conv_in"])
|
||||
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
h = self._gradient_checkpointing_func(self.mid.block_1, h, temb, cache_dict["mid_1"], zq)
|
||||
h = self._gradient_checkpointing_func(self.mid.block_2, h, temb, cache_dict["mid_2"], zq)
|
||||
else:
|
||||
h = self.mid.block_1(h, temb, layer_cache=cache_dict["mid_1"], zq=zq)
|
||||
h = self.mid.block_2(h, temb, layer_cache=cache_dict["mid_2"], zq=zq)
|
||||
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
h = self._gradient_checkpointing_func(
|
||||
self.up[i_level].block[i_block], h, temb, cache_dict[i_level][i_block], zq
|
||||
)
|
||||
else:
|
||||
h = self.up[i_level].block[i_block](h, temb, layer_cache=cache_dict[i_level][i_block], zq=zq)
|
||||
if len(self.up[i_level].attn) > 0:
|
||||
h = self.up[i_level].attn[i_block](h, zq)
|
||||
if i_level != 0:
|
||||
h = self.up[i_level].upsample(h, cache_dict[i_level]["up"])
|
||||
|
||||
h = self.norm_out(h, zq, cache_dict["norm_out"])
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h, cache_dict["conv_out"])
|
||||
|
||||
return h
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Main AutoencoderKL class
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class AutoencoderKLKVAEVideo(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
r"""
|
||||
A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. Used in
|
||||
[KVAE](https://github.com/kandinskylab/kvae-1).
|
||||
|
||||
This model inherits from [`ModelMixin`]. Check the superclass documentation for its generic methods implemented for
|
||||
all models (such as downloading or saving).
|
||||
|
||||
Parameters:
|
||||
ch (`int`, *optional*, defaults to 128): Base channel count.
|
||||
ch_mult (`Tuple[int]`, *optional*, defaults to `(1, 2, 4, 8)`): Channel multipliers per level.
|
||||
num_res_blocks (`int`, *optional*, defaults to 2): Number of residual blocks per level.
|
||||
in_channels (`int`, *optional*, defaults to 3): Number of input channels.
|
||||
out_ch (`int`, *optional*, defaults to 3): Number of output channels.
|
||||
z_channels (`int`, *optional*, defaults to 16): Number of latent channels.
|
||||
temporal_compress_times (`int`, *optional*, defaults to 4): Temporal compression factor.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["KVAECachedResnetBlock3D"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
ch: int = 128,
|
||||
ch_mult: Tuple[int, ...] = (1, 2, 4, 8),
|
||||
num_res_blocks: int = 2,
|
||||
in_channels: int = 3,
|
||||
out_ch: int = 3,
|
||||
z_channels: int = 16,
|
||||
temporal_compress_times: int = 4,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.encoder = KVAECachedEncoder3D(
|
||||
ch=ch,
|
||||
ch_mult=ch_mult,
|
||||
num_res_blocks=num_res_blocks,
|
||||
in_channels=in_channels,
|
||||
z_channels=z_channels,
|
||||
double_z=True,
|
||||
temporal_compress_times=temporal_compress_times,
|
||||
)
|
||||
|
||||
self.decoder = KVAECachedDecoder3D(
|
||||
ch=ch,
|
||||
ch_mult=ch_mult,
|
||||
num_res_blocks=num_res_blocks,
|
||||
out_ch=out_ch,
|
||||
z_channels=z_channels,
|
||||
temporal_compress_times=temporal_compress_times,
|
||||
)
|
||||
|
||||
self.use_slicing = False
|
||||
self.use_tiling = False
|
||||
|
||||
def _make_encoder_cache(self) -> Dict:
|
||||
"""Create empty cache for cached encoder."""
|
||||
|
||||
def make_dict(name, p=None):
|
||||
if name == "conv":
|
||||
return {"padding": None}
|
||||
|
||||
layer, module = name.split("_")
|
||||
if layer == "norm":
|
||||
if module == "enc":
|
||||
return {"mean": None, "var": None}
|
||||
else:
|
||||
return {"norm": make_dict("norm_enc"), "add_conv": make_dict("conv")}
|
||||
elif layer == "resblock":
|
||||
return {
|
||||
"norm1": make_dict(f"norm_{module}"),
|
||||
"norm2": make_dict(f"norm_{module}"),
|
||||
"conv1": make_dict("conv"),
|
||||
"conv2": make_dict("conv"),
|
||||
"conv_shortcut": make_dict("conv"),
|
||||
}
|
||||
elif layer.isdigit():
|
||||
out_dict = {"down": [make_dict("conv"), make_dict("conv")], "up": make_dict("conv")}
|
||||
for i in range(p):
|
||||
out_dict[i] = make_dict(f"resblock_{module}")
|
||||
return out_dict
|
||||
|
||||
cache = {
|
||||
"conv_in": make_dict("conv"),
|
||||
"mid_1": make_dict("resblock_enc"),
|
||||
"mid_2": make_dict("resblock_enc"),
|
||||
"norm_out": make_dict("norm_enc"),
|
||||
"conv_out": make_dict("conv"),
|
||||
}
|
||||
# Encoder uses num_res_blocks per level
|
||||
for i in range(len(self.config.ch_mult)):
|
||||
cache[i] = make_dict(f"{i}_enc", p=self.config.num_res_blocks)
|
||||
return cache
|
||||
|
||||
def _make_decoder_cache(self) -> Dict:
|
||||
"""Create empty cache for decoder."""
|
||||
|
||||
def make_dict(name, p=None):
|
||||
if name == "conv":
|
||||
return {"padding": None}
|
||||
|
||||
layer, module = name.split("_")
|
||||
if layer == "norm":
|
||||
if module == "enc":
|
||||
return {"mean": None, "var": None}
|
||||
else:
|
||||
return {"norm": make_dict("norm_enc"), "add_conv": make_dict("conv")}
|
||||
elif layer == "resblock":
|
||||
return {
|
||||
"norm1": make_dict(f"norm_{module}"),
|
||||
"norm2": make_dict(f"norm_{module}"),
|
||||
"conv1": make_dict("conv"),
|
||||
"conv2": make_dict("conv"),
|
||||
"conv_shortcut": make_dict("conv"),
|
||||
}
|
||||
elif layer.isdigit():
|
||||
out_dict = {"down": [make_dict("conv"), make_dict("conv")], "up": make_dict("conv")}
|
||||
for i in range(p):
|
||||
out_dict[i] = make_dict(f"resblock_{module}")
|
||||
return out_dict
|
||||
|
||||
cache = {
|
||||
"conv_in": make_dict("conv"),
|
||||
"mid_1": make_dict("resblock_dec"),
|
||||
"mid_2": make_dict("resblock_dec"),
|
||||
"norm_out": make_dict("norm_dec"),
|
||||
"conv_out": make_dict("conv"),
|
||||
}
|
||||
for i in range(len(self.config.ch_mult)):
|
||||
cache[i] = make_dict(f"{i}_dec", p=self.config.num_res_blocks + 1)
|
||||
return cache
|
||||
|
||||
def enable_slicing(self) -> None:
|
||||
r"""Enable sliced VAE decoding."""
|
||||
self.use_slicing = True
|
||||
|
||||
def disable_slicing(self) -> None:
|
||||
r"""Disable sliced VAE decoding."""
|
||||
self.use_slicing = False
|
||||
|
||||
def _encode(self, x: torch.Tensor, seg_len: int = 16) -> torch.Tensor:
|
||||
# Cached encoder processes by segments
|
||||
cache = self._make_encoder_cache()
|
||||
|
||||
split_list = [seg_len + 1]
|
||||
n_frames = x.size(2) - (seg_len + 1)
|
||||
while n_frames > 0:
|
||||
split_list.append(seg_len)
|
||||
n_frames -= seg_len
|
||||
split_list[-1] += n_frames
|
||||
|
||||
latent = []
|
||||
for chunk in torch.split(x, split_list, dim=2):
|
||||
l = self.encoder(chunk, cache)
|
||||
sample, _ = torch.chunk(l, 2, dim=1)
|
||||
latent.append(sample)
|
||||
|
||||
return torch.cat(latent, dim=2)
|
||||
|
||||
@apply_forward_hook
|
||||
def encode(
|
||||
self, x: torch.Tensor, return_dict: bool = True
|
||||
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
|
||||
"""
|
||||
Encode a batch of videos into latents.
|
||||
|
||||
Args:
|
||||
x (`torch.Tensor`): Input batch of videos with shape (B, C, T, H, W).
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
The latent representations of the encoded videos.
|
||||
"""
|
||||
if self.use_slicing and x.shape[0] > 1:
|
||||
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
|
||||
h = torch.cat(encoded_slices)
|
||||
else:
|
||||
h = self._encode(x)
|
||||
|
||||
# For cached encoder, we already did the split in _encode
|
||||
h_double = torch.cat([h, torch.zeros_like(h)], dim=1)
|
||||
posterior = DiagonalGaussianDistribution(h_double)
|
||||
|
||||
if not return_dict:
|
||||
return (posterior,)
|
||||
return AutoencoderKLOutput(latent_dist=posterior)
|
||||
|
||||
def _decode(self, z: torch.Tensor, seg_len: int = 16) -> torch.Tensor:
|
||||
cache = self._make_decoder_cache()
|
||||
temporal_compress = self.config.temporal_compress_times
|
||||
|
||||
split_list = [seg_len + 1]
|
||||
n_frames = temporal_compress * (z.size(2) - 1) - seg_len
|
||||
while n_frames > 0:
|
||||
split_list.append(seg_len)
|
||||
n_frames -= seg_len
|
||||
split_list[-1] += n_frames
|
||||
split_list = [math.ceil(size / temporal_compress) for size in split_list]
|
||||
|
||||
recs = []
|
||||
for chunk in torch.split(z, split_list, dim=2):
|
||||
out = self.decoder(chunk, cache)
|
||||
recs.append(out)
|
||||
|
||||
return torch.cat(recs, dim=2)
|
||||
|
||||
@apply_forward_hook
|
||||
def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
||||
"""
|
||||
Decode a batch of videos.
|
||||
|
||||
Args:
|
||||
z (`torch.Tensor`): Input batch of latent vectors with shape (B, C, T, H, W).
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
[`~models.vae.DecoderOutput`] or `tuple`: Decoded video.
|
||||
"""
|
||||
if self.use_slicing and z.shape[0] > 1:
|
||||
decoded_slices = [self._decode(z_slice) for z_slice in z.split(1)]
|
||||
decoded = torch.cat(decoded_slices)
|
||||
else:
|
||||
decoded = self._decode(z)
|
||||
|
||||
if not return_dict:
|
||||
return (decoded,)
|
||||
return DecoderOutput(sample=decoded)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
sample_posterior: bool = False,
|
||||
return_dict: bool = True,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
) -> Union[DecoderOutput, torch.Tensor]:
|
||||
x = sample
|
||||
posterior = self.encode(x).latent_dist
|
||||
if sample_posterior:
|
||||
z = posterior.sample(generator=generator)
|
||||
else:
|
||||
z = posterior.mode()
|
||||
dec = self.decode(z).sample
|
||||
if not return_dict:
|
||||
return (dec,)
|
||||
return DecoderOutput(sample=dec)
|
||||
@@ -105,7 +105,14 @@ class QwenImageRMS_norm(nn.Module):
|
||||
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
|
||||
|
||||
def forward(self, x):
|
||||
return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias
|
||||
needs_fp32_normalize = x.dtype in (torch.float16, torch.bfloat16) or any(
|
||||
t in str(x.dtype) for t in ("float4_", "float8_")
|
||||
)
|
||||
normalized = F.normalize(x.float() if needs_fp32_normalize else x, dim=(1 if self.channel_first else -1)).to(
|
||||
x.dtype
|
||||
)
|
||||
|
||||
return normalized * self.scale * self.gamma + self.bias
|
||||
|
||||
|
||||
class QwenImageUpsample(nn.Upsample):
|
||||
|
||||
@@ -196,7 +196,14 @@ class WanRMS_norm(nn.Module):
|
||||
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
|
||||
|
||||
def forward(self, x):
|
||||
return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias
|
||||
needs_fp32_normalize = x.dtype in (torch.float16, torch.bfloat16) or any(
|
||||
t in str(x.dtype) for t in ("float4_", "float8_")
|
||||
)
|
||||
normalized = F.normalize(x.float() if needs_fp32_normalize else x, dim=(1 if self.channel_first else -1)).to(
|
||||
x.dtype
|
||||
)
|
||||
|
||||
return normalized * self.scale * self.gamma + self.bias
|
||||
|
||||
|
||||
class WanUpsample(nn.Upsample):
|
||||
|
||||
@@ -12,7 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import functools
|
||||
import math
|
||||
from math import prod
|
||||
from typing import Any
|
||||
@@ -25,7 +24,7 @@ import torch.nn.functional as F
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import apply_lora_scale, deprecate, logging
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ...utils.torch_utils import lru_cache_unless_export, maybe_allow_in_graph
|
||||
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
|
||||
from ..attention import AttentionMixin, FeedForward
|
||||
from ..attention_dispatch import dispatch_attention_fn
|
||||
@@ -307,7 +306,7 @@ class QwenEmbedRope(nn.Module):
|
||||
|
||||
return vid_freqs, txt_freqs
|
||||
|
||||
@functools.lru_cache(maxsize=128)
|
||||
@lru_cache_unless_export(maxsize=128)
|
||||
def _compute_video_freqs(
|
||||
self, frame: int, height: int, width: int, idx: int = 0, device: torch.device = None
|
||||
) -> torch.Tensor:
|
||||
@@ -428,7 +427,7 @@ class QwenEmbedLayer3DRope(nn.Module):
|
||||
|
||||
return vid_freqs, txt_freqs
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
@lru_cache_unless_export(maxsize=None)
|
||||
def _compute_video_freqs(self, frame, height, width, idx=0, device: torch.device = None):
|
||||
seq_lens = frame * height * width
|
||||
pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs
|
||||
@@ -450,7 +449,7 @@ class QwenEmbedLayer3DRope(nn.Module):
|
||||
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
|
||||
return freqs.clone().contiguous()
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
@lru_cache_unless_export(maxsize=None)
|
||||
def _compute_condition_freqs(self, frame, height, width, device: torch.device = None):
|
||||
seq_lens = frame * height * width
|
||||
pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs
|
||||
@@ -934,6 +933,7 @@ class QwenImageTransformer2DModel(
|
||||
batch_size, image_seq_len = hidden_states.shape[:2]
|
||||
image_mask = torch.ones((batch_size, image_seq_len), dtype=torch.bool, device=hidden_states.device)
|
||||
joint_attention_mask = torch.cat([encoder_hidden_states_mask, image_mask], dim=1)
|
||||
joint_attention_mask = joint_attention_mask[:, None, None, :]
|
||||
block_attention_kwargs["attention_mask"] = joint_attention_mask
|
||||
|
||||
for index_block, block in enumerate(self.transformer_blocks):
|
||||
|
||||
@@ -788,9 +788,12 @@ class ZImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOr
|
||||
freqs_cis = pad_sequence(freqs_cis, batch_first=True, padding_value=0.0)[:, : feats.shape[1]]
|
||||
|
||||
# Attention mask
|
||||
attn_mask = torch.zeros((bsz, max_seqlen), dtype=torch.bool, device=device)
|
||||
for i, seq_len in enumerate(item_seqlens):
|
||||
attn_mask[i, :seq_len] = 1
|
||||
if all(seq == max_seqlen for seq in item_seqlens):
|
||||
attn_mask = None
|
||||
else:
|
||||
attn_mask = torch.zeros((bsz, max_seqlen), dtype=torch.bool, device=device)
|
||||
for i, seq_len in enumerate(item_seqlens):
|
||||
attn_mask[i, :seq_len] = 1
|
||||
|
||||
# Noise mask
|
||||
noise_mask_tensor = None
|
||||
@@ -871,9 +874,12 @@ class ZImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOr
|
||||
unified_freqs = pad_sequence(unified_freqs, batch_first=True, padding_value=0.0)
|
||||
|
||||
# Attention mask
|
||||
attn_mask = torch.zeros((bsz, max_seqlen), dtype=torch.bool, device=device)
|
||||
for i, seq_len in enumerate(unified_seqlens):
|
||||
attn_mask[i, :seq_len] = 1
|
||||
if all(seq == max_seqlen for seq in unified_seqlens):
|
||||
attn_mask = None
|
||||
else:
|
||||
attn_mask = torch.zeros((bsz, max_seqlen), dtype=torch.bool, device=device)
|
||||
for i, seq_len in enumerate(unified_seqlens):
|
||||
attn_mask[i, :seq_len] = 1
|
||||
|
||||
# Noise mask
|
||||
noise_mask_tensor = None
|
||||
|
||||
@@ -285,6 +285,7 @@ else:
|
||||
]
|
||||
)
|
||||
_import_structure["latte"] = ["LattePipeline"]
|
||||
_import_structure["llada2"] = ["LLaDA2Pipeline", "LLaDA2PipelineOutput"]
|
||||
_import_structure["ltx"] = [
|
||||
"LTXPipeline",
|
||||
"LTXImageToVideoPipeline",
|
||||
@@ -728,6 +729,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
LEditsPPPipelineStableDiffusion,
|
||||
LEditsPPPipelineStableDiffusionXL,
|
||||
)
|
||||
from .llada2 import LLaDA2Pipeline, LLaDA2PipelineOutput
|
||||
from .longcat_image import LongCatImageEditPipeline, LongCatImagePipeline
|
||||
from .ltx import (
|
||||
LTXConditionPipeline,
|
||||
|
||||
@@ -324,17 +324,18 @@ class AudioLDM2Pipeline(DiffusionPipeline):
|
||||
`inputs_embeds (`torch.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
||||
The sequence of generated hidden-states.
|
||||
"""
|
||||
cache_position_kwargs = {}
|
||||
if is_transformers_version("<", "4.52.1"):
|
||||
cache_position_kwargs["input_ids"] = inputs_embeds
|
||||
else:
|
||||
cache_position_kwargs["seq_length"] = inputs_embeds.shape[0]
|
||||
cache_position_kwargs["device"] = (
|
||||
self.language_model.device if getattr(self, "language_model", None) is not None else self.device
|
||||
)
|
||||
cache_position_kwargs["model_kwargs"] = model_kwargs
|
||||
max_new_tokens = max_new_tokens if max_new_tokens is not None else self.language_model.config.max_new_tokens
|
||||
model_kwargs = self.language_model._get_initial_cache_position(**cache_position_kwargs)
|
||||
if hasattr(self.language_model, "_get_initial_cache_position"):
|
||||
cache_position_kwargs = {}
|
||||
if is_transformers_version("<", "4.52.1"):
|
||||
cache_position_kwargs["input_ids"] = inputs_embeds
|
||||
else:
|
||||
cache_position_kwargs["seq_length"] = inputs_embeds.shape[0]
|
||||
cache_position_kwargs["device"] = (
|
||||
self.language_model.device if getattr(self, "language_model", None) is not None else self.device
|
||||
)
|
||||
cache_position_kwargs["model_kwargs"] = model_kwargs
|
||||
model_kwargs = self.language_model._get_initial_cache_position(**cache_position_kwargs)
|
||||
|
||||
for _ in range(max_new_tokens):
|
||||
# prepare model inputs
|
||||
|
||||
@@ -16,22 +16,29 @@ from typing import Callable
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision
|
||||
import torchvision.transforms
|
||||
import torchvision.transforms.functional
|
||||
from transformers import AutoTokenizer, Qwen2_5_VLForConditionalGeneration
|
||||
|
||||
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from ...image_processor import PipelineImageInput
|
||||
from ...models import AutoencoderKLWan, CosmosTransformer3DModel
|
||||
from ...schedulers import UniPCMultistepScheduler
|
||||
from ...utils import is_cosmos_guardrail_available, is_torch_xla_available, logging, replace_example_docstring
|
||||
from ...utils import (
|
||||
is_cosmos_guardrail_available,
|
||||
is_torch_xla_available,
|
||||
is_torchvision_available,
|
||||
logging,
|
||||
replace_example_docstring,
|
||||
)
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ...video_processor import VideoProcessor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from .pipeline_output import CosmosPipelineOutput
|
||||
|
||||
|
||||
if is_torchvision_available():
|
||||
import torchvision.transforms.functional
|
||||
|
||||
|
||||
if is_cosmos_guardrail_available():
|
||||
from cosmos_guardrail import CosmosSafetyChecker
|
||||
else:
|
||||
|
||||
47
src/diffusers/pipelines/llada2/__init__.py
Normal file
47
src/diffusers/pipelines/llada2/__init__.py
Normal file
@@ -0,0 +1,47 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import (
|
||||
DIFFUSERS_SLOW_IMPORT,
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
get_objects_from_module,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
|
||||
|
||||
_dummy_objects = {}
|
||||
_import_structure = {}
|
||||
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils import dummy_torch_and_transformers_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["pipeline_llada2"] = ["LLaDA2Pipeline", "LLaDA2PipelineOutput"]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
|
||||
else:
|
||||
from .pipeline_llada2 import LLaDA2Pipeline, LLaDA2PipelineOutput
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(
|
||||
__name__,
|
||||
globals()["__file__"],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
)
|
||||
|
||||
for name, value in _dummy_objects.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
491
src/diffusers/pipelines/llada2/pipeline_llada2.py
Normal file
491
src/diffusers/pipelines/llada2/pipeline_llada2.py
Normal file
@@ -0,0 +1,491 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable
|
||||
|
||||
import torch
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from ...schedulers import BlockRefinementScheduler
|
||||
from ...utils import BaseOutput, logging, replace_example_docstring
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```python
|
||||
>>> import torch
|
||||
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
>>> from diffusers import BlockRefinementScheduler, LLaDA2Pipeline
|
||||
|
||||
>>> model_id = "inclusionAI/LLaDA2.1-mini"
|
||||
>>> model = AutoModelForCausalLM.from_pretrained(
|
||||
... model_id, trust_remote_code=True, dtype=torch.bfloat16, device_map="auto"
|
||||
... )
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
|
||||
>>> scheduler = BlockRefinementScheduler()
|
||||
|
||||
>>> pipe = LLaDA2Pipeline(model=model, scheduler=scheduler, tokenizer=tokenizer)
|
||||
>>> output = pipe(prompt="What is the meaning of life?", gen_length=256)
|
||||
>>> print(output.texts[0])
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLaDA2PipelineOutput(BaseOutput):
|
||||
sequences: torch.LongTensor
|
||||
texts: list[str] | None = None
|
||||
|
||||
|
||||
class LLaDA2Pipeline(DiffusionPipeline):
|
||||
r"""
|
||||
Pipeline for LLaDA2-style discrete diffusion text generation via block-wise iterative refinement.
|
||||
|
||||
This pipeline maintains a template sequence filled with a `mask_token_id` and refines it in blocks. In each
|
||||
refinement step, it samples candidate tokens for the active block and commits a subset based on confidence.
|
||||
|
||||
The model is expected to accept an attention mask and `position_ids`, and to return logits of shape `[batch, seq,
|
||||
vocab_size]`.
|
||||
"""
|
||||
|
||||
model: Any
|
||||
scheduler: BlockRefinementScheduler
|
||||
tokenizer: Any
|
||||
|
||||
_callback_tensor_inputs = ["block_x", "x0", "x0_p", "transfer_index", "confidence", "active_block"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Any,
|
||||
scheduler: BlockRefinementScheduler,
|
||||
tokenizer: Any | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.register_modules(model=model, scheduler=scheduler, tokenizer=tokenizer)
|
||||
self.eos_token_id = getattr(self.tokenizer, "eos_token_id", None) if self.tokenizer is not None else None
|
||||
self.mask_token_id = getattr(self.tokenizer, "mask_token_id", None) if self.tokenizer is not None else None
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
# --- Prompt encoding ---
|
||||
|
||||
def _prepare_input_ids(
|
||||
self,
|
||||
*,
|
||||
prompt: str | list[str] | None,
|
||||
messages: list[dict[str, str]] | None,
|
||||
input_ids: torch.LongTensor | None,
|
||||
use_chat_template: bool,
|
||||
add_generation_prompt: bool,
|
||||
chat_template_kwargs: dict[str, Any] | None,
|
||||
) -> torch.LongTensor:
|
||||
"""Convert prompt/messages/input_ids to a [batch, seq] LongTensor."""
|
||||
if input_ids is not None:
|
||||
if input_ids.ndim == 1:
|
||||
input_ids = input_ids.unsqueeze(0)
|
||||
if input_ids.ndim != 2:
|
||||
raise ValueError(f"`input_ids` must be 2D, got shape {tuple(input_ids.shape)}.")
|
||||
if input_ids.dtype != torch.long:
|
||||
raise ValueError(f"`input_ids` must be int64 token IDs, got dtype={input_ids.dtype}.")
|
||||
return input_ids
|
||||
|
||||
if self.tokenizer is None:
|
||||
raise ValueError("Tokenizer is required when `input_ids` is not provided.")
|
||||
|
||||
if messages is not None and prompt is not None:
|
||||
raise ValueError("Provide either `prompt` or `messages`, not both.")
|
||||
if messages is None and prompt is None:
|
||||
raise ValueError("Provide one of `prompt`, `messages`, or `input_ids`.")
|
||||
|
||||
chat_template_kwargs = chat_template_kwargs or {}
|
||||
|
||||
if messages is not None:
|
||||
encoded = self.tokenizer.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=add_generation_prompt,
|
||||
tokenize=True,
|
||||
return_tensors="pt",
|
||||
return_dict=True,
|
||||
**chat_template_kwargs,
|
||||
)
|
||||
return encoded["input_ids"]
|
||||
|
||||
if use_chat_template and getattr(self.tokenizer, "chat_template", None):
|
||||
if isinstance(prompt, list):
|
||||
raise ValueError("`prompt` must be a string when `use_chat_template=True`.")
|
||||
encoded = self.tokenizer.apply_chat_template(
|
||||
[{"role": "user", "content": prompt}],
|
||||
add_generation_prompt=add_generation_prompt,
|
||||
tokenize=True,
|
||||
return_tensors="pt",
|
||||
return_dict=True,
|
||||
**chat_template_kwargs,
|
||||
)
|
||||
return encoded["input_ids"]
|
||||
|
||||
encoded = self.tokenizer(prompt, return_tensors="pt", padding=isinstance(prompt, list))
|
||||
return encoded["input_ids"]
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt: str | list[str] | None,
|
||||
messages: list[dict[str, str]] | None,
|
||||
input_ids: torch.LongTensor | None,
|
||||
gen_length: int,
|
||||
block_length: int,
|
||||
num_inference_steps: int,
|
||||
minimal_topk: int,
|
||||
threshold: float,
|
||||
sampling_method: str,
|
||||
output_type: str,
|
||||
callback_on_step_end: Callable | PipelineCallback | MultiPipelineCallbacks | None,
|
||||
callback_on_step_end_tensor_inputs: list[str] | None,
|
||||
):
|
||||
# Input source validation
|
||||
if prompt is None and messages is None and input_ids is None:
|
||||
raise ValueError("Provide one of `prompt`, `messages`, or `input_ids`.")
|
||||
if prompt is not None and messages is not None:
|
||||
raise ValueError("Provide either `prompt` or `messages`, not both.")
|
||||
if input_ids is not None:
|
||||
if input_ids.ndim not in (1, 2):
|
||||
raise ValueError(f"`input_ids` must be 1D or 2D, got shape {tuple(input_ids.shape)}.")
|
||||
if input_ids.dtype != torch.long:
|
||||
raise ValueError(f"`input_ids` must be int64 token IDs, got dtype={input_ids.dtype}.")
|
||||
if prompt is not None and input_ids is None and self.tokenizer is None:
|
||||
raise ValueError("Tokenizer is required when `input_ids` is not provided.")
|
||||
if messages is not None and input_ids is None and self.tokenizer is None:
|
||||
raise ValueError("Tokenizer is required when `input_ids` is not provided.")
|
||||
|
||||
# Generation parameter validation
|
||||
if gen_length <= 0:
|
||||
raise ValueError(f"`gen_length` must be > 0, got {gen_length}.")
|
||||
if block_length <= 0:
|
||||
raise ValueError(f"`block_length` must be > 0, got {block_length}.")
|
||||
if num_inference_steps <= 0:
|
||||
raise ValueError(f"`num_inference_steps` must be > 0, got {num_inference_steps}.")
|
||||
if minimal_topk <= 0:
|
||||
raise ValueError(f"`minimal_topk` must be > 0, got {minimal_topk}.")
|
||||
if not (0.0 <= threshold <= 1.0) and not (threshold > 1.0):
|
||||
raise ValueError(f"`threshold` must be in [0, 1] (or > 1 to force top-k commits), got {threshold}.")
|
||||
if sampling_method not in {"auto", "greedy", "multinomial"}:
|
||||
raise ValueError(
|
||||
f"`sampling_method` must be one of {{'auto','greedy','multinomial'}}, got {sampling_method!r}."
|
||||
)
|
||||
if output_type not in {"seq", "text"}:
|
||||
raise ValueError(f"`output_type` must be 'seq' or 'text', got {output_type!r}.")
|
||||
|
||||
# Callback validation
|
||||
if callback_on_step_end is not None and isinstance(
|
||||
callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)
|
||||
):
|
||||
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found "
|
||||
f"{[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: str | list[str] | None = None,
|
||||
messages: list[dict[str, str]] | None = None,
|
||||
input_ids: torch.LongTensor | None = None,
|
||||
use_chat_template: bool = True,
|
||||
add_generation_prompt: bool = True,
|
||||
gen_length: int = 2048,
|
||||
block_length: int = 32,
|
||||
num_inference_steps: int = 32,
|
||||
temperature: float = 0.0,
|
||||
top_p: float | None = None,
|
||||
top_k: int | None = None,
|
||||
sampling_method: str = "multinomial",
|
||||
threshold: float = 0.7,
|
||||
editing_threshold: float | None = 0.5,
|
||||
max_post_steps: int = 16,
|
||||
minimal_topk: int = 1,
|
||||
eos_early_stop: bool = True,
|
||||
eos_token_id: int | None = None,
|
||||
mask_token_id: int | None = None,
|
||||
generator: torch.Generator | None = None,
|
||||
output_type: str = "text",
|
||||
return_dict: bool = True,
|
||||
callback_on_step_end: Callable[[int, int, dict], None]
|
||||
| PipelineCallback
|
||||
| MultiPipelineCallbacks
|
||||
| None = None,
|
||||
callback_on_step_end_tensor_inputs: list[str] | None = None,
|
||||
) -> LLaDA2PipelineOutput | tuple[torch.LongTensor, list[str] | None]:
|
||||
"""
|
||||
Generate text with block-wise refinement.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
Prompt text. When `use_chat_template` is `True` (default) and a tokenizer with a chat template is
|
||||
available, the prompt is wrapped in a chat message before tokenization.
|
||||
messages (`List[Dict[str, str]]`, *optional*):
|
||||
Chat messages to encode (e.g. `[{"role": "user", "content": "Hello"}]`). Takes precedence over `prompt`
|
||||
when provided. Requires a tokenizer with `apply_chat_template`.
|
||||
input_ids (`torch.LongTensor`, *optional*):
|
||||
Pre-tokenized input IDs. Takes precedence over `prompt` and `messages`.
|
||||
use_chat_template (`bool`, defaults to `True`):
|
||||
Whether to wrap the prompt in a chat template.
|
||||
add_generation_prompt (`bool`, defaults to `True`):
|
||||
Whether to add the generation prompt when using chat templates.
|
||||
gen_length (`int`):
|
||||
Number of tokens to generate.
|
||||
block_length (`int`):
|
||||
Block size for refinement.
|
||||
num_inference_steps (`int`):
|
||||
Number of refinement steps per block.
|
||||
temperature (`float`):
|
||||
Sampling temperature.
|
||||
top_p (`float`, *optional*):
|
||||
Nucleus sampling cutoff.
|
||||
top_k (`int`, *optional*):
|
||||
Top-k sampling cutoff.
|
||||
sampling_method (`str`):
|
||||
Sampling method (`auto`, `greedy`, `multinomial`).
|
||||
threshold (`float`):
|
||||
Confidence threshold for committing tokens.
|
||||
editing_threshold (`float`, *optional*):
|
||||
Confidence threshold for editing already-committed (non-mask) tokens. When positive, after all mask
|
||||
tokens in a block are resolved, the pipeline continues refining: if the model predicts a different
|
||||
token with confidence above this threshold, the existing token is replaced. Set to `None`, `0.0`, or a
|
||||
negative value to disable editing. Defaults to `0.5`.
|
||||
max_post_steps (`int`):
|
||||
Maximum number of additional refinement iterations after all mask tokens in a block are resolved. Only
|
||||
used when `editing_threshold` is enabled. Defaults to `16`.
|
||||
minimal_topk (`int`):
|
||||
Minimum number of tokens to commit per step.
|
||||
eos_early_stop (`bool`):
|
||||
Whether to stop after committing EOS in a block.
|
||||
eos_token_id (`int`, *optional*):
|
||||
EOS token ID to use for early stopping.
|
||||
mask_token_id (`int`, *optional*):
|
||||
Mask token ID to use for the template.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
RNG for sampling.
|
||||
output_type (`str`, defaults to `"text"`):
|
||||
Output format. `"text"` decodes sequences into strings (requires a tokenizer). `"seq"` returns raw
|
||||
token ID sequences only.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether to return a [`LLaDA2PipelineOutput`] instead of a tuple.
|
||||
callback_on_step_end (`Callable` or `PipelineCallback`, *optional*):
|
||||
Callback executed after each refinement step with signature `callback_on_step_end(self, step: int,
|
||||
timestep: int, callback_kwargs: Dict)`.
|
||||
callback_on_step_end_tensor_inputs (`List[str]`, *optional*):
|
||||
Tensor keys to pass to the callback. Allowed keys: `block_x`, `x0`, `x0_p`, `transfer_index`,
|
||||
`confidence`, `active_block`.
|
||||
|
||||
Examples:
|
||||
"""
|
||||
# 1. Check inputs early
|
||||
if callback_on_step_end is not None and isinstance(
|
||||
callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)
|
||||
):
|
||||
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
||||
if callback_on_step_end_tensor_inputs is None:
|
||||
callback_on_step_end_tensor_inputs = ["block_x"]
|
||||
|
||||
self.check_inputs(
|
||||
prompt=prompt,
|
||||
messages=messages,
|
||||
input_ids=input_ids,
|
||||
gen_length=gen_length,
|
||||
block_length=block_length,
|
||||
num_inference_steps=num_inference_steps,
|
||||
minimal_topk=minimal_topk,
|
||||
threshold=threshold,
|
||||
sampling_method=sampling_method,
|
||||
output_type=output_type,
|
||||
callback_on_step_end=callback_on_step_end,
|
||||
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
||||
)
|
||||
|
||||
# 2. Prepare input IDs from prompt/messages/input_ids
|
||||
prompt_ids = self._prepare_input_ids(
|
||||
prompt=prompt,
|
||||
messages=messages,
|
||||
input_ids=input_ids,
|
||||
use_chat_template=use_chat_template,
|
||||
add_generation_prompt=add_generation_prompt,
|
||||
chat_template_kwargs=None,
|
||||
)
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
if prompt_ids.ndim == 1:
|
||||
prompt_ids = prompt_ids.unsqueeze(0)
|
||||
prompt_ids = prompt_ids.to(device=device)
|
||||
batch_size, prompt_length = prompt_ids.shape
|
||||
|
||||
if eos_token_id is None:
|
||||
eos_token_id = self.eos_token_id
|
||||
if mask_token_id is None:
|
||||
mask_token_id = self.mask_token_id
|
||||
if mask_token_id is None:
|
||||
raise ValueError("`mask_token_id` must be provided (or available on the tokenizer).")
|
||||
|
||||
num_inference_steps = min(num_inference_steps, gen_length // minimal_topk)
|
||||
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
|
||||
# 3. Build attention mask and position IDs
|
||||
num_blocks = (prompt_length + gen_length + block_length - 1) // block_length
|
||||
total_length = num_blocks * block_length
|
||||
|
||||
# 2D attention mask (no padding) — the model handles backend-specific conversion internally.
|
||||
attn_mask = torch.ones((batch_size, total_length), device=device, dtype=torch.long)
|
||||
|
||||
position_ids = torch.arange(total_length, device=device, dtype=torch.long).unsqueeze(0).expand(batch_size, -1)
|
||||
|
||||
# 4. Prepare latents (fully masked sequence)
|
||||
x = torch.full((batch_size, total_length), mask_token_id, device=device, dtype=torch.long)
|
||||
if prompt_length > 0:
|
||||
x[:, :prompt_length] = prompt_ids
|
||||
|
||||
prefill_blocks = prompt_length // block_length
|
||||
self._num_timesteps = num_inference_steps * max(num_blocks - prefill_blocks, 0)
|
||||
|
||||
finished = torch.zeros((batch_size,), device=device, dtype=torch.bool)
|
||||
editing_enabled = editing_threshold is not None and editing_threshold > 0.0
|
||||
global_step = 0
|
||||
|
||||
# 5. Block-wise refinement loop
|
||||
block_progress_bar_config = getattr(self, "_progress_bar_config", {}).copy()
|
||||
block_progress_bar_config["position"] = 0
|
||||
block_progress_bar_config["desc"] = "Blocks"
|
||||
for num_block in tqdm(range(prefill_blocks, num_blocks), **block_progress_bar_config):
|
||||
current_window_end = (num_block + 1) * block_length
|
||||
block_x = x[:, :current_window_end]
|
||||
block_attn_mask = attn_mask[:, :current_window_end]
|
||||
block_position_ids = position_ids[:, :current_window_end]
|
||||
|
||||
# Identify which positions in the block are prompt (non-editable).
|
||||
block_start_pos = num_block * block_length
|
||||
prompt_mask_in_block = torch.zeros(block_length, device=device, dtype=torch.bool)
|
||||
if block_start_pos < prompt_length:
|
||||
prompt_end_in_block = min(prompt_length - block_start_pos, block_length)
|
||||
prompt_mask_in_block[:prompt_end_in_block] = True
|
||||
|
||||
post_steps = 0
|
||||
step_idx = 0
|
||||
should_continue = True
|
||||
self.set_progress_bar_config(position=1, leave=False, desc=f"Block {num_block} Inference Steps")
|
||||
progress_bar = self.progress_bar(total=num_inference_steps)
|
||||
|
||||
while should_continue:
|
||||
block_tokens = block_x[:, -block_length:]
|
||||
masks_remaining = (block_tokens == mask_token_id).any()
|
||||
|
||||
if not masks_remaining:
|
||||
post_steps += 1
|
||||
|
||||
logits = self.model(block_x, attention_mask=block_attn_mask, position_ids=block_position_ids).logits
|
||||
block_logits = logits[:, -block_length:, :]
|
||||
|
||||
scheduler_output = self.scheduler.step(
|
||||
model_output=block_logits,
|
||||
timestep=step_idx,
|
||||
sample=block_tokens,
|
||||
mask_token_id=mask_token_id,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
sampling_method=sampling_method,
|
||||
threshold=threshold,
|
||||
editing_threshold=editing_threshold,
|
||||
minimal_topk=minimal_topk,
|
||||
prompt_mask=prompt_mask_in_block,
|
||||
generator=generator,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
transfer_index = scheduler_output.transfer_index
|
||||
editing_transfer_index = scheduler_output.editing_transfer_index
|
||||
final_transfer = transfer_index | editing_transfer_index
|
||||
|
||||
if final_transfer.any():
|
||||
block_x[:, -block_length:] = scheduler_output.prev_sample
|
||||
|
||||
if eos_early_stop and eos_token_id is not None:
|
||||
finished = self.scheduler.check_eos_finished(
|
||||
cur_x=block_x,
|
||||
sampled_tokens=scheduler_output.sampled_tokens,
|
||||
final_transfer=final_transfer,
|
||||
finished=finished,
|
||||
eos_token_id=eos_token_id,
|
||||
mask_token_id=mask_token_id,
|
||||
prompt_length=prompt_length,
|
||||
)
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, global_step, step_idx, callback_kwargs)
|
||||
block_x = callback_outputs.pop("block_x", block_x)
|
||||
|
||||
global_step += 1
|
||||
if masks_remaining:
|
||||
step_idx += 1
|
||||
progress_bar.update(1)
|
||||
|
||||
should_continue = self.scheduler.check_block_should_continue(
|
||||
step_idx=step_idx,
|
||||
masks_remaining=masks_remaining,
|
||||
editing_enabled=editing_enabled,
|
||||
editing_transfer_index=editing_transfer_index,
|
||||
post_steps=post_steps,
|
||||
max_post_steps=max_post_steps,
|
||||
finished=finished,
|
||||
)
|
||||
|
||||
progress_bar.close()
|
||||
x[:, :current_window_end] = block_x
|
||||
if eos_early_stop and finished.all():
|
||||
break
|
||||
|
||||
# 6. Post-process output
|
||||
generated = x[:, : prompt_length + gen_length]
|
||||
sequences = generated[:, prompt_length:]
|
||||
if eos_token_id is not None and batch_size == 1:
|
||||
eos_positions = (sequences[0] == eos_token_id).nonzero(as_tuple=True)[0]
|
||||
if len(eos_positions) > 0:
|
||||
sequences = sequences[:, : int(eos_positions[0].item()) + 1]
|
||||
|
||||
texts = None
|
||||
if output_type == "text" and self.tokenizer is not None:
|
||||
texts = self.tokenizer.batch_decode(sequences, skip_special_tokens=True)
|
||||
|
||||
if not return_dict:
|
||||
return sequences.to(device=device), texts
|
||||
return LLaDA2PipelineOutput(sequences=sequences.to(device=device), texts=texts)
|
||||
|
||||
|
||||
__all__ = ["LLaDA2Pipeline", "LLaDA2PipelineOutput"]
|
||||
@@ -40,6 +40,7 @@ except OptionalDependencyNotAvailable:
|
||||
else:
|
||||
_import_structure["deprecated"] = ["KarrasVeScheduler", "ScoreSdeVpScheduler"]
|
||||
_import_structure["scheduling_amused"] = ["AmusedScheduler"]
|
||||
_import_structure["scheduling_block_refinement"] = ["BlockRefinementScheduler", "BlockRefinementSchedulerOutput"]
|
||||
_import_structure["scheduling_consistency_decoder"] = ["ConsistencyDecoderScheduler"]
|
||||
_import_structure["scheduling_consistency_models"] = ["CMStochasticIterativeScheduler"]
|
||||
_import_structure["scheduling_ddim"] = ["DDIMScheduler"]
|
||||
@@ -145,6 +146,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
else:
|
||||
from .deprecated import KarrasVeScheduler, ScoreSdeVpScheduler
|
||||
from .scheduling_amused import AmusedScheduler
|
||||
from .scheduling_block_refinement import BlockRefinementScheduler, BlockRefinementSchedulerOutput
|
||||
from .scheduling_consistency_decoder import ConsistencyDecoderScheduler
|
||||
from .scheduling_consistency_models import CMStochasticIterativeScheduler
|
||||
from .scheduling_ddim import DDIMScheduler
|
||||
|
||||
460
src/diffusers/schedulers/scheduling_block_refinement.py
Normal file
460
src/diffusers/schedulers/scheduling_block_refinement.py
Normal file
@@ -0,0 +1,460 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import BaseOutput
|
||||
from .scheduling_utils import SchedulerMixin
|
||||
|
||||
|
||||
@dataclass
|
||||
class BlockRefinementSchedulerOutput(BaseOutput):
|
||||
"""
|
||||
Output class for block refinement scheduling.
|
||||
|
||||
Args:
|
||||
prev_sample (`torch.LongTensor` of shape `(batch_size, block_length)`):
|
||||
Updated block tokens after the current refinement step.
|
||||
transfer_index (`torch.BoolTensor` of shape `(batch_size, block_length)`):
|
||||
Boolean mask indicating which tokens were committed (mask-filling).
|
||||
editing_transfer_index (`torch.BoolTensor` of shape `(batch_size, block_length)`):
|
||||
Boolean mask indicating which tokens were edited (non-mask replacement).
|
||||
sampled_tokens (`torch.LongTensor` of shape `(batch_size, block_length)`):
|
||||
Sampled token IDs from the model logits.
|
||||
sampled_probs (`torch.Tensor` of shape `(batch_size, block_length)`):
|
||||
Probabilities of the sampled tokens.
|
||||
"""
|
||||
|
||||
prev_sample: torch.LongTensor
|
||||
transfer_index: torch.BoolTensor
|
||||
editing_transfer_index: torch.BoolTensor
|
||||
sampled_tokens: torch.LongTensor
|
||||
sampled_probs: torch.Tensor
|
||||
|
||||
|
||||
class BlockRefinementScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
Scheduler for block-wise iterative refinement (commit-by-confidence).
|
||||
|
||||
At each step, the scheduler samples candidate tokens from model logits and commits those with the highest
|
||||
confidence. The number of tokens to commit per step is determined by evenly distributing the block length across
|
||||
the number of refinement steps.
|
||||
|
||||
Optionally supports editing: after all mask tokens are resolved, tokens can be replaced if the model predicts a
|
||||
different token with confidence above a positive `editing_threshold` (`None`, `0.0`, or negative disables editing).
|
||||
"""
|
||||
|
||||
order = 1
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
block_length: int = 32,
|
||||
num_inference_steps: int = 32,
|
||||
threshold: float = 0.95,
|
||||
editing_threshold: float | None = None,
|
||||
minimal_topk: int = 1,
|
||||
):
|
||||
self.num_inference_steps = num_inference_steps
|
||||
self.timesteps = torch.arange(self.num_inference_steps - 1, -1, -1, dtype=torch.long)
|
||||
self._transfer_schedule: torch.LongTensor | None = None
|
||||
|
||||
def set_timesteps(self, num_inference_steps: int, device: str | torch.device | None = None) -> None:
|
||||
if num_inference_steps <= 0:
|
||||
raise ValueError(f"`num_inference_steps` must be > 0, got {num_inference_steps}.")
|
||||
self.num_inference_steps = num_inference_steps
|
||||
self.timesteps = torch.arange(self.num_inference_steps - 1, -1, -1, device=device, dtype=torch.long)
|
||||
self._transfer_schedule = self.get_num_transfer_tokens(self.config.block_length, self.num_inference_steps).to(
|
||||
device=device if device is not None else "cpu"
|
||||
)
|
||||
|
||||
def get_num_transfer_tokens(self, block_length: int, num_inference_steps: int) -> torch.LongTensor:
|
||||
"""Evenly distribute `block_length` token commits across `num_inference_steps` steps."""
|
||||
if num_inference_steps <= 0:
|
||||
return torch.zeros((0,), dtype=torch.long)
|
||||
base = block_length // num_inference_steps
|
||||
remainder = block_length % num_inference_steps
|
||||
out = torch.full((num_inference_steps,), base, dtype=torch.long)
|
||||
out[:remainder] += 1
|
||||
return out
|
||||
|
||||
# --- SAR sampling utilities ---
|
||||
|
||||
@staticmethod
|
||||
def _top_p_filtering(logits: torch.Tensor, top_p: float | None) -> torch.Tensor:
|
||||
"""Nucleus (top-p) logit filtering."""
|
||||
if top_p is None or top_p >= 1.0:
|
||||
return logits
|
||||
if not (0.0 < top_p <= 1.0):
|
||||
raise ValueError(f"`top_p` must be in (0, 1], got {top_p}.")
|
||||
|
||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
|
||||
sorted_probs = torch.softmax(sorted_logits, dim=-1)
|
||||
cumulative_probs = sorted_probs.cumsum(dim=-1)
|
||||
|
||||
sorted_indices_to_remove = cumulative_probs > float(top_p)
|
||||
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
||||
sorted_indices_to_remove[..., 0] = 0
|
||||
|
||||
sorted_logits = sorted_logits.masked_fill(sorted_indices_to_remove, torch.finfo(sorted_logits.dtype).min)
|
||||
filtered = logits.scatter(-1, sorted_indices, sorted_logits)
|
||||
return filtered
|
||||
|
||||
@staticmethod
|
||||
def _top_k_filtering(logits: torch.Tensor, top_k: int | None) -> torch.Tensor:
|
||||
"""Top-k logit filtering."""
|
||||
if top_k is None or top_k <= 0:
|
||||
return logits
|
||||
if top_k >= logits.shape[-1]:
|
||||
return logits
|
||||
values, _ = torch.topk(logits, k=top_k, dim=-1)
|
||||
min_keep = values[..., -1, None]
|
||||
return logits.masked_fill(logits < min_keep, torch.finfo(logits.dtype).min)
|
||||
|
||||
@staticmethod
|
||||
def _sample_from_logits(
|
||||
logits: torch.Tensor,
|
||||
*,
|
||||
temperature: float,
|
||||
top_k: int | None,
|
||||
top_p: float | None,
|
||||
generator: torch.Generator | None,
|
||||
use_multinomial: bool,
|
||||
) -> tuple[torch.LongTensor, torch.Tensor]:
|
||||
"""Sample tokens from logits with temperature scaling, top-k, and top-p."""
|
||||
if temperature < 0:
|
||||
raise ValueError(f"`temperature` must be >= 0, got {temperature}.")
|
||||
|
||||
vocab_size = logits.shape[-1]
|
||||
flat_logits = logits.reshape(-1, vocab_size)
|
||||
|
||||
if temperature == 0.0 or not use_multinomial:
|
||||
probs = torch.softmax(flat_logits.float(), dim=-1)
|
||||
token = flat_logits.argmax(dim=-1, keepdim=True)
|
||||
token_prob = torch.gather(probs, -1, token)
|
||||
return token.view(*logits.shape[:-1]), token_prob.view(*logits.shape[:-1])
|
||||
|
||||
scaled = flat_logits
|
||||
if temperature != 1.0:
|
||||
scaled = flat_logits / temperature
|
||||
|
||||
filtered = BlockRefinementScheduler._top_k_filtering(scaled, top_k=top_k)
|
||||
filtered = BlockRefinementScheduler._top_p_filtering(filtered, top_p=top_p)
|
||||
|
||||
probs = torch.softmax(filtered.float(), dim=-1)
|
||||
token = torch.multinomial(probs, num_samples=1, generator=generator)
|
||||
token_prob = torch.gather(probs, -1, token)
|
||||
|
||||
return token.view(*logits.shape[:-1]), token_prob.view(*logits.shape[:-1])
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: torch.Tensor,
|
||||
timestep: int | torch.Tensor,
|
||||
sample: torch.LongTensor,
|
||||
*,
|
||||
mask_token_id: int,
|
||||
temperature: float = 0.0,
|
||||
top_p: float | None = None,
|
||||
top_k: int | None = None,
|
||||
sampling_method: str = "auto",
|
||||
threshold: float | None = None,
|
||||
editing_threshold: float | None = None,
|
||||
minimal_topk: int | None = None,
|
||||
prompt_mask: torch.BoolTensor | None = None,
|
||||
generator: torch.Generator | None = None,
|
||||
return_dict: bool = True,
|
||||
) -> (
|
||||
BlockRefinementSchedulerOutput
|
||||
| tuple[torch.LongTensor, torch.BoolTensor, torch.BoolTensor, torch.LongTensor, torch.Tensor]
|
||||
):
|
||||
"""
|
||||
Perform a single refinement step: sample from logits, commit confident tokens, and optionally edit existing
|
||||
ones.
|
||||
|
||||
Args:
|
||||
model_output (`torch.Tensor` of shape `(batch_size, block_length, vocab_size)`):
|
||||
Raw logits from the model for the current block.
|
||||
timestep (`int` or `torch.Tensor`):
|
||||
Current step index within the block's refinement schedule.
|
||||
sample (`torch.LongTensor` of shape `(batch_size, block_length)`):
|
||||
Current block token IDs (contains mask tokens for uncommitted positions).
|
||||
mask_token_id (`int`):
|
||||
Token ID used for masked positions.
|
||||
temperature (`float`):
|
||||
Sampling temperature.
|
||||
top_p (`float`, *optional*):
|
||||
Nucleus sampling cutoff.
|
||||
top_k (`int`, *optional*):
|
||||
Top-k sampling cutoff.
|
||||
sampling_method (`str`):
|
||||
Sampling method (`auto`, `greedy`, `multinomial`).
|
||||
threshold (`float`, *optional*):
|
||||
Confidence threshold for committing tokens. Defaults to config value.
|
||||
editing_threshold (`float`, *optional*):
|
||||
Confidence threshold for editing non-mask tokens; must be positive to enable editing. Defaults to
|
||||
config value.
|
||||
minimal_topk (`int`, *optional*):
|
||||
Minimum tokens to commit per step. Defaults to config value.
|
||||
prompt_mask (`torch.BoolTensor`, *optional*):
|
||||
Boolean mask of shape `(block_length,)` where `True` marks prompt (non-editable) positions.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
RNG for sampling.
|
||||
return_dict (`bool`):
|
||||
Whether to return a `BlockRefinementSchedulerOutput` or a tuple.
|
||||
"""
|
||||
if threshold is None:
|
||||
threshold = float(self.config.threshold)
|
||||
if editing_threshold is None:
|
||||
editing_threshold = self.config.editing_threshold
|
||||
if minimal_topk is None:
|
||||
minimal_topk = self.config.minimal_topk
|
||||
|
||||
# Sample from logits
|
||||
use_multinomial = sampling_method == "multinomial" or (sampling_method == "auto" and temperature != 0.0)
|
||||
sampled_tokens, sampled_probs = self._sample_from_logits(
|
||||
model_output,
|
||||
temperature=temperature,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
generator=generator,
|
||||
use_multinomial=use_multinomial,
|
||||
)
|
||||
|
||||
batch_size, block_length = sample.shape
|
||||
active_block = sample == mask_token_id
|
||||
masks_remaining = active_block.any()
|
||||
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
step_index = int(timestep.item())
|
||||
else:
|
||||
step_index = int(timestep)
|
||||
|
||||
# --- Mask-filling transfer ---
|
||||
transfer_index = torch.zeros_like(sampled_tokens, dtype=torch.bool)
|
||||
if masks_remaining and self._transfer_schedule is not None:
|
||||
clamped_step = min(step_index, len(self._transfer_schedule) - 1)
|
||||
num_to_transfer = int(self._transfer_schedule[clamped_step].item())
|
||||
|
||||
confidence = torch.where(
|
||||
active_block,
|
||||
sampled_probs.to(dtype=torch.float32),
|
||||
torch.full_like(sampled_probs, -torch.inf, dtype=torch.float32),
|
||||
)
|
||||
|
||||
for b in range(batch_size):
|
||||
high_conf = confidence[b] > threshold
|
||||
if high_conf.sum().item() >= num_to_transfer:
|
||||
transfer_index[b] = high_conf
|
||||
else:
|
||||
k = min(num_to_transfer, int(active_block[b].sum().item()))
|
||||
if k > 0:
|
||||
_, idx = torch.topk(confidence[b], k=k)
|
||||
transfer_index[b, idx] = True
|
||||
|
||||
# --- Editing transfer (non-mask, non-prompt positions) ---
|
||||
editing_enabled = editing_threshold is not None and editing_threshold > 0.0
|
||||
editing_transfer_index = torch.zeros_like(sampled_tokens, dtype=torch.bool)
|
||||
if editing_enabled:
|
||||
if prompt_mask is None:
|
||||
prompt_mask = torch.zeros(block_length, device=sample.device, dtype=torch.bool)
|
||||
editable = (~active_block) & (~prompt_mask.unsqueeze(0))
|
||||
editing_conf = torch.where(
|
||||
editable,
|
||||
sampled_probs.to(dtype=torch.float32),
|
||||
torch.full_like(sampled_probs, -torch.inf, dtype=torch.float32),
|
||||
)
|
||||
high_conf_edit = editing_conf > float(editing_threshold)
|
||||
token_changed = sampled_tokens != sample
|
||||
editing_transfer_index = high_conf_edit & token_changed & editable
|
||||
|
||||
# Apply transfers
|
||||
final_transfer = transfer_index | editing_transfer_index
|
||||
prev_sample = sample.clone()
|
||||
if final_transfer.any():
|
||||
prev_sample[final_transfer] = sampled_tokens[final_transfer]
|
||||
|
||||
if not return_dict:
|
||||
return prev_sample, transfer_index, editing_transfer_index, sampled_tokens, sampled_probs
|
||||
return BlockRefinementSchedulerOutput(
|
||||
prev_sample=prev_sample,
|
||||
transfer_index=transfer_index,
|
||||
editing_transfer_index=editing_transfer_index,
|
||||
sampled_tokens=sampled_tokens,
|
||||
sampled_probs=sampled_probs,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def check_eos_finished(
|
||||
cur_x: torch.LongTensor,
|
||||
sampled_tokens: torch.LongTensor,
|
||||
final_transfer: torch.BoolTensor,
|
||||
finished: torch.BoolTensor,
|
||||
eos_token_id: int,
|
||||
mask_token_id: int,
|
||||
prompt_length: int,
|
||||
) -> torch.BoolTensor:
|
||||
"""
|
||||
Update per-batch finished flags when EOS tokens are committed.
|
||||
|
||||
Args:
|
||||
cur_x (`torch.LongTensor` of shape `(batch_size, seq_len)`):
|
||||
Current full sequence including all blocks up to the current window.
|
||||
sampled_tokens (`torch.LongTensor` of shape `(batch_size, block_length)`):
|
||||
Tokens sampled by the scheduler in this step.
|
||||
final_transfer (`torch.BoolTensor` of shape `(batch_size, block_length)`):
|
||||
Combined mask of committed and edited positions.
|
||||
finished (`torch.BoolTensor` of shape `(batch_size,)`):
|
||||
Current per-batch finished flags.
|
||||
eos_token_id (`int`):
|
||||
EOS token ID.
|
||||
mask_token_id (`int`):
|
||||
Mask token ID.
|
||||
prompt_length (`int`):
|
||||
Number of prompt tokens at the start of the sequence.
|
||||
|
||||
Returns:
|
||||
`torch.BoolTensor`: Updated finished flags.
|
||||
"""
|
||||
batch_size = cur_x.shape[0]
|
||||
for b in range(batch_size):
|
||||
if finished[b]:
|
||||
continue
|
||||
eos_in_commits = (sampled_tokens[b][final_transfer[b]] == eos_token_id).any().item()
|
||||
if not eos_in_commits:
|
||||
continue
|
||||
eos_pos = (cur_x[b] == eos_token_id).nonzero(as_tuple=True)
|
||||
if len(eos_pos[0]) == 0:
|
||||
continue
|
||||
eos_pos = int(eos_pos[0][0].item())
|
||||
if prompt_length >= eos_pos:
|
||||
continue
|
||||
if (cur_x[b, prompt_length:eos_pos] != mask_token_id).all().item():
|
||||
finished[b] = True
|
||||
return finished
|
||||
|
||||
def check_block_should_continue(
|
||||
self,
|
||||
step_idx: int,
|
||||
masks_remaining: bool,
|
||||
editing_enabled: bool,
|
||||
editing_transfer_index: torch.BoolTensor,
|
||||
post_steps: int,
|
||||
max_post_steps: int,
|
||||
finished: torch.BoolTensor,
|
||||
) -> bool:
|
||||
"""
|
||||
Determine whether the inner refinement loop should continue for the current block.
|
||||
|
||||
Args:
|
||||
step_idx (`int`):
|
||||
Current refinement step index within this block.
|
||||
masks_remaining (`bool`):
|
||||
Whether any mask tokens remain in the block.
|
||||
editing_enabled (`bool`):
|
||||
Whether editing mode is active.
|
||||
editing_transfer_index (`torch.BoolTensor`):
|
||||
Which tokens were edited in this step.
|
||||
post_steps (`int`):
|
||||
Number of post-mask editing steps taken so far.
|
||||
max_post_steps (`int`):
|
||||
Maximum allowed post-mask editing steps.
|
||||
finished (`torch.BoolTensor`):
|
||||
Per-batch finished flags (from EOS detection).
|
||||
|
||||
Returns:
|
||||
`bool`: `True` if refinement should continue, `False` to break.
|
||||
"""
|
||||
if finished.all():
|
||||
return False
|
||||
if not masks_remaining and not editing_enabled:
|
||||
return False
|
||||
if not masks_remaining and not editing_transfer_index.any():
|
||||
return False
|
||||
if masks_remaining and step_idx >= self.num_inference_steps:
|
||||
return False
|
||||
if not masks_remaining and post_steps > max_post_steps:
|
||||
return False
|
||||
return True
|
||||
|
||||
def add_noise(
|
||||
self,
|
||||
original_samples: torch.LongTensor,
|
||||
attention_mask: torch.LongTensor,
|
||||
*,
|
||||
prompt_length: int,
|
||||
block_length: int,
|
||||
mask_token_id: int,
|
||||
generator: torch.Generator | None = None,
|
||||
) -> tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor, torch.BoolTensor]:
|
||||
"""
|
||||
Apply the forward (noising) process for semi-autoregressive block masking.
|
||||
|
||||
For each block after the prompt, a random fraction of valid (non-padding) tokens are replaced with
|
||||
`mask_token_id`. Two complementary views are returned: `noisy` and `noisy_rev`, where the masked positions in
|
||||
one are the unmasked positions in the other.
|
||||
|
||||
Args:
|
||||
original_samples (`torch.LongTensor` of shape `(batch_size, seq_len)`):
|
||||
Clean token IDs.
|
||||
attention_mask (`torch.LongTensor` of shape `(batch_size, seq_len)`):
|
||||
Padding mask (1 for valid, 0 for padding).
|
||||
prompt_length (`int`):
|
||||
Number of leading prompt tokens to keep unmasked.
|
||||
block_length (`int`):
|
||||
Block size for masking.
|
||||
mask_token_id (`int`):
|
||||
Token ID to use for masked positions.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
RNG for reproducibility.
|
||||
|
||||
Returns:
|
||||
`tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor, torch.BoolTensor]`:
|
||||
`(noisy, noisy_rev, masked, masked_rev)` — the two complementary noisy sequences and their
|
||||
corresponding boolean masks.
|
||||
"""
|
||||
batch_size, seq_len = original_samples.shape
|
||||
device = original_samples.device
|
||||
|
||||
noisy = original_samples.clone()
|
||||
noisy_rev = original_samples.clone()
|
||||
masked = torch.zeros_like(original_samples, dtype=torch.bool)
|
||||
masked_rev = torch.zeros_like(original_samples, dtype=torch.bool)
|
||||
|
||||
valid = attention_mask.to(dtype=torch.bool)
|
||||
for block_start in range(prompt_length, seq_len, block_length):
|
||||
block_end = min(seq_len, block_start + block_length)
|
||||
seg_len = block_end - block_start
|
||||
if seg_len <= 0:
|
||||
continue
|
||||
|
||||
p_mask = torch.rand((batch_size, 1), device=device, generator=generator)
|
||||
seg = torch.rand((batch_size, seg_len), device=device, generator=generator) < p_mask
|
||||
seg = seg & valid[:, block_start:block_end]
|
||||
seg_rev = (~seg) & valid[:, block_start:block_end]
|
||||
|
||||
masked[:, block_start:block_end] = seg
|
||||
masked_rev[:, block_start:block_end] = seg_rev
|
||||
|
||||
noisy = torch.where(masked, torch.full_like(noisy, mask_token_id), noisy)
|
||||
noisy_rev = torch.where(masked_rev, torch.full_like(noisy_rev, mask_token_id), noisy_rev)
|
||||
return noisy, noisy_rev, masked, masked_rev
|
||||
|
||||
|
||||
__all__ = ["BlockRefinementScheduler", "BlockRefinementSchedulerOutput"]
|
||||
@@ -11,6 +11,7 @@ from typing import Any, Iterable
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
if getattr(torch, "distributed", None) is not None:
|
||||
@@ -109,6 +110,92 @@ def compute_snr(noise_scheduler, timesteps):
|
||||
return snr
|
||||
|
||||
|
||||
def compute_confidence_aware_loss(
|
||||
logits: torch.Tensor,
|
||||
labels: torch.Tensor,
|
||||
*,
|
||||
lambda_conf: float = 0.0,
|
||||
temperature: float = 1.0,
|
||||
per_token_weights: torch.Tensor | None = None,
|
||||
ignore_index: int = -100,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Computes a confidence-aware training loss for token classification-style heads.
|
||||
|
||||
This loss combines:
|
||||
- `loss_sft`: standard supervised cross-entropy on all non-ignored labels.
|
||||
- `loss_conf`: an entropy penalty applied only on tokens that are already predicted correctly.
|
||||
|
||||
Args:
|
||||
logits (`torch.Tensor`): Logits of shape `(..., vocab_size)`.
|
||||
labels (`torch.Tensor`): Labels of shape `(...)`, matching `logits.shape[:-1]`. Values set to `ignore_index`
|
||||
are excluded from both losses.
|
||||
lambda_conf (`float`, *optional*, defaults to `0.0`): Weight for the confidence term.
|
||||
temperature (`float`, *optional*, defaults to `1.0`): Temperature used for the entropy term only. Lower values
|
||||
sharpen the distribution and change the strength of the confidence gradients.
|
||||
per_token_weights (`torch.Tensor`, *optional*): Optional weights of shape `(...)` to reweight both losses per
|
||||
token (e.g. schedule-aware weights). Tokens with weight `0` contribute nothing.
|
||||
ignore_index (`int`, *optional*, defaults to `-100`): Ignore index for labels.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, torch.Tensor, torch.Tensor]`: `(loss, loss_sft, loss_conf)`.
|
||||
"""
|
||||
if logits.ndim < 2:
|
||||
raise ValueError(f"`logits` must have at least 2 dims, got shape {tuple(logits.shape)}.")
|
||||
if labels.shape != logits.shape[:-1]:
|
||||
raise ValueError(
|
||||
f"`labels` shape must match `logits.shape[:-1]`, got labels={tuple(labels.shape)} logits={tuple(logits.shape)}."
|
||||
)
|
||||
if temperature <= 0:
|
||||
raise ValueError(f"`temperature` must be > 0, got {temperature}.")
|
||||
|
||||
valid = labels.ne(ignore_index)
|
||||
if per_token_weights is None:
|
||||
weights = torch.ones_like(labels, dtype=logits.dtype)
|
||||
else:
|
||||
if per_token_weights.shape != labels.shape:
|
||||
raise ValueError(
|
||||
f"`per_token_weights` shape must match `labels` shape, got {tuple(per_token_weights.shape)} != {tuple(labels.shape)}."
|
||||
)
|
||||
weights = per_token_weights.to(dtype=logits.dtype)
|
||||
|
||||
# Supervised CE (optionally weighted).
|
||||
vocab_size = logits.shape[-1]
|
||||
per_token_nll = F.cross_entropy(
|
||||
logits.reshape(-1, vocab_size),
|
||||
labels.reshape(-1),
|
||||
reduction="none",
|
||||
ignore_index=ignore_index,
|
||||
).reshape_as(labels)
|
||||
|
||||
denom_sft = (weights * valid.to(weights.dtype)).sum().clamp_min(1)
|
||||
loss_sft = (per_token_nll * weights * valid.to(per_token_nll.dtype)).sum() / denom_sft
|
||||
|
||||
# Confidence loss: penalize entropy only where prediction is already correct.
|
||||
if lambda_conf == 0.0:
|
||||
loss_conf = torch.zeros((), device=logits.device, dtype=loss_sft.dtype)
|
||||
return loss_sft, loss_sft, loss_conf
|
||||
|
||||
with torch.no_grad():
|
||||
pred = logits.argmax(dim=-1)
|
||||
correct = valid & pred.eq(labels)
|
||||
|
||||
scaled_logits = logits.float()
|
||||
if temperature != 1.0:
|
||||
scaled_logits = scaled_logits / float(temperature)
|
||||
|
||||
probs = torch.softmax(scaled_logits, dim=-1)
|
||||
eps = torch.finfo(probs.dtype).tiny
|
||||
log_probs = torch.log(probs.clamp_min(eps))
|
||||
entropy = -(probs * log_probs).sum(dim=-1).to(dtype=logits.dtype)
|
||||
|
||||
denom_conf = (weights * correct.to(weights.dtype)).sum().clamp_min(1)
|
||||
loss_conf = (entropy * weights * correct.to(entropy.dtype)).sum() / denom_conf
|
||||
|
||||
loss = loss_sft + float(lambda_conf) * loss_conf
|
||||
return loss, loss_sft, loss_conf
|
||||
|
||||
|
||||
def resolve_interpolation_mode(interpolation_type: str):
|
||||
"""
|
||||
Maps a string describing an interpolation function to the corresponding torchvision `InterpolationMode` enum. The
|
||||
|
||||
@@ -521,6 +521,36 @@ class AutoencoderKLHunyuanVideo15(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class AutoencoderKLKVAE(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class AutoencoderKLKVAEVideo(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class AutoencoderKLLTX2Audio(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
@@ -2488,6 +2518,36 @@ class AmusedScheduler(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class BlockRefinementScheduler(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class BlockRefinementSchedulerOutput(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class CMStochasticIterativeScheduler(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
||||
@@ -2222,6 +2222,36 @@ class LEditsPPPipelineStableDiffusionXL(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class LLaDA2Pipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class LLaDA2PipelineOutput(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class LongCatImageEditPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
|
||||
@@ -29,6 +29,7 @@ from numpy.linalg import norm
|
||||
from packaging import version
|
||||
|
||||
from .constants import DIFFUSERS_REQUEST_TIMEOUT
|
||||
from .deprecation_utils import deprecate
|
||||
from .import_utils import (
|
||||
BACKENDS_MAPPING,
|
||||
is_accelerate_available,
|
||||
@@ -67,9 +68,11 @@ else:
|
||||
global_rng = random.Random()
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger.warning(
|
||||
"diffusers.utils.testing_utils' is deprecated and will be removed in a future version. "
|
||||
"Determinism and device backend utilities have been moved to `diffusers.utils.torch_utils`. "
|
||||
deprecate(
|
||||
"diffusers.utils.testing_utils",
|
||||
"1.0.0",
|
||||
"diffusers.utils.testing_utils is deprecated and will be removed in a future version. "
|
||||
"Determinism and device backend utilities have been moved to `diffusers.utils.torch_utils`. ",
|
||||
)
|
||||
_required_peft_version = is_peft_available() and version.parse(
|
||||
version.parse(importlib.metadata.version("peft")).base_version
|
||||
|
||||
@@ -19,11 +19,16 @@ from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import os
|
||||
from typing import Callable, ParamSpec, TypeVar
|
||||
|
||||
from . import logging
|
||||
from .import_utils import is_torch_available, is_torch_mlu_available, is_torch_npu_available, is_torch_version
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
P = ParamSpec("P")
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
from torch.fft import fftn, fftshift, ifftn, ifftshift
|
||||
@@ -333,5 +338,23 @@ def disable_full_determinism():
|
||||
torch.use_deterministic_algorithms(False)
|
||||
|
||||
|
||||
@functools.wraps(functools.lru_cache)
|
||||
def lru_cache_unless_export(maxsize=128, typed=False):
|
||||
def outer_wrapper(fn: Callable[P, T]):
|
||||
cached = functools.lru_cache(maxsize=maxsize, typed=typed)(fn)
|
||||
if is_torch_version("<", "2.7.0"):
|
||||
return cached
|
||||
|
||||
@functools.wraps(fn)
|
||||
def inner_wrapper(*args: P.args, **kwargs: P.kwargs):
|
||||
if torch.compiler.is_exporting():
|
||||
return fn(*args, **kwargs)
|
||||
return cached(*args, **kwargs)
|
||||
|
||||
return inner_wrapper
|
||||
|
||||
return outer_wrapper
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
torch_device = get_device()
|
||||
|
||||
@@ -28,7 +28,6 @@ from diffusers.utils.import_utils import is_peft_available
|
||||
|
||||
from ..testing_utils import (
|
||||
floats_tensor,
|
||||
is_flaky,
|
||||
require_peft_backend,
|
||||
require_peft_version_greater,
|
||||
skip_mps,
|
||||
@@ -46,7 +45,6 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
|
||||
|
||||
@require_peft_backend
|
||||
@skip_mps
|
||||
@is_flaky(max_attempts=10, description="very flaky class")
|
||||
class WanVACELoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
pipeline_class = WanVACEPipeline
|
||||
scheduler_cls = FlowMatchEulerDiscreteScheduler
|
||||
@@ -73,8 +71,8 @@ class WanVACELoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
"base_dim": 3,
|
||||
"z_dim": 4,
|
||||
"dim_mult": [1, 1, 1, 1],
|
||||
"latents_mean": torch.randn(4).numpy().tolist(),
|
||||
"latents_std": torch.randn(4).numpy().tolist(),
|
||||
"latents_mean": [-0.7571, -0.7089, -0.9113, -0.7245],
|
||||
"latents_std": [2.8184, 1.4541, 2.3275, 2.6558],
|
||||
"num_res_blocks": 1,
|
||||
"temperal_downsample": [False, True, True],
|
||||
}
|
||||
|
||||
73
tests/models/autoencoders/test_models_autoencoder_kl_kvae.py
Normal file
73
tests/models/autoencoders/test_models_autoencoder_kl_kvae.py
Normal file
@@ -0,0 +1,73 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from diffusers import AutoencoderKLKVAE
|
||||
|
||||
from ...testing_utils import enable_full_determinism, floats_tensor, torch_device
|
||||
from ..test_modeling_common import ModelTesterMixin
|
||||
from .testing_utils import AutoencoderTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class AutoencoderKLKVAETests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
|
||||
model_class = AutoencoderKLKVAE
|
||||
main_input_name = "sample"
|
||||
base_precision = 1e-2
|
||||
|
||||
def get_autoencoder_kl_kvae_config(self):
|
||||
return {
|
||||
"in_channels": 3,
|
||||
"channels": 32,
|
||||
"num_enc_blocks": 1,
|
||||
"num_dec_blocks": 1,
|
||||
"z_channels": 4,
|
||||
"double_z": True,
|
||||
"ch_mult": (1, 2),
|
||||
"sample_size": 32,
|
||||
}
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
batch_size = 2
|
||||
num_channels = 3
|
||||
sizes = (32, 32)
|
||||
|
||||
image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
|
||||
|
||||
return {"sample": image}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (3, 32, 32)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (3, 32, 32)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = self.get_autoencoder_kl_kvae_config()
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {
|
||||
"KVAEEncoder2D",
|
||||
"KVAEDecoder2D",
|
||||
}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
@@ -0,0 +1,118 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from diffusers import AutoencoderKLKVAEVideo
|
||||
|
||||
from ...testing_utils import enable_full_determinism, floats_tensor, torch_device
|
||||
from ..test_modeling_common import ModelTesterMixin
|
||||
from .testing_utils import AutoencoderTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class AutoencoderKLKVAEVideoTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
|
||||
model_class = AutoencoderKLKVAEVideo
|
||||
main_input_name = "sample"
|
||||
base_precision = 1e-2
|
||||
|
||||
def get_autoencoder_kl_kvae_video_config(self):
|
||||
return {
|
||||
"ch": 32,
|
||||
"ch_mult": (1, 2),
|
||||
"num_res_blocks": 1,
|
||||
"in_channels": 3,
|
||||
"out_ch": 3,
|
||||
"z_channels": 4,
|
||||
"temporal_compress_times": 2,
|
||||
}
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
batch_size = 2
|
||||
num_frames = 3 # satisfies (T-1) % temporal_compress_times == 0 with temporal_compress_times=2
|
||||
num_channels = 3
|
||||
sizes = (16, 16)
|
||||
|
||||
video = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device)
|
||||
|
||||
return {"sample": video}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (3, 3, 16, 16)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (3, 3, 16, 16)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = self.get_autoencoder_kl_kvae_video_config()
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {
|
||||
"KVAECachedEncoder3D",
|
||||
"KVAECachedDecoder3D",
|
||||
}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
@unittest.skip("Unsupported test.")
|
||||
def test_outputs_equivalence(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
"Multi-GPU inference is not supported due to the stateful cache_dict passing through the forward pass."
|
||||
)
|
||||
def test_model_parallelism(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
"Multi-GPU inference is not supported due to the stateful cache_dict passing through the forward pass."
|
||||
)
|
||||
def test_sharded_checkpoints_device_map(self):
|
||||
pass
|
||||
|
||||
def _run_nondeterministic(self, fn):
|
||||
# reflection_pad3d_backward_out_cuda has no deterministic CUDA implementation;
|
||||
# temporarily relax the requirement for training tests that do backward passes.
|
||||
import torch
|
||||
|
||||
torch.use_deterministic_algorithms(False)
|
||||
try:
|
||||
fn()
|
||||
finally:
|
||||
torch.use_deterministic_algorithms(True)
|
||||
|
||||
def test_training(self):
|
||||
self._run_nondeterministic(super().test_training)
|
||||
|
||||
def test_ema_training(self):
|
||||
self._run_nondeterministic(super().test_ema_training)
|
||||
|
||||
@unittest.skip(
|
||||
"Gradient checkpointing recomputes the forward pass, but the model uses a stateful cache_dict "
|
||||
"that is mutated during the first forward. On recomputation the cache is already populated, "
|
||||
"causing a different execution path and numerically different gradients. "
|
||||
"GC still reduces peak memory usage; gradient correctness in the presence of GC is a known limitation."
|
||||
)
|
||||
def test_effective_gradient_checkpointing(self):
|
||||
pass
|
||||
|
||||
def test_layerwise_casting_training(self):
|
||||
self._run_nondeterministic(super().test_layerwise_casting_training)
|
||||
@@ -481,6 +481,8 @@ class LoraHotSwappingForModelTesterMixin:
|
||||
# ensure that enable_lora_hotswap is called before loading the first adapter
|
||||
import logging
|
||||
|
||||
from diffusers.utils import logging as diffusers_logging
|
||||
|
||||
lora_config = self._get_lora_config(8, 8, target_modules=["to_q"])
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
@@ -488,21 +490,31 @@ class LoraHotSwappingForModelTesterMixin:
|
||||
msg = (
|
||||
"It is recommended to call `enable_lora_hotswap` before loading the first adapter to avoid recompilation."
|
||||
)
|
||||
with caplog.at_level(logging.WARNING):
|
||||
model.enable_lora_hotswap(target_rank=32, check_compiled="warn")
|
||||
assert any(msg in record.message for record in caplog.records)
|
||||
diffusers_logging.enable_propagation()
|
||||
try:
|
||||
with caplog.at_level(logging.WARNING):
|
||||
model.enable_lora_hotswap(target_rank=32, check_compiled="warn")
|
||||
assert any(msg in record.message for record in caplog.records)
|
||||
finally:
|
||||
diffusers_logging.disable_propagation()
|
||||
|
||||
def test_enable_lora_hotswap_called_after_adapter_added_ignore(self, caplog):
|
||||
# check possibility to ignore the error/warning
|
||||
import logging
|
||||
|
||||
from diffusers.utils import logging as diffusers_logging
|
||||
|
||||
lora_config = self._get_lora_config(8, 8, target_modules=["to_q"])
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
model.add_adapter(lora_config)
|
||||
with caplog.at_level(logging.WARNING):
|
||||
model.enable_lora_hotswap(target_rank=32, check_compiled="ignore")
|
||||
assert len(caplog.records) == 0
|
||||
diffusers_logging.enable_propagation()
|
||||
try:
|
||||
with caplog.at_level(logging.WARNING):
|
||||
model.enable_lora_hotswap(target_rank=32, check_compiled="ignore")
|
||||
assert len(caplog.records) == 0
|
||||
finally:
|
||||
diffusers_logging.disable_propagation()
|
||||
|
||||
def test_enable_lora_hotswap_wrong_check_compiled_argument_raises(self):
|
||||
# check that wrong argument value raises an error
|
||||
@@ -518,20 +530,26 @@ class LoraHotSwappingForModelTesterMixin:
|
||||
# check the error and log
|
||||
import logging
|
||||
|
||||
from diffusers.utils import logging as diffusers_logging
|
||||
|
||||
# at the moment, PEFT requires the 2nd adapter to target the same or a subset of layers
|
||||
target_modules0 = ["to_q"]
|
||||
target_modules1 = ["to_q", "to_k"]
|
||||
with pytest.raises(RuntimeError): # peft raises RuntimeError
|
||||
with caplog.at_level(logging.ERROR):
|
||||
self._check_model_hotswap(
|
||||
tmp_path,
|
||||
do_compile=True,
|
||||
rank0=8,
|
||||
rank1=8,
|
||||
target_modules0=target_modules0,
|
||||
target_modules1=target_modules1,
|
||||
)
|
||||
assert any("Hotswapping adapter0 was unsuccessful" in record.message for record in caplog.records)
|
||||
diffusers_logging.enable_propagation()
|
||||
try:
|
||||
with pytest.raises(RuntimeError): # peft raises RuntimeError
|
||||
with caplog.at_level(logging.ERROR):
|
||||
self._check_model_hotswap(
|
||||
tmp_path,
|
||||
do_compile=True,
|
||||
rank0=8,
|
||||
rank1=8,
|
||||
target_modules0=target_modules0,
|
||||
target_modules1=target_modules1,
|
||||
)
|
||||
assert any("Hotswapping adapter0 was unsuccessful" in record.message for record in caplog.records)
|
||||
finally:
|
||||
diffusers_logging.disable_propagation()
|
||||
|
||||
@pytest.mark.parametrize("rank0,rank1", [(11, 11), (7, 13), (13, 7)])
|
||||
@require_torch_version_greater("2.7.1")
|
||||
|
||||
@@ -22,6 +22,7 @@ import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
from diffusers.models._modeling_parallel import ContextParallelConfig
|
||||
from diffusers.models.attention_dispatch import AttentionBackendName, _AttentionBackendRegistry
|
||||
|
||||
from ...testing_utils import (
|
||||
is_context_parallel,
|
||||
@@ -160,16 +161,21 @@ def _custom_mesh_worker(
|
||||
@require_torch_multi_accelerator
|
||||
class ContextParallelTesterMixin:
|
||||
@pytest.mark.parametrize("cp_type", ["ulysses_degree", "ring_degree"], ids=["ulysses", "ring"])
|
||||
def test_context_parallel_inference(self, cp_type):
|
||||
def test_context_parallel_inference(self, cp_type, batch_size: int = 1):
|
||||
if not torch.distributed.is_available():
|
||||
pytest.skip("torch.distributed is not available.")
|
||||
|
||||
if not hasattr(self.model_class, "_cp_plan") or self.model_class._cp_plan is None:
|
||||
pytest.skip("Model does not have a _cp_plan defined for context parallel inference.")
|
||||
|
||||
if cp_type == "ring_degree":
|
||||
active_backend, _ = _AttentionBackendRegistry.get_active_backend()
|
||||
if active_backend == AttentionBackendName.NATIVE:
|
||||
pytest.skip("Ring attention is not supported with the native attention backend.")
|
||||
|
||||
world_size = 2
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
inputs_dict = self.get_dummy_inputs(batch_size=batch_size)
|
||||
|
||||
# Move all tensors to CPU for multiprocessing
|
||||
inputs_dict = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in inputs_dict.items()}
|
||||
@@ -194,6 +200,10 @@ class ContextParallelTesterMixin:
|
||||
f"Context parallel inference failed: {return_dict.get('error', 'Unknown error')}"
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("cp_type", ["ulysses_degree", "ring_degree"], ids=["ulysses", "ring"])
|
||||
def test_context_parallel_batch_inputs(self, cp_type):
|
||||
self.test_context_parallel_inference(cp_type, batch_size=2)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"cp_type,mesh_shape,mesh_dim_names",
|
||||
[
|
||||
@@ -209,6 +219,11 @@ class ContextParallelTesterMixin:
|
||||
if not hasattr(self.model_class, "_cp_plan") or self.model_class._cp_plan is None:
|
||||
pytest.skip("Model does not have a _cp_plan defined for context parallel inference.")
|
||||
|
||||
if cp_type == "ring_degree":
|
||||
active_backend, _ = _AttentionBackendRegistry.get_active_backend()
|
||||
if active_backend == AttentionBackendName.NATIVE:
|
||||
pytest.skip("Ring attention is not supported with the native attention backend.")
|
||||
|
||||
world_size = 2
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in self.get_dummy_inputs().items()}
|
||||
|
||||
@@ -41,7 +41,6 @@ from ..testing_utils import (
|
||||
ModelOptCompileTesterMixin,
|
||||
ModelOptTesterMixin,
|
||||
ModelTesterMixin,
|
||||
PyramidAttentionBroadcastTesterMixin,
|
||||
QuantoCompileTesterMixin,
|
||||
QuantoTesterMixin,
|
||||
SingleFileTesterMixin,
|
||||
@@ -151,8 +150,7 @@ class FluxTransformerTesterConfig(BaseModelTesterConfig):
|
||||
"axes_dims_rope": [4, 4, 8],
|
||||
}
|
||||
|
||||
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
|
||||
batch_size = 1
|
||||
def get_dummy_inputs(self, batch_size: int = 1) -> dict[str, torch.Tensor]:
|
||||
height = width = 4
|
||||
num_latent_channels = 4
|
||||
num_image_channels = 3
|
||||
@@ -219,6 +217,10 @@ class TestFluxTransformerMemory(FluxTransformerTesterConfig, MemoryTesterMixin):
|
||||
class TestFluxTransformerTraining(FluxTransformerTesterConfig, TrainingTesterMixin):
|
||||
"""Training tests for Flux Transformer."""
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"FluxTransformer2DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
|
||||
class TestFluxTransformerAttention(FluxTransformerTesterConfig, AttentionTesterMixin):
|
||||
"""Attention processor tests for Flux Transformer."""
|
||||
@@ -412,10 +414,6 @@ class TestFluxTransformerBitsAndBytesCompile(FluxTransformerTesterConfig, BitsAn
|
||||
"""BitsAndBytes + compile tests for Flux Transformer."""
|
||||
|
||||
|
||||
class TestFluxTransformerPABCache(FluxTransformerTesterConfig, PyramidAttentionBroadcastTesterMixin):
|
||||
"""PyramidAttentionBroadcast cache tests for Flux Transformer."""
|
||||
|
||||
|
||||
class TestFluxTransformerFBCCache(FluxTransformerTesterConfig, FirstBlockCacheTesterMixin):
|
||||
"""FirstBlockCache tests for Flux Transformer."""
|
||||
|
||||
|
||||
@@ -13,48 +13,94 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import Flux2Transformer2DModel, attention_backend
|
||||
from diffusers import Flux2Transformer2DModel
|
||||
from diffusers.models.transformers.transformer_flux2 import (
|
||||
Flux2KVAttnProcessor,
|
||||
Flux2KVCache,
|
||||
Flux2KVLayerCache,
|
||||
Flux2KVParallelSelfAttnProcessor,
|
||||
)
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
from ...testing_utils import enable_full_determinism, torch_device
|
||||
from ..test_modeling_common import LoraHotSwappingForModelTesterMixin, ModelTesterMixin, TorchCompileTesterMixin
|
||||
from ..testing_utils import (
|
||||
AttentionTesterMixin,
|
||||
BaseModelTesterConfig,
|
||||
BitsAndBytesTesterMixin,
|
||||
ContextParallelTesterMixin,
|
||||
GGUFCompileTesterMixin,
|
||||
GGUFTesterMixin,
|
||||
LoraHotSwappingForModelTesterMixin,
|
||||
LoraTesterMixin,
|
||||
MemoryTesterMixin,
|
||||
ModelTesterMixin,
|
||||
TorchAoCompileTesterMixin,
|
||||
TorchAoTesterMixin,
|
||||
TorchCompileTesterMixin,
|
||||
TrainingTesterMixin,
|
||||
)
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class Flux2TransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = Flux2Transformer2DModel
|
||||
main_input_name = "hidden_states"
|
||||
# We override the items here because the transformer under consideration is small.
|
||||
model_split_percents = [0.7, 0.6, 0.6]
|
||||
|
||||
# Skip setting testing with default: AttnProcessor
|
||||
uses_custom_attn_processor = True
|
||||
class Flux2TransformerTesterConfig(BaseModelTesterConfig):
|
||||
@property
|
||||
def model_class(self):
|
||||
return Flux2Transformer2DModel
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
return self.prepare_dummy_input()
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
def output_shape(self) -> tuple[int, int]:
|
||||
return (16, 4)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
def input_shape(self) -> tuple[int, int]:
|
||||
return (16, 4)
|
||||
|
||||
def prepare_dummy_input(self, height=4, width=4):
|
||||
batch_size = 1
|
||||
@property
|
||||
def model_split_percents(self) -> list:
|
||||
# We override the items here because the transformer under consideration is small.
|
||||
return [0.7, 0.6, 0.6]
|
||||
|
||||
@property
|
||||
def main_input_name(self) -> str:
|
||||
return "hidden_states"
|
||||
|
||||
@property
|
||||
def uses_custom_attn_processor(self) -> bool:
|
||||
# Skip setting testing with default: AttnProcessor
|
||||
return True
|
||||
|
||||
@property
|
||||
def generator(self):
|
||||
return torch.Generator("cpu").manual_seed(0)
|
||||
|
||||
def get_init_dict(self) -> dict[str, int | list[int]]:
|
||||
return {
|
||||
"patch_size": 1,
|
||||
"in_channels": 4,
|
||||
"num_layers": 1,
|
||||
"num_single_layers": 1,
|
||||
"attention_head_dim": 16,
|
||||
"num_attention_heads": 2,
|
||||
"joint_attention_dim": 32,
|
||||
"timestep_guidance_channels": 256, # Hardcoded in original code
|
||||
"axes_dims_rope": [4, 4, 4, 4],
|
||||
}
|
||||
|
||||
def get_dummy_inputs(self, height: int = 4, width: int = 4, batch_size: int = 1) -> dict[str, torch.Tensor]:
|
||||
num_latent_channels = 4
|
||||
sequence_length = 48
|
||||
embedding_dim = 32
|
||||
|
||||
hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(torch_device)
|
||||
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
|
||||
hidden_states = randn_tensor(
|
||||
(batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device
|
||||
)
|
||||
encoder_hidden_states = randn_tensor(
|
||||
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
|
||||
)
|
||||
|
||||
t_coords = torch.arange(1)
|
||||
h_coords = torch.arange(height)
|
||||
@@ -82,8 +128,286 @@ class Flux2TransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
"guidance": guidance,
|
||||
}
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
|
||||
class TestFlux2Transformer(Flux2TransformerTesterConfig, ModelTesterMixin):
|
||||
pass
|
||||
|
||||
|
||||
class TestFlux2TransformerMemory(Flux2TransformerTesterConfig, MemoryTesterMixin):
|
||||
"""Memory optimization tests for Flux2 Transformer."""
|
||||
|
||||
|
||||
class TestFlux2TransformerTraining(Flux2TransformerTesterConfig, TrainingTesterMixin):
|
||||
"""Training tests for Flux2 Transformer."""
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"Flux2Transformer2DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
|
||||
class TestFlux2TransformerAttention(Flux2TransformerTesterConfig, AttentionTesterMixin):
|
||||
"""Attention processor tests for Flux2 Transformer."""
|
||||
|
||||
|
||||
class TestFlux2TransformerContextParallel(Flux2TransformerTesterConfig, ContextParallelTesterMixin):
|
||||
"""Context Parallel inference tests for Flux2 Transformer."""
|
||||
|
||||
|
||||
class TestFlux2TransformerLoRA(Flux2TransformerTesterConfig, LoraTesterMixin):
|
||||
"""LoRA adapter tests for Flux2 Transformer."""
|
||||
|
||||
|
||||
class TestFlux2TransformerLoRAHotSwap(Flux2TransformerTesterConfig, LoraHotSwappingForModelTesterMixin):
|
||||
"""LoRA hot-swapping tests for Flux2 Transformer."""
|
||||
|
||||
@property
|
||||
def different_shapes_for_compilation(self):
|
||||
return [(4, 4), (4, 8), (8, 8)]
|
||||
|
||||
def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:
|
||||
"""Override to support dynamic height/width for LoRA hotswap tests."""
|
||||
batch_size = 1
|
||||
num_latent_channels = 4
|
||||
sequence_length = 48
|
||||
embedding_dim = 32
|
||||
|
||||
hidden_states = randn_tensor(
|
||||
(batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device
|
||||
)
|
||||
encoder_hidden_states = randn_tensor(
|
||||
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
|
||||
)
|
||||
|
||||
t_coords = torch.arange(1)
|
||||
h_coords = torch.arange(height)
|
||||
w_coords = torch.arange(width)
|
||||
l_coords = torch.arange(1)
|
||||
image_ids = torch.cartesian_prod(t_coords, h_coords, w_coords, l_coords)
|
||||
image_ids = image_ids.unsqueeze(0).expand(batch_size, -1, -1).to(torch_device)
|
||||
|
||||
text_t_coords = torch.arange(1)
|
||||
text_h_coords = torch.arange(1)
|
||||
text_w_coords = torch.arange(1)
|
||||
text_l_coords = torch.arange(sequence_length)
|
||||
text_ids = torch.cartesian_prod(text_t_coords, text_h_coords, text_w_coords, text_l_coords)
|
||||
text_ids = text_ids.unsqueeze(0).expand(batch_size, -1, -1).to(torch_device)
|
||||
|
||||
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
|
||||
guidance = torch.tensor([1.0]).to(torch_device).expand(batch_size)
|
||||
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"img_ids": image_ids,
|
||||
"txt_ids": text_ids,
|
||||
"timestep": timestep,
|
||||
"guidance": guidance,
|
||||
}
|
||||
|
||||
|
||||
class TestFlux2TransformerCompile(Flux2TransformerTesterConfig, TorchCompileTesterMixin):
|
||||
@property
|
||||
def different_shapes_for_compilation(self):
|
||||
return [(4, 4), (4, 8), (8, 8)]
|
||||
|
||||
def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:
|
||||
"""Override to support dynamic height/width for compilation tests."""
|
||||
batch_size = 1
|
||||
num_latent_channels = 4
|
||||
sequence_length = 48
|
||||
embedding_dim = 32
|
||||
|
||||
hidden_states = randn_tensor(
|
||||
(batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device
|
||||
)
|
||||
encoder_hidden_states = randn_tensor(
|
||||
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
|
||||
)
|
||||
|
||||
t_coords = torch.arange(1)
|
||||
h_coords = torch.arange(height)
|
||||
w_coords = torch.arange(width)
|
||||
l_coords = torch.arange(1)
|
||||
image_ids = torch.cartesian_prod(t_coords, h_coords, w_coords, l_coords)
|
||||
image_ids = image_ids.unsqueeze(0).expand(batch_size, -1, -1).to(torch_device)
|
||||
|
||||
text_t_coords = torch.arange(1)
|
||||
text_h_coords = torch.arange(1)
|
||||
text_w_coords = torch.arange(1)
|
||||
text_l_coords = torch.arange(sequence_length)
|
||||
text_ids = torch.cartesian_prod(text_t_coords, text_h_coords, text_w_coords, text_l_coords)
|
||||
text_ids = text_ids.unsqueeze(0).expand(batch_size, -1, -1).to(torch_device)
|
||||
|
||||
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
|
||||
guidance = torch.tensor([1.0]).to(torch_device).expand(batch_size)
|
||||
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"img_ids": image_ids,
|
||||
"txt_ids": text_ids,
|
||||
"timestep": timestep,
|
||||
"guidance": guidance,
|
||||
}
|
||||
|
||||
|
||||
class TestFlux2TransformerBitsAndBytes(Flux2TransformerTesterConfig, BitsAndBytesTesterMixin):
|
||||
"""BitsAndBytes quantization tests for Flux2 Transformer."""
|
||||
|
||||
|
||||
class TestFlux2TransformerTorchAo(Flux2TransformerTesterConfig, TorchAoTesterMixin):
|
||||
"""TorchAO quantization tests for Flux2 Transformer."""
|
||||
|
||||
|
||||
class TestFlux2TransformerGGUF(Flux2TransformerTesterConfig, GGUFTesterMixin):
|
||||
"""GGUF quantization tests for Flux2 Transformer."""
|
||||
|
||||
@property
|
||||
def gguf_filename(self):
|
||||
return "https://huggingface.co/unsloth/FLUX.2-dev-GGUF/blob/main/flux2-dev-Q2_K.gguf"
|
||||
|
||||
@property
|
||||
def torch_dtype(self):
|
||||
return torch.bfloat16
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
"""Override to provide inputs matching the real FLUX2 model dimensions.
|
||||
|
||||
Flux2 defaults: in_channels=128, joint_attention_dim=15360
|
||||
"""
|
||||
batch_size = 1
|
||||
height = 64
|
||||
width = 64
|
||||
sequence_length = 512
|
||||
|
||||
hidden_states = randn_tensor(
|
||||
(batch_size, height * width, 128), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
)
|
||||
encoder_hidden_states = randn_tensor(
|
||||
(batch_size, sequence_length, 15360), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
)
|
||||
|
||||
# Flux2 uses 4D image/text IDs (t, h, w, l)
|
||||
t_coords = torch.arange(1)
|
||||
h_coords = torch.arange(height)
|
||||
w_coords = torch.arange(width)
|
||||
l_coords = torch.arange(1)
|
||||
image_ids = torch.cartesian_prod(t_coords, h_coords, w_coords, l_coords)
|
||||
image_ids = image_ids.unsqueeze(0).expand(batch_size, -1, -1).to(torch_device)
|
||||
|
||||
text_t_coords = torch.arange(1)
|
||||
text_h_coords = torch.arange(1)
|
||||
text_w_coords = torch.arange(1)
|
||||
text_l_coords = torch.arange(sequence_length)
|
||||
text_ids = torch.cartesian_prod(text_t_coords, text_h_coords, text_w_coords, text_l_coords)
|
||||
text_ids = text_ids.unsqueeze(0).expand(batch_size, -1, -1).to(torch_device)
|
||||
|
||||
timestep = torch.tensor([1.0]).to(torch_device, self.torch_dtype)
|
||||
guidance = torch.tensor([3.5]).to(torch_device, self.torch_dtype)
|
||||
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"img_ids": image_ids,
|
||||
"txt_ids": text_ids,
|
||||
"timestep": timestep,
|
||||
"guidance": guidance,
|
||||
}
|
||||
|
||||
|
||||
class TestFlux2TransformerTorchAoCompile(Flux2TransformerTesterConfig, TorchAoCompileTesterMixin):
|
||||
"""TorchAO + compile tests for Flux2 Transformer."""
|
||||
|
||||
|
||||
class TestFlux2TransformerGGUFCompile(Flux2TransformerTesterConfig, GGUFCompileTesterMixin):
|
||||
"""GGUF + compile tests for Flux2 Transformer."""
|
||||
|
||||
@property
|
||||
def gguf_filename(self):
|
||||
return "https://huggingface.co/unsloth/FLUX.2-dev-GGUF/blob/main/flux2-dev-Q2_K.gguf"
|
||||
|
||||
@property
|
||||
def torch_dtype(self):
|
||||
return torch.bfloat16
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
"""Override to provide inputs matching the real FLUX2 model dimensions.
|
||||
|
||||
Flux2 defaults: in_channels=128, joint_attention_dim=15360
|
||||
"""
|
||||
batch_size = 1
|
||||
height = 64
|
||||
width = 64
|
||||
sequence_length = 512
|
||||
|
||||
hidden_states = randn_tensor(
|
||||
(batch_size, height * width, 128), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
)
|
||||
encoder_hidden_states = randn_tensor(
|
||||
(batch_size, sequence_length, 15360), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
)
|
||||
|
||||
# Flux2 uses 4D image/text IDs (t, h, w, l)
|
||||
t_coords = torch.arange(1)
|
||||
h_coords = torch.arange(height)
|
||||
w_coords = torch.arange(width)
|
||||
l_coords = torch.arange(1)
|
||||
image_ids = torch.cartesian_prod(t_coords, h_coords, w_coords, l_coords)
|
||||
image_ids = image_ids.unsqueeze(0).expand(batch_size, -1, -1).to(torch_device)
|
||||
|
||||
text_t_coords = torch.arange(1)
|
||||
text_h_coords = torch.arange(1)
|
||||
text_w_coords = torch.arange(1)
|
||||
text_l_coords = torch.arange(sequence_length)
|
||||
text_ids = torch.cartesian_prod(text_t_coords, text_h_coords, text_w_coords, text_l_coords)
|
||||
text_ids = text_ids.unsqueeze(0).expand(batch_size, -1, -1).to(torch_device)
|
||||
|
||||
timestep = torch.tensor([1.0]).to(torch_device, self.torch_dtype)
|
||||
guidance = torch.tensor([3.5]).to(torch_device, self.torch_dtype)
|
||||
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"img_ids": image_ids,
|
||||
"txt_ids": text_ids,
|
||||
"timestep": timestep,
|
||||
"guidance": guidance,
|
||||
}
|
||||
|
||||
|
||||
class Flux2TransformerKVCacheTesterConfig(BaseModelTesterConfig):
|
||||
num_ref_tokens = 4
|
||||
|
||||
@property
|
||||
def model_class(self):
|
||||
return Flux2Transformer2DModel
|
||||
|
||||
@property
|
||||
def output_shape(self) -> tuple[int, int]:
|
||||
return (16, 4)
|
||||
|
||||
@property
|
||||
def input_shape(self) -> tuple[int, int]:
|
||||
return (16, 4)
|
||||
|
||||
@property
|
||||
def model_split_percents(self) -> list:
|
||||
return [0.7, 0.6, 0.6]
|
||||
|
||||
@property
|
||||
def main_input_name(self) -> str:
|
||||
return "hidden_states"
|
||||
|
||||
@property
|
||||
def uses_custom_attn_processor(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def generator(self):
|
||||
return torch.Generator("cpu").manual_seed(0)
|
||||
|
||||
def get_init_dict(self) -> dict[str, int | list[int]]:
|
||||
return {
|
||||
"patch_size": 1,
|
||||
"in_channels": 4,
|
||||
"num_layers": 1,
|
||||
@@ -91,72 +415,210 @@ class Flux2TransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
"attention_head_dim": 16,
|
||||
"num_attention_heads": 2,
|
||||
"joint_attention_dim": 32,
|
||||
"timestep_guidance_channels": 256, # Hardcoded in original code
|
||||
"timestep_guidance_channels": 256,
|
||||
"axes_dims_rope": [4, 4, 4, 4],
|
||||
}
|
||||
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:
|
||||
batch_size = 1
|
||||
num_latent_channels = 4
|
||||
sequence_length = 48
|
||||
embedding_dim = 32
|
||||
num_ref_tokens = self.num_ref_tokens
|
||||
|
||||
# TODO (Daniel, Sayak): We can remove this test.
|
||||
def test_flux2_consistency(self, seed=0):
|
||||
torch.manual_seed(seed)
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
ref_hidden_states = randn_tensor(
|
||||
(batch_size, num_ref_tokens, num_latent_channels), generator=self.generator, device=torch_device
|
||||
)
|
||||
img_hidden_states = randn_tensor(
|
||||
(batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device
|
||||
)
|
||||
hidden_states = torch.cat([ref_hidden_states, img_hidden_states], dim=1)
|
||||
|
||||
torch.manual_seed(seed)
|
||||
model = self.model_class(**init_dict)
|
||||
# state_dict = model.state_dict()
|
||||
# for key, param in state_dict.items():
|
||||
# print(f"{key} | {param.shape}")
|
||||
# torch.save(state_dict, "/raid/daniel_gu/test_flux2_params/diffusers.pt")
|
||||
encoder_hidden_states = randn_tensor(
|
||||
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
|
||||
)
|
||||
|
||||
ref_t_coords = torch.arange(1)
|
||||
ref_h_coords = torch.arange(num_ref_tokens)
|
||||
ref_w_coords = torch.arange(1)
|
||||
ref_l_coords = torch.arange(1)
|
||||
ref_ids = torch.cartesian_prod(ref_t_coords, ref_h_coords, ref_w_coords, ref_l_coords)
|
||||
ref_ids = ref_ids.unsqueeze(0).expand(batch_size, -1, -1).to(torch_device)
|
||||
|
||||
t_coords = torch.arange(1)
|
||||
h_coords = torch.arange(height)
|
||||
w_coords = torch.arange(width)
|
||||
l_coords = torch.arange(1)
|
||||
image_ids = torch.cartesian_prod(t_coords, h_coords, w_coords, l_coords)
|
||||
image_ids = image_ids.unsqueeze(0).expand(batch_size, -1, -1).to(torch_device)
|
||||
image_ids = torch.cat([ref_ids, image_ids], dim=1)
|
||||
|
||||
text_t_coords = torch.arange(1)
|
||||
text_h_coords = torch.arange(1)
|
||||
text_w_coords = torch.arange(1)
|
||||
text_l_coords = torch.arange(sequence_length)
|
||||
text_ids = torch.cartesian_prod(text_t_coords, text_h_coords, text_w_coords, text_l_coords)
|
||||
text_ids = text_ids.unsqueeze(0).expand(batch_size, -1, -1).to(torch_device)
|
||||
|
||||
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
|
||||
guidance = torch.tensor([1.0]).to(torch_device).expand(batch_size)
|
||||
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"img_ids": image_ids,
|
||||
"txt_ids": text_ids,
|
||||
"timestep": timestep,
|
||||
"guidance": guidance,
|
||||
}
|
||||
|
||||
|
||||
class TestFlux2TransformerKVCache(Flux2TransformerKVCacheTesterConfig):
|
||||
"""KV cache tests for Flux2 Transformer."""
|
||||
|
||||
def test_kv_layer_cache_store_and_get(self):
|
||||
cache = Flux2KVLayerCache()
|
||||
k = torch.randn(1, 4, 2, 16)
|
||||
v = torch.randn(1, 4, 2, 16)
|
||||
cache.store(k, v)
|
||||
k_out, v_out = cache.get()
|
||||
assert torch.equal(k, k_out)
|
||||
assert torch.equal(v, v_out)
|
||||
|
||||
def test_kv_layer_cache_get_before_store_raises(self):
|
||||
cache = Flux2KVLayerCache()
|
||||
try:
|
||||
cache.get()
|
||||
assert False, "Expected RuntimeError"
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
def test_kv_layer_cache_clear(self):
|
||||
cache = Flux2KVLayerCache()
|
||||
cache.store(torch.randn(1, 4, 2, 16), torch.randn(1, 4, 2, 16))
|
||||
cache.clear()
|
||||
assert cache.k_ref is None
|
||||
assert cache.v_ref is None
|
||||
|
||||
def test_kv_cache_structure(self):
|
||||
num_double = 3
|
||||
num_single = 2
|
||||
cache = Flux2KVCache(num_double, num_single)
|
||||
assert len(cache.double_block_caches) == num_double
|
||||
assert len(cache.single_block_caches) == num_single
|
||||
assert cache.num_ref_tokens == 0
|
||||
|
||||
for i in range(num_double):
|
||||
assert isinstance(cache.get_double(i), Flux2KVLayerCache)
|
||||
for i in range(num_single):
|
||||
assert isinstance(cache.get_single(i), Flux2KVLayerCache)
|
||||
|
||||
def test_kv_cache_clear(self):
|
||||
cache = Flux2KVCache(2, 1)
|
||||
cache.num_ref_tokens = 4
|
||||
cache.get_double(0).store(torch.randn(1, 4, 2, 16), torch.randn(1, 4, 2, 16))
|
||||
cache.clear()
|
||||
assert cache.num_ref_tokens == 0
|
||||
assert cache.get_double(0).k_ref is None
|
||||
|
||||
def _set_kv_attn_processors(self, model):
|
||||
for block in model.transformer_blocks:
|
||||
block.attn.set_processor(Flux2KVAttnProcessor())
|
||||
for block in model.single_transformer_blocks:
|
||||
block.attn.set_processor(Flux2KVParallelSelfAttnProcessor())
|
||||
|
||||
@torch.no_grad()
|
||||
def test_extract_mode_returns_cache(self):
|
||||
model = self.model_class(**self.get_init_dict())
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
self._set_kv_attn_processors(model)
|
||||
|
||||
output = model(
|
||||
**self.get_dummy_inputs(),
|
||||
kv_cache_mode="extract",
|
||||
num_ref_tokens=self.num_ref_tokens,
|
||||
ref_fixed_timestep=0.0,
|
||||
)
|
||||
|
||||
assert output.kv_cache is not None
|
||||
assert isinstance(output.kv_cache, Flux2KVCache)
|
||||
assert output.kv_cache.num_ref_tokens == self.num_ref_tokens
|
||||
|
||||
for layer_cache in output.kv_cache.double_block_caches:
|
||||
assert layer_cache.k_ref is not None
|
||||
assert layer_cache.v_ref is not None
|
||||
|
||||
for layer_cache in output.kv_cache.single_block_caches:
|
||||
assert layer_cache.k_ref is not None
|
||||
assert layer_cache.v_ref is not None
|
||||
|
||||
@torch.no_grad()
|
||||
def test_extract_mode_output_shape(self):
|
||||
model = self.model_class(**self.get_init_dict())
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
with attention_backend("native"):
|
||||
with torch.no_grad():
|
||||
output = model(**inputs_dict)
|
||||
height, width = 4, 4
|
||||
output = model(
|
||||
**self.get_dummy_inputs(height=height, width=width),
|
||||
kv_cache_mode="extract",
|
||||
num_ref_tokens=self.num_ref_tokens,
|
||||
ref_fixed_timestep=0.0,
|
||||
)
|
||||
|
||||
if isinstance(output, dict):
|
||||
output = output.to_tuple()[0]
|
||||
assert output.sample.shape == (1, height * width, 4)
|
||||
|
||||
self.assertIsNotNone(output)
|
||||
@torch.no_grad()
|
||||
def test_cached_mode_uses_cache(self):
|
||||
model = self.model_class(**self.get_init_dict())
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
# input & output have to have the same shape
|
||||
input_tensor = inputs_dict[self.main_input_name]
|
||||
expected_shape = input_tensor.shape
|
||||
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
|
||||
height, width = 4, 4
|
||||
extract_output = model(
|
||||
**self.get_dummy_inputs(height=height, width=width),
|
||||
kv_cache_mode="extract",
|
||||
num_ref_tokens=self.num_ref_tokens,
|
||||
ref_fixed_timestep=0.0,
|
||||
)
|
||||
|
||||
# Check against expected slice
|
||||
# fmt: off
|
||||
expected_slice = torch.tensor([-0.3662, 0.4844, 0.6334, -0.3497, 0.2162, 0.0188, 0.0521, -0.2061, -0.2041, -0.0342, -0.7107, 0.4797, -0.3280, 0.7059, -0.0849, 0.4416])
|
||||
# fmt: on
|
||||
base_config = Flux2TransformerTesterConfig()
|
||||
cached_inputs = base_config.get_dummy_inputs(height=height, width=width)
|
||||
cached_output = model(
|
||||
**cached_inputs,
|
||||
kv_cache=extract_output.kv_cache,
|
||||
kv_cache_mode="cached",
|
||||
)
|
||||
|
||||
flat_output = output.cpu().flatten()
|
||||
generated_slice = torch.cat([flat_output[:8], flat_output[-8:]])
|
||||
self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-4))
|
||||
assert cached_output.sample.shape == (1, height * width, 4)
|
||||
assert cached_output.kv_cache is None
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"Flux2Transformer2DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
@torch.no_grad()
|
||||
def test_extract_return_dict_false(self):
|
||||
model = self.model_class(**self.get_init_dict())
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
output = model(
|
||||
**self.get_dummy_inputs(),
|
||||
kv_cache_mode="extract",
|
||||
num_ref_tokens=self.num_ref_tokens,
|
||||
ref_fixed_timestep=0.0,
|
||||
return_dict=False,
|
||||
)
|
||||
|
||||
class Flux2TransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
|
||||
model_class = Flux2Transformer2DModel
|
||||
different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]
|
||||
assert isinstance(output, tuple)
|
||||
assert len(output) == 2
|
||||
assert isinstance(output[1], Flux2KVCache)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
return Flux2TransformerTests().prepare_init_args_and_inputs_for_common()
|
||||
@torch.no_grad()
|
||||
def test_no_kv_cache_mode_returns_no_cache(self):
|
||||
model = self.model_class(**self.get_init_dict())
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
def prepare_dummy_input(self, height, width):
|
||||
return Flux2TransformerTests().prepare_dummy_input(height=height, width=width)
|
||||
base_config = Flux2TransformerTesterConfig()
|
||||
output = model(**base_config.get_dummy_inputs())
|
||||
|
||||
|
||||
class Flux2TransformerLoRAHotSwapTests(LoraHotSwappingForModelTesterMixin, unittest.TestCase):
|
||||
model_class = Flux2Transformer2DModel
|
||||
different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
return Flux2TransformerTests().prepare_init_args_and_inputs_for_common()
|
||||
|
||||
def prepare_dummy_input(self, height, width):
|
||||
return Flux2TransformerTests().prepare_dummy_input(height=height, width=width)
|
||||
assert output.kv_cache is None
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
|
||||
import warnings
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from diffusers import QwenImageTransformer2DModel
|
||||
@@ -77,8 +78,7 @@ class QwenImageTransformerTesterConfig(BaseModelTesterConfig):
|
||||
"axes_dims_rope": (8, 4, 4),
|
||||
}
|
||||
|
||||
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
|
||||
batch_size = 1
|
||||
def get_dummy_inputs(self, batch_size: int = 1) -> dict[str, torch.Tensor]:
|
||||
num_latent_channels = embedding_dim = 16
|
||||
height = width = 4
|
||||
sequence_length = 8
|
||||
@@ -106,9 +106,10 @@ class QwenImageTransformerTesterConfig(BaseModelTesterConfig):
|
||||
|
||||
|
||||
class TestQwenImageTransformer(QwenImageTransformerTesterConfig, ModelTesterMixin):
|
||||
def test_infers_text_seq_len_from_mask(self):
|
||||
@pytest.mark.parametrize("batch_size", [1, 2])
|
||||
def test_infers_text_seq_len_from_mask(self, batch_size):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs = self.get_dummy_inputs()
|
||||
inputs = self.get_dummy_inputs(batch_size=batch_size)
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
encoder_hidden_states_mask = inputs["encoder_hidden_states_mask"].clone()
|
||||
@@ -122,7 +123,7 @@ class TestQwenImageTransformer(QwenImageTransformerTesterConfig, ModelTesterMixi
|
||||
assert isinstance(per_sample_len, torch.Tensor)
|
||||
assert int(per_sample_len.max().item()) == 2
|
||||
assert normalized_mask.dtype == torch.bool
|
||||
assert normalized_mask.sum().item() == 2
|
||||
assert normalized_mask.sum().item() == 2 * batch_size
|
||||
assert rope_text_seq_len >= inputs["encoder_hidden_states"].shape[1]
|
||||
|
||||
inputs["encoder_hidden_states_mask"] = normalized_mask
|
||||
@@ -139,7 +140,7 @@ class TestQwenImageTransformer(QwenImageTransformerTesterConfig, ModelTesterMixi
|
||||
)
|
||||
|
||||
assert int(per_sample_len2.max().item()) == 8
|
||||
assert normalized_mask2.sum().item() == 5
|
||||
assert normalized_mask2.sum().item() == 5 * batch_size
|
||||
|
||||
rope_text_seq_len_none, per_sample_len_none, normalized_mask_none = compute_text_seq_len_from_mask(
|
||||
inputs["encoder_hidden_states"], None
|
||||
@@ -149,9 +150,10 @@ class TestQwenImageTransformer(QwenImageTransformerTesterConfig, ModelTesterMixi
|
||||
assert per_sample_len_none is None
|
||||
assert normalized_mask_none is None
|
||||
|
||||
def test_non_contiguous_attention_mask(self):
|
||||
@pytest.mark.parametrize("batch_size", [1, 2])
|
||||
def test_non_contiguous_attention_mask(self, batch_size):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs = self.get_dummy_inputs()
|
||||
inputs = self.get_dummy_inputs(batch_size=batch_size)
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
encoder_hidden_states_mask = inputs["encoder_hidden_states_mask"].clone()
|
||||
@@ -284,6 +286,14 @@ class TestQwenImageTransformerLoRA(QwenImageTransformerTesterConfig, LoraTesterM
|
||||
class TestQwenImageTransformerLoRAHotSwap(QwenImageTransformerTesterConfig, LoraHotSwappingForModelTesterMixin):
|
||||
"""LoRA hot-swapping tests for QwenImage Transformer."""
|
||||
|
||||
@pytest.mark.xfail(True, reason="Recompilation issues.", strict=True)
|
||||
def test_hotswapping_compiled_model_linear(self):
|
||||
super().test_hotswapping_compiled_model_linear()
|
||||
|
||||
@pytest.mark.xfail(True, reason="Recompilation issues.", strict=True)
|
||||
def test_hotswapping_compiled_model_both_linear_and_other(self):
|
||||
super().test_hotswapping_compiled_model_both_linear_and_other()
|
||||
|
||||
@property
|
||||
def different_shapes_for_compilation(self):
|
||||
return [(4, 4), (4, 8), (8, 8)]
|
||||
|
||||
@@ -13,58 +13,63 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import SD3Transformer2DModel
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
from ...testing_utils import (
|
||||
enable_full_determinism,
|
||||
torch_device,
|
||||
from ...testing_utils import enable_full_determinism, torch_device
|
||||
from ..testing_utils import (
|
||||
BaseModelTesterConfig,
|
||||
BitsAndBytesTesterMixin,
|
||||
ModelTesterMixin,
|
||||
TorchAoTesterMixin,
|
||||
TorchCompileTesterMixin,
|
||||
TrainingTesterMixin,
|
||||
)
|
||||
from ..test_modeling_common import ModelTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class SD3TransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = SD3Transformer2DModel
|
||||
main_input_name = "hidden_states"
|
||||
model_split_percents = [0.8, 0.8, 0.9]
|
||||
# ======================== SD3 Transformer ========================
|
||||
|
||||
|
||||
class SD3TransformerTesterConfig(BaseModelTesterConfig):
|
||||
@property
|
||||
def model_class(self):
|
||||
return SD3Transformer2DModel
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
batch_size = 2
|
||||
num_channels = 4
|
||||
height = width = embedding_dim = 32
|
||||
pooled_embedding_dim = embedding_dim * 2
|
||||
sequence_length = 154
|
||||
def pretrained_model_name_or_path(self):
|
||||
return "hf-internal-testing/tiny-sd3-pipe"
|
||||
|
||||
hidden_states = torch.randn((batch_size, num_channels, height, width)).to(torch_device)
|
||||
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
|
||||
pooled_prompt_embeds = torch.randn((batch_size, pooled_embedding_dim)).to(torch_device)
|
||||
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
|
||||
@property
|
||||
def pretrained_model_kwargs(self):
|
||||
return {"subfolder": "transformer"}
|
||||
|
||||
@property
|
||||
def main_input_name(self) -> str:
|
||||
return "hidden_states"
|
||||
|
||||
@property
|
||||
def model_split_percents(self) -> list:
|
||||
return [0.8, 0.8, 0.9]
|
||||
|
||||
@property
|
||||
def output_shape(self) -> tuple:
|
||||
return (4, 32, 32)
|
||||
|
||||
@property
|
||||
def input_shape(self) -> tuple:
|
||||
return (4, 32, 32)
|
||||
|
||||
@property
|
||||
def generator(self):
|
||||
return torch.Generator("cpu").manual_seed(0)
|
||||
|
||||
def get_init_dict(self) -> dict:
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"pooled_projections": pooled_prompt_embeds,
|
||||
"timestep": timestep,
|
||||
}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (4, 32, 32)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (4, 32, 32)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"sample_size": 32,
|
||||
"patch_size": 1,
|
||||
"in_channels": 4,
|
||||
@@ -79,67 +84,79 @@ class SD3TransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
"dual_attention_layers": (),
|
||||
"qk_norm": None,
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
@unittest.skipIf(
|
||||
torch_device != "cuda" or not is_xformers_available(),
|
||||
reason="XFormers attention is only available with CUDA and `xformers` installed",
|
||||
)
|
||||
def test_xformers_enable_works(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict)
|
||||
|
||||
model.enable_xformers_memory_efficient_attention()
|
||||
|
||||
assert model.transformer_blocks[0].attn.processor.__class__.__name__ == "XFormersJointAttnProcessor", (
|
||||
"xformers is not enabled"
|
||||
)
|
||||
|
||||
@unittest.skip("SD3Transformer2DModel uses a dedicated attention processor. This test doesn't apply")
|
||||
def test_set_attn_processor_for_determinism(self):
|
||||
pass
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"SD3Transformer2DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
|
||||
class SD35TransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = SD3Transformer2DModel
|
||||
main_input_name = "hidden_states"
|
||||
model_split_percents = [0.8, 0.8, 0.9]
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
batch_size = 2
|
||||
def get_dummy_inputs(self, batch_size: int = 2) -> dict[str, torch.Tensor]:
|
||||
num_channels = 4
|
||||
height = width = embedding_dim = 32
|
||||
pooled_embedding_dim = embedding_dim * 2
|
||||
sequence_length = 154
|
||||
|
||||
hidden_states = torch.randn((batch_size, num_channels, height, width)).to(torch_device)
|
||||
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
|
||||
pooled_prompt_embeds = torch.randn((batch_size, pooled_embedding_dim)).to(torch_device)
|
||||
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
|
||||
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"pooled_projections": pooled_prompt_embeds,
|
||||
"timestep": timestep,
|
||||
"hidden_states": randn_tensor(
|
||||
(batch_size, num_channels, height, width), generator=self.generator, device=torch_device
|
||||
),
|
||||
"encoder_hidden_states": randn_tensor(
|
||||
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
|
||||
),
|
||||
"pooled_projections": randn_tensor(
|
||||
(batch_size, pooled_embedding_dim), generator=self.generator, device=torch_device
|
||||
),
|
||||
"timestep": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device),
|
||||
}
|
||||
|
||||
|
||||
class TestSD3Transformer(SD3TransformerTesterConfig, ModelTesterMixin):
|
||||
pass
|
||||
|
||||
|
||||
class TestSD3TransformerTraining(SD3TransformerTesterConfig, TrainingTesterMixin):
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"SD3Transformer2DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
|
||||
class TestSD3TransformerCompile(SD3TransformerTesterConfig, TorchCompileTesterMixin):
|
||||
pass
|
||||
|
||||
|
||||
# ======================== SD3.5 Transformer ========================
|
||||
|
||||
|
||||
class SD35TransformerTesterConfig(BaseModelTesterConfig):
|
||||
@property
|
||||
def input_shape(self):
|
||||
def model_class(self):
|
||||
return SD3Transformer2DModel
|
||||
|
||||
@property
|
||||
def pretrained_model_name_or_path(self):
|
||||
return "hf-internal-testing/tiny-sd35-pipe"
|
||||
|
||||
@property
|
||||
def pretrained_model_kwargs(self):
|
||||
return {"subfolder": "transformer"}
|
||||
|
||||
@property
|
||||
def main_input_name(self) -> str:
|
||||
return "hidden_states"
|
||||
|
||||
@property
|
||||
def model_split_percents(self) -> list:
|
||||
return [0.8, 0.8, 0.9]
|
||||
|
||||
@property
|
||||
def output_shape(self) -> tuple:
|
||||
return (4, 32, 32)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
def input_shape(self) -> tuple:
|
||||
return (4, 32, 32)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
@property
|
||||
def generator(self):
|
||||
return torch.Generator("cpu").manual_seed(0)
|
||||
|
||||
def get_init_dict(self) -> dict:
|
||||
return {
|
||||
"sample_size": 32,
|
||||
"patch_size": 1,
|
||||
"in_channels": 4,
|
||||
@@ -154,47 +171,56 @@ class SD35TransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
"dual_attention_layers": (0,),
|
||||
"qk_norm": "rms_norm",
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
@unittest.skipIf(
|
||||
torch_device != "cuda" or not is_xformers_available(),
|
||||
reason="XFormers attention is only available with CUDA and `xformers` installed",
|
||||
)
|
||||
def test_xformers_enable_works(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict)
|
||||
def get_dummy_inputs(self, batch_size: int = 2) -> dict[str, torch.Tensor]:
|
||||
num_channels = 4
|
||||
height = width = embedding_dim = 32
|
||||
pooled_embedding_dim = embedding_dim * 2
|
||||
sequence_length = 154
|
||||
|
||||
model.enable_xformers_memory_efficient_attention()
|
||||
return {
|
||||
"hidden_states": randn_tensor(
|
||||
(batch_size, num_channels, height, width), generator=self.generator, device=torch_device
|
||||
),
|
||||
"encoder_hidden_states": randn_tensor(
|
||||
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
|
||||
),
|
||||
"pooled_projections": randn_tensor(
|
||||
(batch_size, pooled_embedding_dim), generator=self.generator, device=torch_device
|
||||
),
|
||||
"timestep": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device),
|
||||
}
|
||||
|
||||
assert model.transformer_blocks[0].attn.processor.__class__.__name__ == "XFormersJointAttnProcessor", (
|
||||
"xformers is not enabled"
|
||||
)
|
||||
|
||||
@unittest.skip("SD3Transformer2DModel uses a dedicated attention processor. This test doesn't apply")
|
||||
def test_set_attn_processor_for_determinism(self):
|
||||
pass
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"SD3Transformer2DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
class TestSD35Transformer(SD35TransformerTesterConfig, ModelTesterMixin):
|
||||
def test_skip_layers(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
# Forward pass without skipping layers
|
||||
output_full = model(**inputs_dict).sample
|
||||
|
||||
# Forward pass with skipping layers 0 (since there's only one layer in this test setup)
|
||||
inputs_dict_with_skip = inputs_dict.copy()
|
||||
inputs_dict_with_skip["skip_layers"] = [0]
|
||||
output_skip = model(**inputs_dict_with_skip).sample
|
||||
|
||||
# Check that the outputs are different
|
||||
self.assertFalse(
|
||||
torch.allclose(output_full, output_skip, atol=1e-5), "Outputs should differ when layers are skipped"
|
||||
)
|
||||
assert not torch.allclose(output_full, output_skip, atol=1e-5), "Outputs should differ when layers are skipped"
|
||||
assert output_full.shape == output_skip.shape, "Outputs should have the same shape"
|
||||
|
||||
# Check that the outputs have the same shape
|
||||
self.assertEqual(output_full.shape, output_skip.shape, "Outputs should have the same shape")
|
||||
|
||||
class TestSD35TransformerTraining(SD35TransformerTesterConfig, TrainingTesterMixin):
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"SD3Transformer2DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
|
||||
class TestSD35TransformerCompile(SD35TransformerTesterConfig, TorchCompileTesterMixin):
|
||||
pass
|
||||
|
||||
|
||||
class TestSD35TransformerBitsAndBytes(SD35TransformerTesterConfig, BitsAndBytesTesterMixin):
|
||||
"""BitsAndBytes quantization tests for SD3.5 Transformer."""
|
||||
|
||||
|
||||
class TestSD35TransformerTorchAo(SD35TransformerTesterConfig, TorchAoTesterMixin):
|
||||
"""TorchAO quantization tests for SD3.5 Transformer."""
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import Callable
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
import diffusers
|
||||
from diffusers import AutoModel, ComponentsManager, ModularPipeline, ModularPipelineBlocks
|
||||
@@ -32,6 +33,33 @@ from ..testing_utils import (
|
||||
)
|
||||
|
||||
|
||||
def _get_specified_components(path_or_repo_id, cache_dir=None):
|
||||
if os.path.isdir(path_or_repo_id):
|
||||
config_path = os.path.join(path_or_repo_id, "modular_model_index.json")
|
||||
else:
|
||||
try:
|
||||
config_path = hf_hub_download(
|
||||
repo_id=path_or_repo_id,
|
||||
filename="modular_model_index.json",
|
||||
local_dir=cache_dir,
|
||||
)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
with open(config_path) as f:
|
||||
config = json.load(f)
|
||||
|
||||
components = set()
|
||||
for k, v in config.items():
|
||||
if isinstance(v, (str, int, float, bool)):
|
||||
continue
|
||||
for entry in v:
|
||||
if isinstance(entry, dict) and (entry.get("repo") or entry.get("pretrained_model_name_or_path")):
|
||||
components.add(k)
|
||||
break
|
||||
return components
|
||||
|
||||
|
||||
class ModularPipelineTesterMixin:
|
||||
"""
|
||||
It provides a set of common tests for each modular pipeline,
|
||||
@@ -360,6 +388,39 @@ class ModularPipelineTesterMixin:
|
||||
|
||||
assert torch.abs(image_slices[0] - image_slices[1]).max() < 1e-3
|
||||
|
||||
def test_load_expected_components_from_pretrained(self, tmp_path):
|
||||
pipe = self.get_pipeline()
|
||||
expected = _get_specified_components(self.pretrained_model_name_or_path, cache_dir=tmp_path)
|
||||
if not expected:
|
||||
pytest.skip("Skipping test as we couldn't fetch the expected components.")
|
||||
|
||||
actual = {
|
||||
name
|
||||
for name in pipe.components
|
||||
if getattr(pipe, name, None) is not None
|
||||
and getattr(getattr(pipe, name), "_diffusers_load_id", None) not in (None, "null")
|
||||
}
|
||||
assert expected == actual, f"Component mismatch: missing={expected - actual}, unexpected={actual - expected}"
|
||||
|
||||
def test_load_expected_components_from_save_pretrained(self, tmp_path):
|
||||
pipe = self.get_pipeline()
|
||||
save_dir = str(tmp_path / "saved-pipeline")
|
||||
pipe.save_pretrained(save_dir)
|
||||
|
||||
expected = _get_specified_components(save_dir)
|
||||
loaded_pipe = ModularPipeline.from_pretrained(save_dir)
|
||||
loaded_pipe.load_components(torch_dtype=torch.float32)
|
||||
|
||||
actual = {
|
||||
name
|
||||
for name in loaded_pipe.components
|
||||
if getattr(loaded_pipe, name, None) is not None
|
||||
and getattr(getattr(loaded_pipe, name), "_diffusers_load_id", None) not in (None, "null")
|
||||
}
|
||||
assert expected == actual, (
|
||||
f"Component mismatch after save/load: missing={expected - actual}, unexpected={actual - expected}"
|
||||
)
|
||||
|
||||
def test_modular_index_consistency(self, tmp_path):
|
||||
pipe = self.get_pipeline()
|
||||
components_spec = pipe._component_specs
|
||||
|
||||
@@ -18,7 +18,7 @@ import unittest
|
||||
import torch
|
||||
|
||||
from diffusers import DDIMScheduler, DDPMScheduler, UNet2DModel
|
||||
from diffusers.training_utils import set_seed
|
||||
from diffusers.training_utils import compute_confidence_aware_loss, set_seed
|
||||
|
||||
from ..testing_utils import slow
|
||||
|
||||
@@ -85,3 +85,47 @@ class TrainingTests(unittest.TestCase):
|
||||
|
||||
self.assertTrue(torch.allclose(ddpm_noisy_images, ddim_noisy_images, atol=1e-5))
|
||||
self.assertTrue(torch.allclose(ddpm_noise_pred, ddim_noise_pred, atol=1e-5))
|
||||
|
||||
def test_confidence_aware_loss(self):
|
||||
logits = torch.tensor([[[5.0, 0.0], [0.0, 5.0]]])
|
||||
labels = torch.tensor([[0, 0]])
|
||||
weights = torch.tensor([[1.0, 2.0]])
|
||||
|
||||
loss, loss_sft, loss_conf = compute_confidence_aware_loss(
|
||||
logits, labels, lambda_conf=0.0, per_token_weights=weights
|
||||
)
|
||||
self.assertTrue(torch.allclose(loss, loss_sft))
|
||||
self.assertTrue(torch.allclose(loss_conf, torch.zeros_like(loss_conf)))
|
||||
|
||||
lambda_conf = 0.25
|
||||
loss, loss_sft, loss_conf = compute_confidence_aware_loss(
|
||||
logits, labels, lambda_conf=lambda_conf, per_token_weights=weights
|
||||
)
|
||||
|
||||
# Manual expected values for the small 2-class case.
|
||||
per_token_nll = torch.nn.functional.cross_entropy(logits.view(-1, 2), labels.view(-1), reduction="none").view(
|
||||
1, 2
|
||||
)
|
||||
expected_sft = (per_token_nll * weights).sum() / weights.sum()
|
||||
|
||||
pred = logits.argmax(dim=-1)
|
||||
correct = pred.eq(labels)
|
||||
log_probs = torch.log_softmax(logits.float(), dim=-1)
|
||||
probs = log_probs.exp()
|
||||
entropy = -(probs * log_probs).sum(dim=-1).to(dtype=logits.dtype)
|
||||
expected_conf = (entropy * weights * correct.to(entropy.dtype)).sum() / (
|
||||
weights * correct.to(weights.dtype)
|
||||
).sum().clamp_min(1)
|
||||
|
||||
expected = expected_sft + lambda_conf * expected_conf
|
||||
self.assertTrue(torch.allclose(loss_sft, expected_sft))
|
||||
self.assertTrue(torch.allclose(loss_conf, expected_conf))
|
||||
self.assertTrue(torch.allclose(loss, expected))
|
||||
|
||||
# Temperature affects only the confidence term.
|
||||
loss_t, loss_sft_t, loss_conf_t = compute_confidence_aware_loss(
|
||||
logits, labels, lambda_conf=lambda_conf, temperature=0.5, per_token_weights=weights
|
||||
)
|
||||
self.assertTrue(torch.allclose(loss_sft_t, expected_sft))
|
||||
self.assertFalse(torch.allclose(loss_conf_t, expected_conf))
|
||||
self.assertTrue(torch.allclose(loss_t, loss_sft_t + lambda_conf * loss_conf_t))
|
||||
|
||||
@@ -13,8 +13,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import importlib
|
||||
import os
|
||||
import unittest
|
||||
import warnings
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -182,6 +184,25 @@ class DeprecateTester(unittest.TestCase):
|
||||
assert str(warning.warning) == "This message is better!!!"
|
||||
assert "diffusers/tests/others/test_utils.py" in warning.filename
|
||||
|
||||
def test_deprecate_testing_utils_module(self):
|
||||
import diffusers.utils.testing_utils
|
||||
|
||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||
warnings.simplefilter("always")
|
||||
importlib.reload(diffusers.utils.testing_utils)
|
||||
|
||||
deprecation_warnings = [w for w in caught_warnings if issubclass(w.category, FutureWarning)]
|
||||
assert len(deprecation_warnings) >= 1, "Expected at least one FutureWarning from diffusers.utils.testing_utils"
|
||||
|
||||
messages = [str(w.message) for w in deprecation_warnings]
|
||||
assert any("diffusers.utils.testing_utils" in msg for msg in messages), (
|
||||
f"Expected a deprecation warning mentioning 'diffusers.utils.testing_utils', got: {messages}"
|
||||
)
|
||||
assert any(
|
||||
"diffusers.utils.testing_utils is deprecated and will be removed in a future version." in msg
|
||||
for msg in messages
|
||||
), f"Expected deprecation message substring not found, got: {messages}"
|
||||
|
||||
|
||||
# Copied from https://github.com/huggingface/transformers/blob/main/tests/utils/test_expectations.py
|
||||
class ExpectationsTester(unittest.TestCase):
|
||||
|
||||
0
tests/pipelines/llada2/__init__.py
Normal file
0
tests/pipelines/llada2/__init__.py
Normal file
245
tests/pipelines/llada2/test_llada2.py
Normal file
245
tests/pipelines/llada2/test_llada2.py
Normal file
@@ -0,0 +1,245 @@
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import BlockRefinementScheduler, LLaDA2Pipeline
|
||||
|
||||
|
||||
class _DummyModelOutput:
|
||||
def __init__(self, logits):
|
||||
self.logits = logits
|
||||
|
||||
|
||||
class _DummyCausalLM(torch.nn.Module):
|
||||
def __init__(self, vocab_size: int):
|
||||
super().__init__()
|
||||
self.vocab_size = int(vocab_size)
|
||||
self.register_buffer("_device_anchor", torch.empty(0))
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return torch.float32
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self._device_anchor.device
|
||||
|
||||
def forward(self, input_ids, attention_mask=None, position_ids=None, **kwargs):
|
||||
batch_size, seq_len = input_ids.shape
|
||||
logits = torch.zeros((batch_size, seq_len, self.vocab_size), device=input_ids.device, dtype=torch.float32)
|
||||
|
||||
# Make confidence vary with token position so top-k commits are deterministic.
|
||||
positions = torch.arange(seq_len, device=input_ids.device, dtype=torch.float32).view(1, seq_len, 1)
|
||||
token_ids = (torch.arange(seq_len, device=input_ids.device) % (self.vocab_size - 2)).view(1, seq_len, 1)
|
||||
logits.scatter_(2, token_ids.expand(batch_size, -1, -1), 1.0 + positions.expand(batch_size, -1, -1) * 0.1)
|
||||
return _DummyModelOutput(logits=logits)
|
||||
|
||||
|
||||
def _make_pipeline(tokenizer=None):
|
||||
model = _DummyCausalLM(vocab_size=32)
|
||||
scheduler = BlockRefinementScheduler()
|
||||
return LLaDA2Pipeline(model=model, scheduler=scheduler, tokenizer=tokenizer)
|
||||
|
||||
|
||||
class LLaDA2PipelineTest(unittest.TestCase):
|
||||
def test_pipeline_runs(self):
|
||||
pipe = _make_pipeline().to("cpu")
|
||||
|
||||
input_ids = torch.tensor([[5, 6, 7, 8], [1, 2, 3, 4]], dtype=torch.long)
|
||||
out = pipe(
|
||||
input_ids=input_ids,
|
||||
use_chat_template=False,
|
||||
gen_length=24,
|
||||
block_length=8,
|
||||
num_inference_steps=8,
|
||||
temperature=0.0,
|
||||
threshold=2.0, # force top-k commits
|
||||
minimal_topk=1,
|
||||
eos_early_stop=False,
|
||||
mask_token_id=31,
|
||||
eos_token_id=None,
|
||||
output_type="seq",
|
||||
)
|
||||
|
||||
self.assertEqual(out.sequences.shape, (2, 24))
|
||||
self.assertFalse((out.sequences == 31).any().item())
|
||||
|
||||
def test_pipeline_return_tuple(self):
|
||||
pipe = _make_pipeline().to("cpu")
|
||||
|
||||
input_ids = torch.tensor([[5, 6, 7, 8]], dtype=torch.long)
|
||||
sequences, texts = pipe(
|
||||
input_ids=input_ids,
|
||||
use_chat_template=False,
|
||||
gen_length=16,
|
||||
block_length=8,
|
||||
num_inference_steps=4,
|
||||
temperature=0.0,
|
||||
threshold=2.0,
|
||||
minimal_topk=1,
|
||||
eos_early_stop=False,
|
||||
mask_token_id=31,
|
||||
output_type="seq",
|
||||
return_dict=False,
|
||||
)
|
||||
|
||||
self.assertEqual(sequences.shape, (1, 16))
|
||||
self.assertIsNone(texts)
|
||||
|
||||
def test_output_type_seq(self):
|
||||
"""output_type='seq' should return sequences but no texts."""
|
||||
pipe = _make_pipeline().to("cpu")
|
||||
|
||||
out = pipe(
|
||||
input_ids=torch.tensor([[5, 6, 7, 8]], dtype=torch.long),
|
||||
use_chat_template=False,
|
||||
gen_length=16,
|
||||
block_length=8,
|
||||
num_inference_steps=4,
|
||||
temperature=0.0,
|
||||
threshold=2.0,
|
||||
minimal_topk=1,
|
||||
eos_early_stop=False,
|
||||
mask_token_id=31,
|
||||
output_type="seq",
|
||||
)
|
||||
|
||||
self.assertIsNotNone(out.sequences)
|
||||
self.assertEqual(out.sequences.shape, (1, 16))
|
||||
self.assertIsNone(out.texts)
|
||||
|
||||
def test_output_type_text_without_tokenizer(self):
|
||||
"""output_type='text' without a tokenizer should return texts=None."""
|
||||
pipe = _make_pipeline(tokenizer=None).to("cpu")
|
||||
|
||||
out = pipe(
|
||||
input_ids=torch.tensor([[5, 6, 7, 8]], dtype=torch.long),
|
||||
use_chat_template=False,
|
||||
gen_length=16,
|
||||
block_length=8,
|
||||
num_inference_steps=4,
|
||||
temperature=0.0,
|
||||
threshold=2.0,
|
||||
minimal_topk=1,
|
||||
eos_early_stop=False,
|
||||
mask_token_id=31,
|
||||
output_type="text",
|
||||
)
|
||||
|
||||
self.assertIsNotNone(out.sequences)
|
||||
self.assertIsNone(out.texts)
|
||||
|
||||
def test_output_type_text_with_tokenizer(self):
|
||||
"""output_type='text' with a tokenizer should return decoded texts."""
|
||||
tok = type(
|
||||
"Tok",
|
||||
(),
|
||||
{
|
||||
"eos_token_id": None,
|
||||
"mask_token_id": 31,
|
||||
"batch_decode": lambda self, seqs, **kw: [f"decoded_{len(s)}" for s in seqs],
|
||||
},
|
||||
)()
|
||||
pipe = _make_pipeline(tokenizer=tok).to("cpu")
|
||||
|
||||
out = pipe(
|
||||
input_ids=torch.tensor([[5, 6, 7, 8]], dtype=torch.long),
|
||||
use_chat_template=False,
|
||||
gen_length=16,
|
||||
block_length=8,
|
||||
num_inference_steps=4,
|
||||
temperature=0.0,
|
||||
threshold=2.0,
|
||||
minimal_topk=1,
|
||||
eos_early_stop=False,
|
||||
output_type="text",
|
||||
)
|
||||
|
||||
self.assertIsNotNone(out.sequences)
|
||||
self.assertIsNotNone(out.texts)
|
||||
self.assertEqual(len(out.texts), 1)
|
||||
self.assertTrue(out.texts[0].startswith("decoded_"))
|
||||
|
||||
def test_output_type_invalid_raises(self):
|
||||
"""Invalid output_type should raise ValueError."""
|
||||
pipe = _make_pipeline().to("cpu")
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
pipe(
|
||||
input_ids=torch.tensor([[5, 6, 7, 8]], dtype=torch.long),
|
||||
use_chat_template=False,
|
||||
gen_length=16,
|
||||
block_length=8,
|
||||
num_inference_steps=4,
|
||||
mask_token_id=31,
|
||||
output_type="invalid",
|
||||
)
|
||||
|
||||
def test_prepare_input_ids_from_tensor(self):
|
||||
pipe = _make_pipeline()
|
||||
ids = torch.tensor([[1, 2, 3]], dtype=torch.long)
|
||||
result = pipe._prepare_input_ids(
|
||||
prompt=None,
|
||||
messages=None,
|
||||
input_ids=ids,
|
||||
use_chat_template=False,
|
||||
add_generation_prompt=False,
|
||||
chat_template_kwargs=None,
|
||||
)
|
||||
self.assertTrue(torch.equal(result, ids))
|
||||
|
||||
def test_prepare_input_ids_from_1d_tensor(self):
|
||||
pipe = _make_pipeline()
|
||||
ids = torch.tensor([1, 2, 3], dtype=torch.long)
|
||||
result = pipe._prepare_input_ids(
|
||||
prompt=None,
|
||||
messages=None,
|
||||
input_ids=ids,
|
||||
use_chat_template=False,
|
||||
add_generation_prompt=False,
|
||||
chat_template_kwargs=None,
|
||||
)
|
||||
self.assertEqual(result.shape, (1, 3))
|
||||
|
||||
def test_prepare_input_ids_no_tokenizer_raises(self):
|
||||
pipe = _make_pipeline(tokenizer=None)
|
||||
with self.assertRaises(ValueError):
|
||||
pipe._prepare_input_ids(
|
||||
prompt="hello",
|
||||
messages=None,
|
||||
input_ids=None,
|
||||
use_chat_template=False,
|
||||
add_generation_prompt=False,
|
||||
chat_template_kwargs=None,
|
||||
)
|
||||
|
||||
def test_prepare_input_ids_both_prompt_and_messages_raises(self):
|
||||
pipe = _make_pipeline()
|
||||
# Manually set tokenizer to a simple object so _prepare_input_ids doesn't short-circuit
|
||||
pipe.tokenizer = type("Tok", (), {"eos_token_id": None, "mask_token_id": None})()
|
||||
with self.assertRaises(ValueError):
|
||||
pipe._prepare_input_ids(
|
||||
prompt="hello",
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
input_ids=None,
|
||||
use_chat_template=False,
|
||||
add_generation_prompt=False,
|
||||
chat_template_kwargs=None,
|
||||
)
|
||||
|
||||
def test_prepare_input_ids_neither_raises(self):
|
||||
pipe = _make_pipeline()
|
||||
pipe.tokenizer = type("Tok", (), {"eos_token_id": None, "mask_token_id": None})()
|
||||
with self.assertRaises(ValueError):
|
||||
pipe._prepare_input_ids(
|
||||
prompt=None,
|
||||
messages=None,
|
||||
input_ids=None,
|
||||
use_chat_template=False,
|
||||
add_generation_prompt=False,
|
||||
chat_template_kwargs=None,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -1534,14 +1534,18 @@ class PipelineTesterMixin:
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
pipe.to("cpu")
|
||||
model_devices = [component.device.type for component in components.values() if hasattr(component, "device")]
|
||||
model_devices = [
|
||||
component.device.type for component in components.values() if getattr(component, "device", None)
|
||||
]
|
||||
self.assertTrue(all(device == "cpu" for device in model_devices))
|
||||
|
||||
output_cpu = pipe(**self.get_dummy_inputs("cpu"))[0]
|
||||
self.assertTrue(np.isnan(output_cpu).sum() == 0)
|
||||
|
||||
pipe.to(torch_device)
|
||||
model_devices = [component.device.type for component in components.values() if hasattr(component, "device")]
|
||||
model_devices = [
|
||||
component.device.type for component in components.values() if getattr(component, "device", None)
|
||||
]
|
||||
self.assertTrue(all(device == torch_device for device in model_devices))
|
||||
|
||||
output_device = pipe(**self.get_dummy_inputs(torch_device))[0]
|
||||
@@ -1552,11 +1556,11 @@ class PipelineTesterMixin:
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
model_dtypes = [component.dtype for component in components.values() if hasattr(component, "dtype")]
|
||||
model_dtypes = [component.dtype for component in components.values() if getattr(component, "dtype", None)]
|
||||
self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes))
|
||||
|
||||
pipe.to(dtype=torch.float16)
|
||||
model_dtypes = [component.dtype for component in components.values() if hasattr(component, "dtype")]
|
||||
model_dtypes = [component.dtype for component in components.values() if getattr(component, "dtype", None)]
|
||||
self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes))
|
||||
|
||||
def test_attention_slicing_forward_pass(self, expected_max_diff=1e-3):
|
||||
|
||||
470
tests/schedulers/test_scheduler_block_refinement.py
Normal file
470
tests/schedulers/test_scheduler_block_refinement.py
Normal file
@@ -0,0 +1,470 @@
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import BlockRefinementScheduler
|
||||
|
||||
|
||||
class BlockRefinementSchedulerTest(unittest.TestCase):
|
||||
def get_scheduler(self, **kwargs):
|
||||
config = {
|
||||
"block_length": 32,
|
||||
"num_inference_steps": 8,
|
||||
"threshold": 0.95,
|
||||
"editing_threshold": None,
|
||||
"minimal_topk": 1,
|
||||
}
|
||||
config.update(kwargs)
|
||||
return BlockRefinementScheduler(**config)
|
||||
|
||||
def _make_logits_from_probs(self, target_probs: torch.Tensor, vocab_size: int = 100) -> torch.Tensor:
|
||||
"""Create logits where softmax of the target token has approximately the given probability."""
|
||||
batch_size, block_length = target_probs.shape
|
||||
logits = torch.zeros(batch_size, block_length, vocab_size)
|
||||
# Set token 0 as the "predicted" token with a logit proportional to desired probability
|
||||
for b in range(batch_size):
|
||||
for t in range(block_length):
|
||||
p = target_probs[b, t].item()
|
||||
if p > 0:
|
||||
logits[b, t, t % (vocab_size - 1)] = 10.0 * p
|
||||
return logits
|
||||
|
||||
def test_set_timesteps(self):
|
||||
scheduler = self.get_scheduler()
|
||||
scheduler.set_timesteps(8)
|
||||
self.assertEqual(scheduler.num_inference_steps, 8)
|
||||
self.assertEqual(len(scheduler.timesteps), 8)
|
||||
self.assertEqual(scheduler.timesteps[0].item(), 7)
|
||||
self.assertEqual(scheduler.timesteps[-1].item(), 0)
|
||||
|
||||
def test_set_timesteps_invalid(self):
|
||||
scheduler = self.get_scheduler()
|
||||
with self.assertRaises(ValueError):
|
||||
scheduler.set_timesteps(0)
|
||||
|
||||
def test_get_num_transfer_tokens_even(self):
|
||||
scheduler = self.get_scheduler()
|
||||
schedule = scheduler.get_num_transfer_tokens(block_length=32, num_inference_steps=8)
|
||||
self.assertEqual(schedule.sum().item(), 32)
|
||||
self.assertEqual(len(schedule), 8)
|
||||
self.assertTrue((schedule == 4).all().item())
|
||||
|
||||
def test_get_num_transfer_tokens_remainder(self):
|
||||
scheduler = self.get_scheduler()
|
||||
schedule = scheduler.get_num_transfer_tokens(block_length=10, num_inference_steps=3)
|
||||
self.assertEqual(schedule.sum().item(), 10)
|
||||
self.assertEqual(len(schedule), 3)
|
||||
self.assertEqual(schedule[0].item(), 4)
|
||||
self.assertEqual(schedule[1].item(), 3)
|
||||
self.assertEqual(schedule[2].item(), 3)
|
||||
|
||||
def test_transfer_schedule_created_on_set_timesteps(self):
|
||||
scheduler = self.get_scheduler(block_length=16)
|
||||
scheduler.set_timesteps(4)
|
||||
self.assertIsNotNone(scheduler._transfer_schedule)
|
||||
self.assertEqual(scheduler._transfer_schedule.sum().item(), 16)
|
||||
|
||||
def test_save_load_config_round_trip(self):
|
||||
scheduler = self.get_scheduler(block_length=64, threshold=0.8, editing_threshold=0.5, minimal_topk=2)
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
scheduler.save_config(tmpdir)
|
||||
loaded = BlockRefinementScheduler.from_pretrained(tmpdir)
|
||||
|
||||
self.assertEqual(loaded.config.block_length, 64)
|
||||
self.assertEqual(loaded.config.threshold, 0.8)
|
||||
self.assertEqual(loaded.config.editing_threshold, 0.5)
|
||||
self.assertEqual(loaded.config.minimal_topk, 2)
|
||||
|
||||
def test_from_config(self):
|
||||
scheduler = self.get_scheduler(block_length=16, threshold=0.7)
|
||||
new_scheduler = BlockRefinementScheduler.from_config(scheduler.config)
|
||||
self.assertEqual(new_scheduler.config.block_length, 16)
|
||||
self.assertEqual(new_scheduler.config.threshold, 0.7)
|
||||
|
||||
def test_step_commits_tokens(self):
|
||||
"""Verify that step() commits mask tokens based on confidence."""
|
||||
scheduler = self.get_scheduler(block_length=8)
|
||||
scheduler.set_timesteps(2)
|
||||
|
||||
batch_size, block_length, vocab_size = 1, 8, 32
|
||||
mask_id = 31
|
||||
|
||||
sample = torch.full((batch_size, block_length), mask_id, dtype=torch.long)
|
||||
# Create logits where confidence decreases with position
|
||||
logits = torch.zeros(batch_size, block_length, vocab_size)
|
||||
for i in range(block_length):
|
||||
logits[0, i, i] = 10.0 - i # decreasing confidence
|
||||
|
||||
out = scheduler.step(
|
||||
model_output=logits,
|
||||
timestep=0,
|
||||
sample=sample,
|
||||
mask_token_id=mask_id,
|
||||
temperature=0.0,
|
||||
threshold=0.95,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
# With 8 tokens and 2 steps, first step should commit 4 tokens
|
||||
committed = out.transfer_index[0].sum().item()
|
||||
self.assertEqual(committed, 4)
|
||||
|
||||
def test_step_no_editing_by_default(self):
|
||||
"""Without editing_threshold, no non-mask tokens should be changed."""
|
||||
scheduler = self.get_scheduler(block_length=4)
|
||||
scheduler.set_timesteps(2)
|
||||
|
||||
vocab_size = 32
|
||||
sample = torch.tensor([[10, 20, 31, 31]], dtype=torch.long)
|
||||
logits = torch.zeros(1, 4, vocab_size)
|
||||
logits[0, :, 15] = 10.0 # predict token 15 for all positions
|
||||
|
||||
out = scheduler.step(
|
||||
model_output=logits,
|
||||
timestep=0,
|
||||
sample=sample,
|
||||
mask_token_id=31,
|
||||
temperature=0.0,
|
||||
editing_threshold=None,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
self.assertFalse(out.editing_transfer_index.any().item())
|
||||
self.assertFalse(out.transfer_index[0, 0].item())
|
||||
self.assertFalse(out.transfer_index[0, 1].item())
|
||||
|
||||
def test_step_editing_replaces_tokens(self):
|
||||
"""With editing_threshold, non-mask tokens with high confidence and different prediction get replaced."""
|
||||
scheduler = self.get_scheduler(block_length=4)
|
||||
scheduler.set_timesteps(2)
|
||||
|
||||
vocab_size = 32
|
||||
sample = torch.tensor([[10, 20, 31, 31]], dtype=torch.long)
|
||||
logits = torch.zeros(1, 4, vocab_size)
|
||||
# Token 0: predict 50 (different from 10) with very high logit
|
||||
logits[0, 0, 15] = 20.0
|
||||
# Token 1: predict 20 (same as current)
|
||||
logits[0, 1, 20] = 20.0
|
||||
# Mask tokens
|
||||
logits[0, 2, 5] = 5.0
|
||||
logits[0, 3, 6] = 5.0
|
||||
|
||||
out = scheduler.step(
|
||||
model_output=logits,
|
||||
timestep=0,
|
||||
sample=sample,
|
||||
mask_token_id=31,
|
||||
temperature=0.0,
|
||||
editing_threshold=0.5,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
# Token 0 should be edited (different prediction, high confidence)
|
||||
self.assertTrue(out.editing_transfer_index[0, 0].item())
|
||||
# Token 1 should NOT be edited (same prediction)
|
||||
self.assertFalse(out.editing_transfer_index[0, 1].item())
|
||||
|
||||
def test_step_prompt_mask_prevents_editing(self):
|
||||
"""Prompt positions should never be edited even with editing enabled."""
|
||||
scheduler = self.get_scheduler(block_length=4)
|
||||
scheduler.set_timesteps(2)
|
||||
|
||||
vocab_size = 32
|
||||
sample = torch.tensor([[10, 20, 31, 31]], dtype=torch.long)
|
||||
logits = torch.zeros(1, 4, vocab_size)
|
||||
logits[0, :, 15] = 20.0
|
||||
prompt_mask = torch.tensor([True, True, False, False])
|
||||
|
||||
out = scheduler.step(
|
||||
model_output=logits,
|
||||
timestep=0,
|
||||
sample=sample,
|
||||
mask_token_id=31,
|
||||
temperature=0.0,
|
||||
editing_threshold=0.5,
|
||||
prompt_mask=prompt_mask,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
self.assertFalse(out.editing_transfer_index[0, 0].item())
|
||||
self.assertFalse(out.editing_transfer_index[0, 1].item())
|
||||
|
||||
def test_step_return_tuple(self):
|
||||
"""Verify tuple output when return_dict=False."""
|
||||
scheduler = self.get_scheduler(block_length=4)
|
||||
scheduler.set_timesteps(2)
|
||||
|
||||
vocab_size = 32
|
||||
sample = torch.full((1, 4), 31, dtype=torch.long)
|
||||
logits = torch.randn(1, 4, vocab_size)
|
||||
|
||||
result = scheduler.step(
|
||||
model_output=logits,
|
||||
timestep=0,
|
||||
sample=sample,
|
||||
mask_token_id=31,
|
||||
temperature=0.0,
|
||||
return_dict=False,
|
||||
)
|
||||
|
||||
self.assertIsInstance(result, tuple)
|
||||
self.assertEqual(len(result), 5)
|
||||
|
||||
def test_step_batched(self):
|
||||
"""Verify step works with batch_size > 1."""
|
||||
scheduler = self.get_scheduler(block_length=4)
|
||||
scheduler.set_timesteps(2)
|
||||
|
||||
batch_size, vocab_size = 3, 32
|
||||
mask_id = 31
|
||||
sample = torch.full((batch_size, 4), mask_id, dtype=torch.long)
|
||||
logits = torch.randn(batch_size, 4, vocab_size)
|
||||
|
||||
out = scheduler.step(
|
||||
model_output=logits,
|
||||
timestep=0,
|
||||
sample=sample,
|
||||
mask_token_id=mask_id,
|
||||
temperature=0.0,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
self.assertEqual(out.prev_sample.shape, (batch_size, 4))
|
||||
self.assertEqual(out.transfer_index.shape, (batch_size, 4))
|
||||
|
||||
def test_check_block_should_continue_finished(self):
|
||||
scheduler = self.get_scheduler()
|
||||
scheduler.set_timesteps(8)
|
||||
finished = torch.tensor([True, True])
|
||||
result = scheduler.check_block_should_continue(
|
||||
step_idx=0,
|
||||
masks_remaining=True,
|
||||
editing_enabled=False,
|
||||
editing_transfer_index=torch.zeros(2, 32, dtype=torch.bool),
|
||||
post_steps=0,
|
||||
max_post_steps=16,
|
||||
finished=finished,
|
||||
)
|
||||
self.assertFalse(result)
|
||||
|
||||
def test_check_block_should_continue_no_masks_no_edits(self):
|
||||
scheduler = self.get_scheduler()
|
||||
scheduler.set_timesteps(8)
|
||||
finished = torch.tensor([False])
|
||||
result = scheduler.check_block_should_continue(
|
||||
step_idx=5,
|
||||
masks_remaining=False,
|
||||
editing_enabled=True,
|
||||
editing_transfer_index=torch.zeros(1, 32, dtype=torch.bool),
|
||||
post_steps=1,
|
||||
max_post_steps=16,
|
||||
finished=finished,
|
||||
)
|
||||
self.assertFalse(result)
|
||||
|
||||
def test_check_block_should_continue_steps_exhausted(self):
|
||||
scheduler = self.get_scheduler()
|
||||
scheduler.set_timesteps(8)
|
||||
finished = torch.tensor([False])
|
||||
result = scheduler.check_block_should_continue(
|
||||
step_idx=8,
|
||||
masks_remaining=True,
|
||||
editing_enabled=False,
|
||||
editing_transfer_index=torch.zeros(1, 32, dtype=torch.bool),
|
||||
post_steps=0,
|
||||
max_post_steps=16,
|
||||
finished=finished,
|
||||
)
|
||||
self.assertFalse(result)
|
||||
|
||||
def test_check_eos_finished_marks_batch(self):
|
||||
"""When EOS is committed and all tokens before it are unmasked, mark batch as finished."""
|
||||
mask_id, eos_id, prompt_length = 99, 2, 2
|
||||
# cur_x: [prompt, prompt, token, eos, mask, mask]
|
||||
cur_x = torch.tensor([[10, 11, 5, eos_id, mask_id, mask_id]], dtype=torch.long)
|
||||
sampled_tokens = torch.tensor([[0, 0, 0, eos_id]], dtype=torch.long)
|
||||
final_transfer = torch.tensor([[False, False, False, True]])
|
||||
finished = torch.tensor([False])
|
||||
|
||||
finished = BlockRefinementScheduler.check_eos_finished(
|
||||
cur_x=cur_x,
|
||||
sampled_tokens=sampled_tokens,
|
||||
final_transfer=final_transfer,
|
||||
finished=finished,
|
||||
eos_token_id=eos_id,
|
||||
mask_token_id=mask_id,
|
||||
prompt_length=prompt_length,
|
||||
)
|
||||
self.assertTrue(finished[0].item())
|
||||
|
||||
def test_check_eos_finished_ignores_when_masks_before_eos(self):
|
||||
"""If there are still mask tokens between prompt and EOS, don't mark as finished."""
|
||||
mask_id, eos_id, prompt_length = 99, 2, 2
|
||||
# cur_x: [prompt, prompt, mask, eos] — mask before EOS
|
||||
cur_x = torch.tensor([[10, 11, mask_id, eos_id]], dtype=torch.long)
|
||||
sampled_tokens = torch.tensor([[0, 0]], dtype=torch.long)
|
||||
final_transfer = torch.tensor([[False, True]])
|
||||
finished = torch.tensor([False])
|
||||
|
||||
finished = BlockRefinementScheduler.check_eos_finished(
|
||||
cur_x=cur_x,
|
||||
sampled_tokens=sampled_tokens,
|
||||
final_transfer=final_transfer,
|
||||
finished=finished,
|
||||
eos_token_id=eos_id,
|
||||
mask_token_id=mask_id,
|
||||
prompt_length=prompt_length,
|
||||
)
|
||||
self.assertFalse(finished[0].item())
|
||||
|
||||
def test_check_eos_finished_already_finished(self):
|
||||
"""Already-finished batches should stay finished."""
|
||||
mask_id, eos_id = 99, 2
|
||||
cur_x = torch.tensor([[10, 11, 5, 6]], dtype=torch.long)
|
||||
sampled_tokens = torch.tensor([[0, 0]], dtype=torch.long)
|
||||
final_transfer = torch.tensor([[False, False]])
|
||||
finished = torch.tensor([True])
|
||||
|
||||
finished = BlockRefinementScheduler.check_eos_finished(
|
||||
cur_x=cur_x,
|
||||
sampled_tokens=sampled_tokens,
|
||||
final_transfer=final_transfer,
|
||||
finished=finished,
|
||||
eos_token_id=eos_id,
|
||||
mask_token_id=mask_id,
|
||||
prompt_length=2,
|
||||
)
|
||||
self.assertTrue(finished[0].item())
|
||||
|
||||
def test_add_noise(self):
|
||||
scheduler = self.get_scheduler(block_length=4)
|
||||
input_ids = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]], dtype=torch.long)
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
mask_token_id = 99
|
||||
|
||||
gen = torch.Generator().manual_seed(42)
|
||||
noisy, noisy_rev, masked, masked_rev = scheduler.add_noise(
|
||||
input_ids,
|
||||
attention_mask,
|
||||
prompt_length=2,
|
||||
block_length=4,
|
||||
mask_token_id=mask_token_id,
|
||||
generator=gen,
|
||||
)
|
||||
|
||||
# Prompt positions should never be masked
|
||||
self.assertFalse(masked[0, 0].item())
|
||||
self.assertFalse(masked[0, 1].item())
|
||||
self.assertFalse(masked_rev[0, 0].item())
|
||||
self.assertFalse(masked_rev[0, 1].item())
|
||||
|
||||
# Noisy should have mask_token_id where masked is True
|
||||
self.assertTrue((noisy[masked] == mask_token_id).all().item())
|
||||
self.assertTrue((noisy_rev[masked_rev] == mask_token_id).all().item())
|
||||
|
||||
# masked and masked_rev should be complementary within valid non-prompt positions
|
||||
non_prompt = torch.zeros_like(masked)
|
||||
non_prompt[0, 2:] = True
|
||||
combined = masked | masked_rev
|
||||
self.assertTrue((combined[0, 2:] == non_prompt[0, 2:]).all().item())
|
||||
|
||||
|
||||
class TestTopPFiltering(unittest.TestCase):
|
||||
def test_top_p_filtering(self):
|
||||
logits = torch.tensor([[1.0, 2.0, 3.0, 4.0]])
|
||||
filtered = BlockRefinementScheduler._top_p_filtering(logits, top_p=0.5)
|
||||
self.assertTrue((filtered > torch.finfo(filtered.dtype).min).any())
|
||||
self.assertTrue((filtered == torch.finfo(filtered.dtype).min).any())
|
||||
|
||||
def test_top_p_filtering_none(self):
|
||||
logits = torch.tensor([[1.0, 2.0, 3.0]])
|
||||
result = BlockRefinementScheduler._top_p_filtering(logits, top_p=None)
|
||||
self.assertTrue(torch.equal(result, logits))
|
||||
|
||||
def test_top_p_filtering_one(self):
|
||||
logits = torch.tensor([[1.0, 2.0, 3.0]])
|
||||
result = BlockRefinementScheduler._top_p_filtering(logits, top_p=1.0)
|
||||
self.assertTrue(torch.equal(result, logits))
|
||||
|
||||
|
||||
class TestTopKFiltering(unittest.TestCase):
|
||||
def test_top_k_filtering(self):
|
||||
logits = torch.tensor([[1.0, 4.0, 2.0, 3.0]])
|
||||
filtered = BlockRefinementScheduler._top_k_filtering(logits, top_k=2)
|
||||
self.assertAlmostEqual(filtered[0, 1].item(), 4.0)
|
||||
self.assertAlmostEqual(filtered[0, 3].item(), 3.0)
|
||||
self.assertEqual(filtered[0, 0].item(), torch.finfo(filtered.dtype).min)
|
||||
self.assertEqual(filtered[0, 2].item(), torch.finfo(filtered.dtype).min)
|
||||
|
||||
def test_top_k_filtering_none(self):
|
||||
logits = torch.tensor([[1.0, 2.0, 3.0]])
|
||||
result = BlockRefinementScheduler._top_k_filtering(logits, top_k=None)
|
||||
self.assertTrue(torch.equal(result, logits))
|
||||
|
||||
def test_top_k_filtering_zero(self):
|
||||
logits = torch.tensor([[1.0, 2.0, 3.0]])
|
||||
result = BlockRefinementScheduler._top_k_filtering(logits, top_k=0)
|
||||
self.assertTrue(torch.equal(result, logits))
|
||||
|
||||
def test_top_k_filtering_large_k(self):
|
||||
logits = torch.tensor([[1.0, 2.0, 3.0]])
|
||||
result = BlockRefinementScheduler._top_k_filtering(logits, top_k=100)
|
||||
self.assertTrue(torch.equal(result, logits))
|
||||
|
||||
|
||||
class TestSampleFromLogits(unittest.TestCase):
|
||||
def test_greedy_sampling(self):
|
||||
logits = torch.tensor([[1.0, 5.0, 2.0]])
|
||||
tokens, probs = BlockRefinementScheduler._sample_from_logits(
|
||||
logits,
|
||||
temperature=0.0,
|
||||
top_k=None,
|
||||
top_p=None,
|
||||
generator=None,
|
||||
use_multinomial=False,
|
||||
)
|
||||
self.assertEqual(tokens.item(), 1)
|
||||
self.assertEqual(tokens.shape, (1,))
|
||||
self.assertEqual(probs.shape, (1,))
|
||||
|
||||
def test_multinomial_sampling(self):
|
||||
logits = torch.tensor([[0.0, 100.0, -100.0]])
|
||||
gen = torch.Generator().manual_seed(42)
|
||||
tokens, probs = BlockRefinementScheduler._sample_from_logits(
|
||||
logits,
|
||||
temperature=1.0,
|
||||
top_k=None,
|
||||
top_p=None,
|
||||
generator=gen,
|
||||
use_multinomial=True,
|
||||
)
|
||||
self.assertEqual(tokens.item(), 1)
|
||||
|
||||
def test_temperature_scaling(self):
|
||||
logits = torch.tensor([[1.0, 2.0, 3.0]])
|
||||
tokens, _ = BlockRefinementScheduler._sample_from_logits(
|
||||
logits,
|
||||
temperature=0.01,
|
||||
top_k=None,
|
||||
top_p=None,
|
||||
generator=None,
|
||||
use_multinomial=False,
|
||||
)
|
||||
self.assertEqual(tokens.item(), 2)
|
||||
|
||||
def test_negative_temperature_raises(self):
|
||||
logits = torch.tensor([[1.0, 2.0]])
|
||||
with self.assertRaises(ValueError):
|
||||
BlockRefinementScheduler._sample_from_logits(
|
||||
logits,
|
||||
temperature=-1.0,
|
||||
top_k=None,
|
||||
top_p=None,
|
||||
generator=None,
|
||||
use_multinomial=False,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -43,7 +43,7 @@ def filter_pipelines(usage_dict, usage_cutoff=10000):
|
||||
|
||||
|
||||
def fetch_pipeline_objects():
|
||||
models = api.list_models(library="diffusers")
|
||||
models = api.list_models(filter="diffusers")
|
||||
downloads = defaultdict(int)
|
||||
|
||||
for model in models:
|
||||
|
||||
Reference in New Issue
Block a user