Compare commits

...

33 Commits

Author SHA1 Message Date
Sayak Paul
58f7ede8f2 Merge branch 'main' into model-test-refactor 2026-01-16 14:10:07 +05:30
dg845
8af8e86bc7 LTX 2 Single File Support (#12983)
* LTX 2 transformer single file support

* LTX 2 video VAE single file support

* LTX 2 audio VAE single file support

* Make it easier to distinguish LTX 1 and 2 models
2026-01-15 22:46:42 -08:00
Sayak Paul
74654df203 add klein docs. (#12984) 2026-01-16 10:12:42 +05:30
YiYi Xu
f112eab97e [modular] fix a bug in mellon param & improve docstrings (#12980)
* update mellonparams docstring to incude the acutal param definition render in mellon

* style

---------

Co-authored-by: yiyi@huggingface.co <yiyi@ip-26-0-160-103.ec2.internal>
2026-01-15 10:42:42 -10:00
YiYi Xu
61f175660a Flux2 klein (#12982)
* flux2-klein

* Apply suggestions from code review

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* Klein tests (#2)

* tests

* up

* tests

* up

* support step-distilled

* Apply suggestions from code review

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Apply suggestions from code review

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* doc string etc

* style

* more

* copies

* klein lora training scripts (#3)

* initial commit

* initial commit

* remove remote text encoder

* initial commit

* initial commit

* initial commit

* revert

* img2img fix

* text encoder + tokenizer

* text encoder + tokenizer

* update readme

* guidance

* guidance

* guidance

* test

* test

* revert changes not needed for the non klein model

* Update examples/dreambooth/train_dreambooth_lora_flux2_klein.py

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* fix guidance

* fix validation

* fix validation

* fix validation

* fix path

* space

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* style

* Update src/diffusers/pipelines/flux2/pipeline_flux2_klein.py

* Apply style fixes

* auto pipeline

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>
Co-authored-by: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2026-01-15 09:10:54 -10:00
DN6
69f68c1a65 update 2026-01-13 15:49:21 +05:30
DN6
da801e97ba update 2026-01-13 15:15:13 +05:30
DN6
ce3097c65b update 2026-01-13 15:11:18 +05:30
DN6
6dfba74c21 update 2026-01-13 14:21:16 +05:30
DN6
3620e4ffee update 2026-01-13 13:33:40 +05:30
DN6
7334262fd9 update 2026-01-13 12:25:08 +05:30
DN6
0a639d1843 update 2026-01-13 11:18:46 +05:30
DN6
9362584c66 update 2026-01-13 10:52:50 +05:30
DN6
5c2d30623e update 2026-01-13 10:38:16 +05:30
DN6
6caa0a9bf4 update 2026-01-08 12:21:13 +05:30
DN6
ba475eee8d update 2026-01-08 12:21:13 +05:30
Sayak Paul
e0ab03d79b Merge branch 'main' into model-test-refactor 2025-12-31 21:03:32 +05:30
DN6
7b3ef42a01 update 2025-12-26 12:45:30 +05:30
DN6
c70de2bc37 update 2025-12-18 13:18:54 +05:30
DN6
e82001e40d update 2025-12-18 13:16:50 +05:30
DN6
d9b73ffd51 update 2025-12-15 16:12:50 +05:30
DN6
dcd6026d17 update 2025-12-15 16:12:15 +05:30
DN6
eae7543712 update 2025-12-15 16:02:38 +05:30
DN6
d08e0bb545 update 2025-12-15 14:19:27 +05:30
DN6
c366b5a817 update 2025-12-11 13:37:06 +05:30
DN6
0fdd9d3a60 update 2025-12-11 11:41:17 +05:30
DN6
489480b02a update 2025-12-11 11:27:59 +05:30
DN6
fe451c367b update 2025-12-11 11:04:47 +05:30
DN6
0f1a4e0c14 update 2025-11-19 21:59:20 +05:30
DN6
aa29af8f0e update 2025-11-19 08:51:38 +05:30
DN6
bffa3a9754 update 2025-11-14 15:48:19 +05:30
DN6
1c558712e8 Merge branch 'main' into model-test-refactor 2025-11-12 10:18:07 +05:30
DN6
1f026ad14e update 2025-11-12 10:17:54 +05:30
35 changed files with 11768 additions and 181 deletions

View File

@@ -35,5 +35,11 @@ The [official implementation](https://github.com/black-forest-labs/flux2/blob/5a
## Flux2Pipeline
[[autodoc]] Flux2Pipeline
- all
- __call__
## Flux2KleinPipeline
[[autodoc]] Flux2KleinPipeline
- all
- __call__

View File

@@ -1,14 +1,22 @@
# DreamBooth training example for FLUX.2 [dev]
# DreamBooth training example for FLUX.2 [dev] and FLUX 2 [klein]
[DreamBooth](https://huggingface.co/papers/2208.12242) is a method to personalize image generation models given just a few (3~5) images of a subject/concept.
[LoRA](https://huggingface.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora) is a popular parameter-efficient fine-tuning technique that allows you to achieve full-finetuning like performance but with a fraction of learnable parameters.
The `train_dreambooth_lora_flux2.py` script shows how to implement the training procedure for [LoRAs](https://huggingface.co/blog/lora) and adapt it for [FLUX.2 [dev]](https://github.com/black-forest-labs/flux2).
The `train_dreambooth_lora_flux2.py`, `train_dreambooth_lora_flux2_klein.py` scripts shows how to implement the training procedure for [LoRAs](https://huggingface.co/blog/lora) and adapt it for [FLUX.2 [dev]](https://huggingface.co/black-forest-labs/FLUX.2-dev) and [FLUX 2 [klein]](https://huggingface.co/black-forest-labs/FLUX.2-klein).
> [!NOTE]
> **Model Variants**
>
> We support two FLUX model families:
> - **FLUX.2 [dev]**: The full-size model using Mistral Small 3.1 as the text encoder. Very capable but memory intensive.
> - **FLUX 2 [klein]**: Available in 4B and 9B parameter variants, using Qwen VL as the text encoder. Much more memory efficient and suitable for consumer hardware.
> [!NOTE]
> **Memory consumption**
>
> Flux can be quite expensive to run on consumer hardware devices and as a result finetuning it comes with high memory requirements -
> a LoRA with a rank of 16 can exceed XXGB of VRAM for training. below we provide some tips and tricks to reduce memory consumption during training.
> FLUX.2 [dev] can be quite expensive to run on consumer hardware devices and as a result finetuning it comes with high memory requirements -
> a LoRA with a rank of 16 can exceed XXGB of VRAM for training. FLUX 2 [klein] models (4B and 9B) are significantly more memory efficient alternatives. Below we provide some tips and tricks to reduce memory consumption during training.
> For more tips & guidance on training on a resource-constrained device and general good practices please check out these great guides and trainers for FLUX:
> 1) [`@bghira`'s guide](https://github.com/bghira/SimpleTuner/blob/main/documentation/quickstart/FLUX2.md)
@@ -17,7 +25,7 @@ The `train_dreambooth_lora_flux2.py` script shows how to implement the training
> [!NOTE]
> **Gated model**
>
> As the model is gated, before using it with diffusers you first need to go to the [FLUX.2 [dev] Hugging Face page](https://huggingface.co/black-forest-labs/FLUX.2-dev), fill in the form and accept the gate. Once you are in, you need to log in so that your system knows youve accepted the gate. Use the command below to log in:
> As the model is gated, before using it with diffusers you first need to go to the [FLUX.2 [dev] Hugging Face page](https://huggingface.co/black-forest-labs/FLUX.2-dev), fill in the form and accept the gate. Once you are in, you need to log in so that your system knows you've accepted the gate. Use the command below to log in:
```bash
hf auth login
@@ -88,23 +96,32 @@ snapshot_download(
This will also allow us to push the trained LoRA parameters to the Hugging Face Hub platform.
As mentioned, Flux2 LoRA training is *very* memory intensive. Here are memory optimizations we can use (some still experimental) for a more memory efficient training:
As mentioned, Flux2 LoRA training is *very* memory intensive (especially for FLUX.2 [dev]). Here are memory optimizations we can use (some still experimental) for a more memory efficient training:
## Memory Optimizations
> [!NOTE] many of these techniques complement each other and can be used together to further reduce memory consumption.
> However some techniques may be mutually exclusive so be sure to check before launching a training run.
### Remote Text Encoder
Flux.2 uses Mistral Small 3.1 as text encoder which is quite large and can take up a lot of memory. To mitigate this, we can use the `--remote_text_encoder` flag to enable remote computation of the prompt embeddings using the HuggingFace Inference API.
FLUX.2 [dev] uses Mistral Small 3.1 as text encoder which is quite large and can take up a lot of memory. To mitigate this, we can use the `--remote_text_encoder` flag to enable remote computation of the prompt embeddings using the HuggingFace Inference API.
This way, the text encoder model is not loaded into memory during training.
> [!IMPORTANT]
> **Remote text encoder is only supported for FLUX.2 [dev]**. FLUX 2 [klein] models use the Qwen VL text encoder and do not support remote text encoding.
> [!NOTE]
> to enable remote text encoding you must either be logged in to your HuggingFace account (`hf auth login`) OR pass a token with `--hub_token`.
### FSDP Text Encoder
Flux.2 uses Mistral Small 3.1 as text encoder which is quite large and can take up a lot of memory. To mitigate this, we can use the `--fsdp_text_encoder` flag to enable distributed computation of the prompt embeddings.
FLUX.2 [dev] uses Mistral Small 3.1 as text encoder which is quite large and can take up a lot of memory. To mitigate this, we can use the `--fsdp_text_encoder` flag to enable distributed computation of the prompt embeddings.
This way, it distributes the memory cost across multiple nodes.
### CPU Offloading
To offload parts of the model to CPU memory, you can use `--offload` flag. This will offload the vae and text encoder to CPU memory and only move them to GPU when needed.
### Latent Caching
Pre-encode the training images with the vae, and then delete it to free up some memory. To enable `latent_caching` simply pass `--cache_latents`.
### QLoRA: Low Precision Training with Quantization
Perform low precision training using 8-bit or 4-bit quantization to reduce memory usage. You can use the following flags:
- **FP8 training** with `torchao`:
@@ -114,22 +131,29 @@ enable FP8 training by passing `--do_fp8_training`.
- **NF4 training** with `bitsandbytes`:
Alternatively, you can use 8-bit or 4-bit quantization with `bitsandbytes` by passing:
`--bnb_quantization_config_path` to enable 4-bit NF4 quantization.
### Gradient Checkpointing and Accumulation
* `--gradient accumulation` refers to the number of updates steps to accumulate before performing a backward/update pass.
by passing a value > 1 you can reduce the amount of backward/update passes and hence also memory reqs.
* with `--gradient checkpointing` we can save memory by not storing all intermediate activations during the forward pass.
Instead, only a subset of these activations (the checkpoints) are stored and the rest is recomputed as needed during the backward pass. Note that this comes at the expanse of a slower backward pass.
### 8-bit-Adam Optimizer
When training with `AdamW`(doesn't apply to `prodigy`) You can pass `--use_8bit_adam` to reduce the memory requirements of training.
Make sure to install `bitsandbytes` if you want to do so.
### Image Resolution
An easy way to mitigate some of the memory requirements is through `--resolution`. `--resolution` refers to the resolution for input images, all the images in the train/validation dataset are resized to this.
Note that by default, images are resized to resolution of 512, but it's good to keep in mind in case you're accustomed to training on higher resolutions.
### Precision of saved LoRA layers
By default, trained transformer layers are saved in the precision dtype in which training was performed. E.g. when training in mixed precision is enabled with `--mixed_precision="bf16"`, final finetuned layers will be saved in `torch.bfloat16` as well.
This reduces memory requirements significantly w/o a significant quality loss. Note that if you do wish to save the final layers in float32 at the expanse of more memory usage, you can do so by passing `--upcast_before_saving`.
## Training Examples
### FLUX.2 [dev] Training
To perform DreamBooth with LoRA on FLUX.2 [dev], run:
```bash
export MODEL_NAME="black-forest-labs/FLUX.2-dev"
export INSTANCE_DIR="dog"
@@ -161,13 +185,84 @@ accelerate launch train_dreambooth_lora_flux2.py \
--push_to_hub
```
### FLUX 2 [klein] Training
FLUX 2 [klein] models are more memory efficient alternatives available in 4B and 9B parameter variants. They use the Qwen VL text encoder instead of Mistral Small 3.1.
> [!NOTE]
> The `--remote_text_encoder` flag is **not supported** for FLUX 2 [klein] models. The Qwen VL text encoder must be loaded locally, but offloading is still supported.
**FLUX 2 [klein] 4B:**
```bash
export MODEL_NAME="black-forest-labs/FLUX.2-klein-4B"
export INSTANCE_DIR="dog"
export OUTPUT_DIR="trained-flux2-klein-4b"
accelerate launch train_dreambooth_lora_flux2_klein.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--instance_data_dir=$INSTANCE_DIR \
--output_dir=$OUTPUT_DIR \
--do_fp8_training \
--gradient_checkpointing \
--cache_latents \
--instance_prompt="a photo of sks dog" \
--resolution=1024 \
--train_batch_size=1 \
--guidance_scale=1 \
--use_8bit_adam \
--gradient_accumulation_steps=4 \
--optimizer="adamW" \
--learning_rate=1e-4 \
--report_to="wandb" \
--lr_scheduler="constant" \
--lr_warmup_steps=100 \
--max_train_steps=500 \
--validation_prompt="A photo of sks dog in a bucket" \
--validation_epochs=25 \
--seed="0" \
--push_to_hub
```
**FLUX 2 [klein] 9B:**
```bash
export MODEL_NAME="black-forest-labs/FLUX.2-klein-9B"
export INSTANCE_DIR="dog"
export OUTPUT_DIR="trained-flux2-klein-9b"
accelerate launch train_dreambooth_lora_flux2_klein.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--instance_data_dir=$INSTANCE_DIR \
--output_dir=$OUTPUT_DIR \
--do_fp8_training \
--gradient_checkpointing \
--cache_latents \
--instance_prompt="a photo of sks dog" \
--resolution=1024 \
--train_batch_size=1 \
--guidance_scale=1 \
--use_8bit_adam \
--gradient_accumulation_steps=4 \
--optimizer="adamW" \
--learning_rate=1e-4 \
--report_to="wandb" \
--lr_scheduler="constant" \
--lr_warmup_steps=100 \
--max_train_steps=500 \
--validation_prompt="A photo of sks dog in a bucket" \
--validation_epochs=25 \
--seed="0" \
--push_to_hub
```
To better track our training experiments, we're using the following flags in the command above:
* `report_to="wandb` will ensure the training runs are tracked on [Weights and Biases](https://wandb.ai/site). To use it, be sure to install `wandb` with `pip install wandb`. Don't forget to call `wandb login <your_api_key>` before training if you haven't done it before.
* `validation_prompt` and `validation_epochs` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.
> [!NOTE]
> If you want to train using long prompts with the T5 text encoder, you can use `--max_sequence_length` to set the token limit. The default is 77, but it can be increased to as high as 512. Note that this will use more resources and may slow down the training in some cases.
> If you want to train using long prompts, you can use `--max_sequence_length` to set the token limit. Note that this will use more resources and may slow down the training in some cases.
### FSDP on the transformer
By setting the accelerate configuration with FSDP, the transformer block will be wrapped automatically. E.g. set the configuration to:
@@ -189,12 +284,6 @@ fsdp_config:
fsdp_cpu_ram_efficient_loading: false
```
## LoRA + DreamBooth
[LoRA](https://huggingface.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora) is a popular parameter-efficient fine-tuning technique that allows you to achieve full-finetuning like performance but with a fraction of learnable parameters.
Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.
### Prodigy Optimizer
Prodigy is an adaptive optimizer that dynamically adjusts the learning rate learned parameters based on past gradients, allowing for more efficient convergence.
By using prodigy we can "eliminate" the need for manual learning rate tuning. read more [here](https://huggingface.co/blog/sdxl_lora_advanced_script#adaptive-optimizers).
@@ -206,8 +295,6 @@ to use prodigy, first make sure to install the prodigyopt library: `pip install
> [!TIP]
> When using prodigy it's generally good practice to set- `--learning_rate=1.0`
To perform DreamBooth with LoRA, run:
```bash
export MODEL_NAME="black-forest-labs/FLUX.2-dev"
export INSTANCE_DIR="dog"
@@ -271,13 +358,10 @@ the exact modules for LoRA training. Here are some examples of target modules yo
> keep in mind that while training more layers can improve quality and expressiveness, it also increases the size of the output LoRA weights.
## Training Image-to-Image
Flux.2 lets us perform image editing as well as image generation. We provide a simple script for image-to-image(I2I) LoRA fine-tuning in [train_dreambooth_lora_flux2_img2img.py](./train_dreambooth_lora_flux2_img2img.py) for both T2I and I2I. The optimizations discussed above apply this script, too.
**important**
**Important**
To make sure you can successfully run the latest version of the image-to-image example script, we highly recommend installing from source, specifically from the commit mentioned below. To do this, execute the following steps in a new virtual environment:
@@ -334,5 +418,6 @@ we've added aspect ratio bucketing support which allows training on images with
To enable aspect ratio bucketing, pass `--aspect_ratio_buckets` argument with a semicolon-separated list of height,width pairs, such as:
`--aspect_ratio_buckets="672,1568;688,1504;720,1456;752,1392;800,1328;832,1248;880,1184;944,1104;1024,1024;1104,944;1184,880;1248,832;1328,800;1392,752;1456,720;1504,688;1568,672"
`
Since Flux.2 finetuning is still an experimental phase, we encourage you to explore different settings and share your insights! 🤗
Since Flux.2 finetuning is still an experimental phase, we encourage you to explore different settings and share your insights! 🤗

View File

@@ -0,0 +1,262 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import logging
import os
import sys
import tempfile
import safetensors
from diffusers.loaders.lora_base import LORA_ADAPTER_METADATA_KEY
sys.path.append("..")
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger()
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
class DreamBoothLoRAFlux2Klein(ExamplesTestsAccelerate):
instance_data_dir = "docs/source/en/imgs"
instance_prompt = "dog"
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux2-klein"
script_path = "examples/dreambooth/train_dreambooth_lora_flux2_klein.py"
transformer_layer_type = "single_transformer_blocks.0.attn.to_qkv_mlp_proj"
def test_dreambooth_lora_flux2(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
--instance_data_dir {self.instance_data_dir}
--instance_prompt {self.instance_prompt}
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--max_sequence_length 8
--text_encoder_out_layers 1
--output_dir {tmpdir}
""".split()
run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)
# when not training the text encoder, all the parameters in the state dict should start
# with `"transformer"` in their names.
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
self.assertTrue(starts_with_transformer)
def test_dreambooth_lora_latent_caching(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
--instance_data_dir {self.instance_data_dir}
--instance_prompt {self.instance_prompt}
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--cache_latents
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--max_sequence_length 8
--text_encoder_out_layers 1
--output_dir {tmpdir}
""".split()
run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)
# when not training the text encoder, all the parameters in the state dict should start
# with `"transformer"` in their names.
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
self.assertTrue(starts_with_transformer)
def test_dreambooth_lora_layers(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
--instance_data_dir {self.instance_data_dir}
--instance_prompt {self.instance_prompt}
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--cache_latents
--learning_rate 5.0e-04
--scale_lr
--lora_layers {self.transformer_layer_type}
--lr_scheduler constant
--lr_warmup_steps 0
--max_sequence_length 8
--text_encoder_out_layers 1
--output_dir {tmpdir}
""".split()
run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)
# when not training the text encoder, all the parameters in the state dict should start
# with `"transformer"` in their names. In this test, we only params of
# transformer.single_transformer_blocks.0.attn.to_k should be in the state dict
starts_with_transformer = all(
key.startswith(f"transformer.{self.transformer_layer_type}") for key in lora_state_dict.keys()
)
self.assertTrue(starts_with_transformer)
def test_dreambooth_lora_flux2_checkpointing_checkpoints_total_limit(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
--instance_data_dir={self.instance_data_dir}
--output_dir={tmpdir}
--instance_prompt={self.instance_prompt}
--resolution=64
--train_batch_size=1
--gradient_accumulation_steps=1
--max_train_steps=6
--checkpoints_total_limit=2
--max_sequence_length 8
--checkpointing_steps=2
--text_encoder_out_layers 1
""".split()
run_command(self._launch_args + test_args)
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-4", "checkpoint-6"},
)
def test_dreambooth_lora_flux2_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
--instance_data_dir={self.instance_data_dir}
--output_dir={tmpdir}
--instance_prompt={self.instance_prompt}
--resolution=64
--train_batch_size=1
--gradient_accumulation_steps=1
--max_train_steps=4
--checkpointing_steps=2
--max_sequence_length 8
--text_encoder_out_layers 1
""".split()
run_command(self._launch_args + test_args)
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-2", "checkpoint-4"})
resume_run_args = f"""
{self.script_path}
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
--instance_data_dir={self.instance_data_dir}
--output_dir={tmpdir}
--instance_prompt={self.instance_prompt}
--resolution=64
--train_batch_size=1
--gradient_accumulation_steps=1
--max_train_steps=8
--checkpointing_steps=2
--resume_from_checkpoint=checkpoint-4
--checkpoints_total_limit=2
--max_sequence_length 8
--text_encoder_out_layers 1
""".split()
run_command(self._launch_args + resume_run_args)
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})
def test_dreambooth_lora_with_metadata(self):
# Use a `lora_alpha` that is different from `rank`.
lora_alpha = 8
rank = 4
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
--instance_data_dir {self.instance_data_dir}
--instance_prompt {self.instance_prompt}
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--lora_alpha={lora_alpha}
--rank={rank}
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--max_sequence_length 8
--text_encoder_out_layers 1
--output_dir {tmpdir}
""".split()
run_command(self._launch_args + test_args)
# save_pretrained smoke test
state_dict_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors")
self.assertTrue(os.path.isfile(state_dict_file))
# Check if the metadata was properly serialized.
with safetensors.torch.safe_open(state_dict_file, framework="pt", device="cpu") as f:
metadata = f.metadata() or {}
metadata.pop("format", None)
raw = metadata.get(LORA_ADAPTER_METADATA_KEY)
if raw:
raw = json.loads(raw)
loaded_lora_alpha = raw["transformer.lora_alpha"]
self.assertTrue(loaded_lora_alpha == lora_alpha)
loaded_lora_rank = raw["transformer.r"]
self.assertTrue(loaded_lora_rank == rank)

View File

@@ -127,7 +127,7 @@ def save_model_card(
)
model_description = f"""
# Flux DreamBooth LoRA - {repo_id}
# Flux.2 DreamBooth LoRA - {repo_id}
<Gallery />

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -44,7 +44,7 @@ CTX = init_empty_weights if is_accelerate_available() else nullcontext
parser = argparse.ArgumentParser()
parser.add_argument("--original_state_dict_repo_id", default=None, type=str)
parser.add_argument("--vae_filename", default="flux2-vae.sft", type=str)
parser.add_argument("--dit_filename", default="flux-dev-dummy.sft", type=str)
parser.add_argument("--dit_filename", default="flux2-dev.safetensors", type=str)
parser.add_argument("--vae", action="store_true")
parser.add_argument("--dit", action="store_true")
parser.add_argument("--vae_dtype", type=str, default="fp32")
@@ -385,9 +385,9 @@ def update_state_dict(state_dict: Dict[str, Any], old_key: str, new_key: str) ->
def get_flux2_transformer_config(model_type: str) -> Tuple[Dict[str, Any], ...]:
if model_type == "test" or model_type == "dummy-flux2":
if model_type == "flux2-dev":
config = {
"model_id": "diffusers-internal-dev/dummy-flux2",
"model_id": "black-forest-labs/FLUX.2-dev",
"diffusers_config": {
"patch_size": 1,
"in_channels": 128,
@@ -405,6 +405,53 @@ def get_flux2_transformer_config(model_type: str) -> Tuple[Dict[str, Any], ...]:
}
rename_dict = FLUX2_TRANSFORMER_KEYS_RENAME_DICT
special_keys_remap = TRANSFORMER_SPECIAL_KEYS_REMAP
elif model_type == "klein-4b":
config = {
"model_id": "diffusers-internal-dev/dummy0115",
"diffusers_config": {
"patch_size": 1,
"in_channels": 128,
"num_layers": 5,
"num_single_layers": 20,
"attention_head_dim": 128,
"num_attention_heads": 24,
"joint_attention_dim": 7680,
"timestep_guidance_channels": 256,
"mlp_ratio": 3.0,
"axes_dims_rope": (32, 32, 32, 32),
"rope_theta": 2000,
"eps": 1e-6,
"guidance_embeds": False,
},
}
rename_dict = FLUX2_TRANSFORMER_KEYS_RENAME_DICT
special_keys_remap = TRANSFORMER_SPECIAL_KEYS_REMAP
elif model_type == "klein-9b":
config = {
"model_id": "diffusers-internal-dev/dummy0115",
"diffusers_config": {
"patch_size": 1,
"in_channels": 128,
"num_layers": 8,
"num_single_layers": 24,
"attention_head_dim": 128,
"num_attention_heads": 32,
"joint_attention_dim": 12288,
"timestep_guidance_channels": 256,
"mlp_ratio": 3.0,
"axes_dims_rope": (32, 32, 32, 32),
"rope_theta": 2000,
"eps": 1e-6,
"guidance_embeds": False,
},
}
rename_dict = FLUX2_TRANSFORMER_KEYS_RENAME_DICT
special_keys_remap = TRANSFORMER_SPECIAL_KEYS_REMAP
else:
raise ValueError(f"Unknown model_type: {model_type}. Choose from: flux2-dev, klein-4b, klein-9b")
return config, rename_dict, special_keys_remap
@@ -447,7 +494,14 @@ def main(args):
if args.dit:
original_dit_ckpt = load_original_checkpoint(args, filename=args.dit_filename)
transformer = convert_flux2_transformer_to_diffusers(original_dit_ckpt, "test")
if "klein-4b" in args.dit_filename:
model_type = "klein-4b"
elif "klein-9b" in args.dit_filename:
model_type = "klein-9b"
else:
model_type = "flux2-dev"
transformer = convert_flux2_transformer_to_diffusers(original_dit_ckpt, model_type)
if not args.full_pipe:
dit_dtype = torch.bfloat16 if args.dit_dtype == "bf16" else torch.float32
transformer.to(dit_dtype).save_pretrained(f"{args.output_path}/transformer")
@@ -465,8 +519,15 @@ def main(args):
"black-forest-labs/FLUX.1-dev", subfolder="scheduler"
)
if_distilled = "base" not in args.dit_filename
pipe = Flux2Pipeline(
vae=vae, transformer=transformer, text_encoder=text_encoder, tokenizer=tokenizer, scheduler=scheduler
vae=vae,
transformer=transformer,
text_encoder=text_encoder,
tokenizer=tokenizer,
scheduler=scheduler,
if_distilled=if_distilled,
)
pipe.save_pretrained(args.output_path)

View File

@@ -481,6 +481,7 @@ else:
"EasyAnimateControlPipeline",
"EasyAnimateInpaintPipeline",
"EasyAnimatePipeline",
"Flux2KleinPipeline",
"Flux2Pipeline",
"FluxControlImg2ImgPipeline",
"FluxControlInpaintPipeline",
@@ -1208,6 +1209,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
EasyAnimateControlPipeline,
EasyAnimateInpaintPipeline,
EasyAnimatePipeline,
Flux2KleinPipeline,
Flux2Pipeline,
FluxControlImg2ImgPipeline,
FluxControlInpaintPipeline,

View File

@@ -40,6 +40,9 @@ from .single_file_utils import (
convert_hunyuan_video_transformer_to_diffusers,
convert_ldm_unet_checkpoint,
convert_ldm_vae_checkpoint,
convert_ltx2_audio_vae_to_diffusers,
convert_ltx2_transformer_to_diffusers,
convert_ltx2_vae_to_diffusers,
convert_ltx_transformer_checkpoint_to_diffusers,
convert_ltx_vae_checkpoint_to_diffusers,
convert_lumina2_to_diffusers,
@@ -176,6 +179,18 @@ SINGLE_FILE_LOADABLE_CLASSES = {
"ZImageControlNetModel": {
"checkpoint_mapping_fn": convert_z_image_controlnet_checkpoint_to_diffusers,
},
"LTX2VideoTransformer3DModel": {
"checkpoint_mapping_fn": convert_ltx2_transformer_to_diffusers,
"default_subfolder": "transformer",
},
"AutoencoderKLLTX2Video": {
"checkpoint_mapping_fn": convert_ltx2_vae_to_diffusers,
"default_subfolder": "vae",
},
"AutoencoderKLLTX2Audio": {
"checkpoint_mapping_fn": convert_ltx2_audio_vae_to_diffusers,
"default_subfolder": "audio_vae",
},
}

View File

@@ -112,7 +112,8 @@ CHECKPOINT_KEY_NAMES = {
"model.diffusion_model.transformer_blocks.27.scale_shift_table",
"patchify_proj.weight",
"transformer_blocks.27.scale_shift_table",
"vae.per_channel_statistics.mean-of-means",
"vae.decoder.last_scale_shift_table", # 0.9.1, 0.9.5, 0.9.7, 0.9.8
"vae.decoder.up_blocks.9.res_blocks.0.conv1.conv.weight", # 0.9.0
],
"autoencoder-dc": "decoder.stages.1.op_list.0.main.conv.conv.bias",
"autoencoder-dc-sana": "encoder.project_in.conv.bias",
@@ -147,6 +148,11 @@ CHECKPOINT_KEY_NAMES = {
"net.pos_embedder.dim_spatial_range",
],
"flux2": ["model.diffusion_model.single_stream_modulation.lin.weight", "single_stream_modulation.lin.weight"],
"ltx2": [
"model.diffusion_model.av_ca_a2v_gate_adaln_single.emb.timestep_embedder.linear_1.weight",
"vae.per_channel_statistics.mean-of-means",
"audio_vae.per_channel_statistics.mean-of-means",
],
}
DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
@@ -228,6 +234,7 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
"z-image-turbo-controlnet": {"pretrained_model_name_or_path": "hlky/Z-Image-Turbo-Fun-Controlnet-Union"},
"z-image-turbo-controlnet-2.0": {"pretrained_model_name_or_path": "hlky/Z-Image-Turbo-Fun-Controlnet-Union-2.0"},
"z-image-turbo-controlnet-2.1": {"pretrained_model_name_or_path": "hlky/Z-Image-Turbo-Fun-Controlnet-Union-2.1"},
"ltx2-dev": {"pretrained_model_name_or_path": "Lightricks/LTX-2"},
}
# Use to configure model sample size when original config is provided
@@ -796,6 +803,9 @@ def infer_diffusers_model_type(checkpoint):
elif CHECKPOINT_KEY_NAMES["z-image-turbo-controlnet"] in checkpoint:
model_type = "z-image-turbo-controlnet"
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["ltx2"]):
model_type = "ltx2-dev"
else:
model_type = "v1"
@@ -3920,3 +3930,161 @@ def convert_z_image_controlnet_checkpoint_to_diffusers(checkpoint, config, **kwa
return converted_state_dict
else:
raise ValueError("Unknown Z-Image Turbo ControlNet type.")
def convert_ltx2_transformer_to_diffusers(checkpoint, **kwargs):
LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT = {
# Transformer prefix
"model.diffusion_model.": "",
# Input Patchify Projections
"patchify_proj": "proj_in",
"audio_patchify_proj": "audio_proj_in",
# Modulation Parameters
# Handle adaln_single --> time_embed, audioln_single --> audio_time_embed separately as the original keys are
# substrings of the other modulation parameters below
"av_ca_video_scale_shift_adaln_single": "av_cross_attn_video_scale_shift",
"av_ca_a2v_gate_adaln_single": "av_cross_attn_video_a2v_gate",
"av_ca_audio_scale_shift_adaln_single": "av_cross_attn_audio_scale_shift",
"av_ca_v2a_gate_adaln_single": "av_cross_attn_audio_v2a_gate",
# Transformer Blocks
# Per-Block Cross Attention Modulation Parameters
"scale_shift_table_a2v_ca_video": "video_a2v_cross_attn_scale_shift_table",
"scale_shift_table_a2v_ca_audio": "audio_a2v_cross_attn_scale_shift_table",
# Attention QK Norms
"q_norm": "norm_q",
"k_norm": "norm_k",
}
def update_state_dict_inplace(state_dict, old_key: str, new_key: str) -> None:
state_dict[new_key] = state_dict.pop(old_key)
def remove_keys_inplace(key: str, state_dict) -> None:
state_dict.pop(key)
def convert_ltx2_transformer_adaln_single(key: str, state_dict) -> None:
# Skip if not a weight, bias
if ".weight" not in key and ".bias" not in key:
return
if key.startswith("adaln_single."):
new_key = key.replace("adaln_single.", "time_embed.")
param = state_dict.pop(key)
state_dict[new_key] = param
if key.startswith("audio_adaln_single."):
new_key = key.replace("audio_adaln_single.", "audio_time_embed.")
param = state_dict.pop(key)
state_dict[new_key] = param
return
LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP = {
"video_embeddings_connector": remove_keys_inplace,
"audio_embeddings_connector": remove_keys_inplace,
"adaln_single": convert_ltx2_transformer_adaln_single,
}
converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys())}
# Handle official code --> diffusers key remapping via the remap dict
for key in list(converted_state_dict.keys()):
new_key = key[:]
for replace_key, rename_key in LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict_inplace(converted_state_dict, key, new_key)
# Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
# special_keys_remap
for key in list(converted_state_dict.keys()):
for special_key, handler_fn_inplace in LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP.items():
if special_key not in key:
continue
handler_fn_inplace(key, converted_state_dict)
return converted_state_dict
def convert_ltx2_vae_to_diffusers(checkpoint, **kwargs):
LTX_2_0_VIDEO_VAE_RENAME_DICT = {
# Video VAE prefix
"vae.": "",
# Encoder
"down_blocks.0": "down_blocks.0",
"down_blocks.1": "down_blocks.0.downsamplers.0",
"down_blocks.2": "down_blocks.1",
"down_blocks.3": "down_blocks.1.downsamplers.0",
"down_blocks.4": "down_blocks.2",
"down_blocks.5": "down_blocks.2.downsamplers.0",
"down_blocks.6": "down_blocks.3",
"down_blocks.7": "down_blocks.3.downsamplers.0",
"down_blocks.8": "mid_block",
# Decoder
"up_blocks.0": "mid_block",
"up_blocks.1": "up_blocks.0.upsamplers.0",
"up_blocks.2": "up_blocks.0",
"up_blocks.3": "up_blocks.1.upsamplers.0",
"up_blocks.4": "up_blocks.1",
"up_blocks.5": "up_blocks.2.upsamplers.0",
"up_blocks.6": "up_blocks.2",
# Common
# For all 3D ResNets
"res_blocks": "resnets",
"per_channel_statistics.mean-of-means": "latents_mean",
"per_channel_statistics.std-of-means": "latents_std",
}
def update_state_dict_inplace(state_dict, old_key: str, new_key: str) -> None:
state_dict[new_key] = state_dict.pop(old_key)
def remove_keys_inplace(key: str, state_dict) -> None:
state_dict.pop(key)
LTX_2_0_VAE_SPECIAL_KEYS_REMAP = {
"per_channel_statistics.channel": remove_keys_inplace,
"per_channel_statistics.mean-of-stds": remove_keys_inplace,
}
converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys())}
# Handle official code --> diffusers key remapping via the remap dict
for key in list(converted_state_dict.keys()):
new_key = key[:]
for replace_key, rename_key in LTX_2_0_VIDEO_VAE_RENAME_DICT.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict_inplace(converted_state_dict, key, new_key)
# Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
# special_keys_remap
for key in list(converted_state_dict.keys()):
for special_key, handler_fn_inplace in LTX_2_0_VAE_SPECIAL_KEYS_REMAP.items():
if special_key not in key:
continue
handler_fn_inplace(key, converted_state_dict)
return converted_state_dict
def convert_ltx2_audio_vae_to_diffusers(checkpoint, **kwargs):
LTX_2_0_AUDIO_VAE_RENAME_DICT = {
# Audio VAE prefix
"audio_vae.": "",
"per_channel_statistics.mean-of-means": "latents_mean",
"per_channel_statistics.std-of-means": "latents_std",
}
def update_state_dict_inplace(state_dict, old_key: str, new_key: str) -> None:
state_dict[new_key] = state_dict.pop(old_key)
converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys())}
# Handle official code --> diffusers key remapping via the remap dict
for key in list(converted_state_dict.keys()):
new_key = key[:]
for replace_key, rename_key in LTX_2_0_AUDIO_VAE_RENAME_DICT.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict_inplace(converted_state_dict, key, new_key)
return converted_state_dict

View File

@@ -585,7 +585,13 @@ class Flux2PosEmbed(nn.Module):
class Flux2TimestepGuidanceEmbeddings(nn.Module):
def __init__(self, in_channels: int = 256, embedding_dim: int = 6144, bias: bool = False):
def __init__(
self,
in_channels: int = 256,
embedding_dim: int = 6144,
bias: bool = False,
guidance_embeds: bool = True,
):
super().__init__()
self.time_proj = Timesteps(num_channels=in_channels, flip_sin_to_cos=True, downscale_freq_shift=0)
@@ -593,20 +599,24 @@ class Flux2TimestepGuidanceEmbeddings(nn.Module):
in_channels=in_channels, time_embed_dim=embedding_dim, sample_proj_bias=bias
)
self.guidance_embedder = TimestepEmbedding(
in_channels=in_channels, time_embed_dim=embedding_dim, sample_proj_bias=bias
)
if guidance_embeds:
self.guidance_embedder = TimestepEmbedding(
in_channels=in_channels, time_embed_dim=embedding_dim, sample_proj_bias=bias
)
else:
self.guidance_embedder = None
def forward(self, timestep: torch.Tensor, guidance: torch.Tensor) -> torch.Tensor:
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(timestep.dtype)) # (N, D)
guidance_proj = self.time_proj(guidance)
guidance_emb = self.guidance_embedder(guidance_proj.to(guidance.dtype)) # (N, D)
time_guidance_emb = timesteps_emb + guidance_emb
return time_guidance_emb
if guidance is not None and self.guidance_embedder is not None:
guidance_proj = self.time_proj(guidance)
guidance_emb = self.guidance_embedder(guidance_proj.to(guidance.dtype)) # (N, D)
time_guidance_emb = timesteps_emb + guidance_emb
return time_guidance_emb
else:
return timesteps_emb
class Flux2Modulation(nn.Module):
@@ -698,6 +708,7 @@ class Flux2Transformer2DModel(
axes_dims_rope: Tuple[int, ...] = (32, 32, 32, 32),
rope_theta: int = 2000,
eps: float = 1e-6,
guidance_embeds: bool = True,
):
super().__init__()
self.out_channels = out_channels or in_channels
@@ -708,7 +719,10 @@ class Flux2Transformer2DModel(
# 2. Combined timestep + guidance embedding
self.time_guidance_embed = Flux2TimestepGuidanceEmbeddings(
in_channels=timestep_guidance_channels, embedding_dim=self.inner_dim, bias=False
in_channels=timestep_guidance_channels,
embedding_dim=self.inner_dim,
bias=False,
guidance_embeds=guidance_embeds,
)
# 3. Modulation (double stream and single stream blocks share modulation parameters, resp.)
@@ -815,7 +829,9 @@ class Flux2Transformer2DModel(
# 1. Calculate timestep embedding and modulation parameters
timestep = timestep.to(hidden_states.dtype) * 1000
guidance = guidance.to(hidden_states.dtype) * 1000
if guidance is not None:
guidance = guidance.to(hidden_states.dtype) * 1000
temb = self.time_guidance_embed(timestep, guidance)

View File

@@ -23,10 +23,18 @@ logger = logging.getLogger(__name__)
@dataclass(frozen=True)
class MellonParam:
"""
Parameter definition for Mellon nodes.
Parameter definition for Mellon nodes.
Use factory methods for common params (e.g., MellonParam.seed()) or create custom ones with MellonParam(name="...",
label="...", type="...").
Use factory methods for common params (e.g., MellonParam.seed()) or create custom ones with
MellonParam(name="...", label="...", type="...").
Example:
```python
# Custom param
MellonParam(name="my_param", label="My Param", type="float", default=0.5)
# Output in Mellon node definition:
# "my_param": {"label": "My Param", "type": "float", "default": 0.5}
```
"""
name: str
@@ -51,14 +59,32 @@ class MellonParam:
@classmethod
def image(cls) -> "MellonParam":
"""
Image input parameter.
Mellon node definition:
"image": {"label": "Image", "type": "image", "display": "input"}
"""
return cls(name="image", label="Image", type="image", display="input", required_block_params=["image"])
@classmethod
def images(cls) -> "MellonParam":
"""
Images output parameter.
Mellon node definition:
"images": {"label": "Images", "type": "image", "display": "output"}
"""
return cls(name="images", label="Images", type="image", display="output", required_block_params=["images"])
@classmethod
def control_image(cls, display: str = "input") -> "MellonParam":
"""
Control image parameter for ControlNet.
Mellon node definition (display="input"):
"control_image": {"label": "Control Image", "type": "image", "display": "input"}
"""
return cls(
name="control_image",
label="Control Image",
@@ -69,10 +95,25 @@ class MellonParam:
@classmethod
def latents(cls, display: str = "input") -> "MellonParam":
"""
Latents parameter.
Mellon node definition (display="input"):
"latents": {"label": "Latents", "type": "latents", "display": "input"}
Mellon node definition (display="output"):
"latents": {"label": "Latents", "type": "latents", "display": "output"}
"""
return cls(name="latents", label="Latents", type="latents", display=display, required_block_params=["latents"])
@classmethod
def image_latents(cls, display: str = "input") -> "MellonParam":
"""
Image latents parameter for img2img workflows.
Mellon node definition (display="input"):
"image_latents": {"label": "Image Latents", "type": "latents", "display": "input"}
"""
return cls(
name="image_latents",
label="Image Latents",
@@ -83,6 +124,12 @@ class MellonParam:
@classmethod
def first_frame_latents(cls, display: str = "input") -> "MellonParam":
"""
First frame latents for video generation.
Mellon node definition (display="input"):
"first_frame_latents": {"label": "First Frame Latents", "type": "latents", "display": "input"}
"""
return cls(
name="first_frame_latents",
label="First Frame Latents",
@@ -93,6 +140,16 @@ class MellonParam:
@classmethod
def image_latents_with_strength(cls) -> "MellonParam":
"""
Image latents with strength-based onChange behavior. When connected, shows strength slider; when disconnected,
shows height/width.
Mellon node definition:
"image_latents": {
"label": "Image Latents", "type": "latents", "display": "input", "onChange": {"false": ["height",
"width"], "true": ["strength"]}
}
"""
return cls(
name="image_latents",
label="Image Latents",
@@ -105,16 +162,34 @@ class MellonParam:
@classmethod
def latents_preview(cls) -> "MellonParam":
"""
`Latents Preview` is a special output parameter that is used to preview the latents in the UI.
Latents preview output for visualizing latents in the UI.
Mellon node definition:
"latents_preview": {"label": "Latents Preview", "type": "latent", "display": "output"}
"""
return cls(name="latents_preview", label="Latents Preview", type="latent", display="output")
@classmethod
def embeddings(cls, display: str = "output") -> "MellonParam":
"""
Text embeddings parameter.
Mellon node definition (display="output"):
"embeddings": {"label": "Text Embeddings", "type": "embeddings", "display": "output"}
Mellon node definition (display="input"):
"embeddings": {"label": "Text Embeddings", "type": "embeddings", "display": "input"}
"""
return cls(name="embeddings", label="Text Embeddings", type="embeddings", display=display)
@classmethod
def image_embeds(cls, display: str = "output") -> "MellonParam":
"""
Image embeddings parameter for IP-Adapter workflows.
Mellon node definition (display="output"):
"image_embeds": {"label": "Image Embeddings", "type": "image_embeds", "display": "output"}
"""
return cls(
name="image_embeds",
label="Image Embeddings",
@@ -125,6 +200,15 @@ class MellonParam:
@classmethod
def controlnet_conditioning_scale(cls, default: float = 0.5) -> "MellonParam":
"""
ControlNet conditioning scale slider.
Mellon node definition (default=0.5):
"controlnet_conditioning_scale": {
"label": "Controlnet Conditioning Scale", "type": "float", "default": 0.5, "min": 0.0, "max": 1.0,
"step": 0.01
}
"""
return cls(
name="controlnet_conditioning_scale",
label="Controlnet Conditioning Scale",
@@ -138,6 +222,15 @@ class MellonParam:
@classmethod
def control_guidance_start(cls, default: float = 0.0) -> "MellonParam":
"""
Control guidance start timestep.
Mellon node definition (default=0.0):
"control_guidance_start": {
"label": "Control Guidance Start", "type": "float", "default": 0.0, "min": 0.0, "max": 1.0, "step":
0.01
}
"""
return cls(
name="control_guidance_start",
label="Control Guidance Start",
@@ -151,6 +244,14 @@ class MellonParam:
@classmethod
def control_guidance_end(cls, default: float = 1.0) -> "MellonParam":
"""
Control guidance end timestep.
Mellon node definition (default=1.0):
"control_guidance_end": {
"label": "Control Guidance End", "type": "float", "default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01
}
"""
return cls(
name="control_guidance_end",
label="Control Guidance End",
@@ -164,6 +265,12 @@ class MellonParam:
@classmethod
def prompt(cls, default: str = "") -> "MellonParam":
"""
Text prompt input as textarea.
Mellon node definition (default=""):
"prompt": {"label": "Prompt", "type": "string", "default": "", "display": "textarea"}
"""
return cls(
name="prompt",
label="Prompt",
@@ -175,6 +282,12 @@ class MellonParam:
@classmethod
def negative_prompt(cls, default: str = "") -> "MellonParam":
"""
Negative prompt input as textarea.
Mellon node definition (default=""):
"negative_prompt": {"label": "Negative Prompt", "type": "string", "default": "", "display": "textarea"}
"""
return cls(
name="negative_prompt",
label="Negative Prompt",
@@ -186,6 +299,12 @@ class MellonParam:
@classmethod
def strength(cls, default: float = 0.5) -> "MellonParam":
"""
Denoising strength for img2img.
Mellon node definition (default=0.5):
"strength": {"label": "Strength", "type": "float", "default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}
"""
return cls(
name="strength",
label="Strength",
@@ -199,6 +318,15 @@ class MellonParam:
@classmethod
def guidance_scale(cls, default: float = 5.0) -> "MellonParam":
"""
CFG guidance scale slider.
Mellon node definition (default=5.0):
"guidance_scale": {
"label": "Guidance Scale", "type": "float", "display": "slider", "default": 5.0, "min": 1.0, "max":
30.0, "step": 0.1
}
"""
return cls(
name="guidance_scale",
label="Guidance Scale",
@@ -212,6 +340,12 @@ class MellonParam:
@classmethod
def height(cls, default: int = 1024) -> "MellonParam":
"""
Image height in pixels.
Mellon node definition (default=1024):
"height": {"label": "Height", "type": "int", "default": 1024, "min": 64, "step": 8}
"""
return cls(
name="height",
label="Height",
@@ -224,12 +358,26 @@ class MellonParam:
@classmethod
def width(cls, default: int = 1024) -> "MellonParam":
"""
Image width in pixels.
Mellon node definition (default=1024):
"width": {"label": "Width", "type": "int", "default": 1024, "min": 64, "step": 8}
"""
return cls(
name="width", label="Width", type="int", default=default, min=64, step=8, required_block_params=["width"]
)
@classmethod
def seed(cls, default: int = 0) -> "MellonParam":
"""
Random seed with randomize button.
Mellon node definition (default=0):
"seed": {
"label": "Seed", "type": "int", "default": 0, "min": 0, "max": 4294967295, "display": "random"
}
"""
return cls(
name="seed",
label="Seed",
@@ -243,6 +391,14 @@ class MellonParam:
@classmethod
def num_inference_steps(cls, default: int = 25) -> "MellonParam":
"""
Number of denoising steps slider.
Mellon node definition (default=25):
"num_inference_steps": {
"label": "Steps", "type": "int", "default": 25, "min": 1, "max": 100, "display": "slider"
}
"""
return cls(
name="num_inference_steps",
label="Steps",
@@ -256,6 +412,12 @@ class MellonParam:
@classmethod
def num_frames(cls, default: int = 81) -> "MellonParam":
"""
Number of video frames slider.
Mellon node definition (default=81):
"num_frames": {"label": "Frames", "type": "int", "default": 81, "min": 1, "max": 480, "display": "slider"}
"""
return cls(
name="num_frames",
label="Frames",
@@ -269,6 +431,12 @@ class MellonParam:
@classmethod
def layers(cls, default: int = 4) -> "MellonParam":
"""
Number of layers slider (for layered diffusion).
Mellon node definition (default=4):
"layers": {"label": "Layers", "type": "int", "default": 4, "min": 1, "max": 10, "display": "slider"}
"""
return cls(
name="layers",
label="Layers",
@@ -282,15 +450,24 @@ class MellonParam:
@classmethod
def videos(cls) -> "MellonParam":
"""
Video output parameter.
Mellon node definition:
"videos": {"label": "Videos", "type": "video", "display": "output"}
"""
return cls(name="videos", label="Videos", type="video", display="output", required_block_params=["videos"])
@classmethod
def vae(cls) -> "MellonParam":
"""
VAE model info dict.
VAE model input.
Contains keys like 'model_id', 'repo_id', 'execution_device' etc. Use components.get_one(model_id) to retrieve
the actual model.
Mellon node definition:
"vae": {"label": "VAE", "type": "diffusers_auto_model", "display": "input"}
Note: The value received is a model info dict with keys like 'model_id', 'repo_id', 'execution_device'. Use
components.get_one(model_id) to retrieve the actual model.
"""
return cls(
name="vae", label="VAE", type="diffusers_auto_model", display="input", required_block_params=["vae"]
@@ -299,10 +476,13 @@ class MellonParam:
@classmethod
def image_encoder(cls) -> "MellonParam":
"""
Image Encoder model info dict.
Image encoder model input.
Contains keys like 'model_id', 'repo_id', 'execution_device' etc. Use components.get_one(model_id) to retrieve
the actual model.
Mellon node definition:
"image_encoder": {"label": "Image Encoder", "type": "diffusers_auto_model", "display": "input"}
Note: The value received is a model info dict with keys like 'model_id', 'repo_id', 'execution_device'. Use
components.get_one(model_id) to retrieve the actual model.
"""
return cls(
name="image_encoder",
@@ -315,30 +495,39 @@ class MellonParam:
@classmethod
def unet(cls) -> "MellonParam":
"""
Denoising model (UNet/Transformer) info dict.
Denoising model (UNet/Transformer) input.
Contains keys like 'model_id', 'repo_id', 'execution_device' etc. Use components.get_one(model_id) to retrieve
the actual model.
Mellon node definition:
"unet": {"label": "Denoise Model", "type": "diffusers_auto_model", "display": "input"}
Note: The value received is a model info dict with keys like 'model_id', 'repo_id', 'execution_device'. Use
components.get_one(model_id) to retrieve the actual model.
"""
return cls(name="unet", label="Denoise Model", type="diffusers_auto_model", display="input")
@classmethod
def scheduler(cls) -> "MellonParam":
"""
Scheduler model info dict.
Scheduler model input.
Contains keys like 'model_id', 'repo_id' etc. Use components.get_one(model_id) to retrieve the actual
scheduler.
Mellon node definition:
"scheduler": {"label": "Scheduler", "type": "diffusers_auto_model", "display": "input"}
Note: The value received is a model info dict with keys like 'model_id', 'repo_id'. Use
components.get_one(model_id) to retrieve the actual scheduler.
"""
return cls(name="scheduler", label="Scheduler", type="diffusers_auto_model", display="input")
@classmethod
def controlnet(cls) -> "MellonParam":
"""
ControlNet model info dict.
ControlNet model input.
Contains keys like 'model_id', 'repo_id', 'execution_device' etc. Use components.get_one(model_id) to retrieve
the actual model.
Mellon node definition:
"controlnet": {"label": "ControlNet Model", "type": "diffusers_auto_model", "display": "input"}
Note: The value received is a model info dict with keys like 'model_id', 'repo_id', 'execution_device'. Use
components.get_one(model_id) to retrieve the actual model.
"""
return cls(
name="controlnet",
@@ -351,12 +540,17 @@ class MellonParam:
@classmethod
def text_encoders(cls) -> "MellonParam":
"""
Dict of text encoder model info dicts.
Text encoders dict input (multiple encoders).
Structure: {
'text_encoder': {'model_id': ..., 'execution_device': ..., ...}, 'tokenizer': {'model_id': ..., ...},
'repo_id': '...'
} Use components.get_one(model_id) to retrieve each model.
Mellon node definition:
"text_encoders": {"label": "Text Encoders", "type": "diffusers_auto_models", "display": "input"}
Note: The value received is a dict of model info dicts:
{
'text_encoder': {'model_id': ..., 'execution_device': ..., ...}, 'tokenizer': {'model_id': ..., ...},
'repo_id': '...'
}
Use components.get_one(model_id) to retrieve each model.
"""
return cls(
name="text_encoders",
@@ -369,15 +563,20 @@ class MellonParam:
@classmethod
def controlnet_bundle(cls, display: str = "input") -> "MellonParam":
"""
ControlNet bundle containing model info and processed control inputs.
ControlNet bundle containing model and processed control inputs. Output from ControlNet node, input to Denoise
node.
Structure: {
'controlnet': {'model_id': ..., ...}, # controlnet model info dict 'control_image': ..., # processed
control image/embeddings 'controlnet_conditioning_scale': ..., ... # other inputs expected by denoise
blocks
}
Mellon node definition (display="input"):
"controlnet_bundle": {"label": "ControlNet", "type": "custom_controlnet", "display": "input"}
Output from Controlnet node, input to Denoise node.
Mellon node definition (display="output"):
"controlnet_bundle": {"label": "ControlNet", "type": "custom_controlnet", "display": "output"}
Note: The value is a dict containing:
{
'controlnet': {'model_id': ..., ...}, # controlnet model info 'control_image': ..., # processed control
image/embeddings 'controlnet_conditioning_scale': ..., # and other denoise block inputs
}
"""
return cls(
name="controlnet_bundle",
@@ -389,10 +588,25 @@ class MellonParam:
@classmethod
def ip_adapter(cls) -> "MellonParam":
"""
IP-Adapter input.
Mellon node definition:
"ip_adapter": {"label": "IP Adapter", "type": "custom_ip_adapter", "display": "input"}
"""
return cls(name="ip_adapter", label="IP Adapter", type="custom_ip_adapter", display="input")
@classmethod
def guider(cls) -> "MellonParam":
"""
Custom guider input. When connected, hides the guidance_scale slider.
Mellon node definition:
"guider": {
"label": "Guider", "type": "custom_guider", "display": "input", "onChange": {false: ["guidance_scale"],
true: []}
}
"""
return cls(
name="guider",
label="Guider",
@@ -403,6 +617,12 @@ class MellonParam:
@classmethod
def doc(cls) -> "MellonParam":
"""
Documentation output for inspecting the underlying modular pipeline.
Mellon node definition:
"doc": {"label": "Doc", "type": "string", "display": "output"}
"""
return cls(name="doc", label="Doc", type="string", display="output")
@@ -415,6 +635,7 @@ DEFAULT_NODE_SPECS = {
MellonParam.height(),
MellonParam.seed(),
MellonParam.num_inference_steps(),
MellonParam.num_frames(),
MellonParam.guidance_scale(),
MellonParam.strength(),
MellonParam.image_latents_with_strength(),
@@ -669,6 +890,9 @@ class MellonPipelineConfig:
@property
def node_params(self) -> Dict[str, Any]:
"""Lazily compute node_params from node_specs."""
if self.node_specs is None:
return self._node_params
params = {}
for node_type, spec in self.node_specs.items():
if spec is None:
@@ -711,7 +935,8 @@ class MellonPipelineConfig:
Note: The mellon_params are already in Mellon format when loading from JSON.
"""
instance = cls.__new__(cls)
instance.node_params = data.get("node_params", {})
instance.node_specs = None
instance._node_params = data.get("node_params", {})
instance.label = data.get("label", "")
instance.default_repo = data.get("default_repo", "")
instance.default_dtype = data.get("default_dtype", "")

View File

@@ -130,7 +130,7 @@ else:
]
_import_structure["bria"] = ["BriaPipeline"]
_import_structure["bria_fibo"] = ["BriaFiboPipeline"]
_import_structure["flux2"] = ["Flux2Pipeline"]
_import_structure["flux2"] = ["Flux2Pipeline", "Flux2KleinPipeline"]
_import_structure["flux"] = [
"FluxControlPipeline",
"FluxControlInpaintPipeline",
@@ -678,7 +678,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
FluxPriorReduxPipeline,
ReduxImageEncoder,
)
from .flux2 import Flux2Pipeline
from .flux2 import Flux2KleinPipeline, Flux2Pipeline
from .glm_image import GlmImagePipeline
from .hidream_image import HiDreamImagePipeline
from .hunyuan_image import HunyuanImagePipeline, HunyuanImageRefinerPipeline

View File

@@ -52,6 +52,7 @@ from .flux import (
FluxKontextPipeline,
FluxPipeline,
)
from .flux2 import Flux2KleinPipeline, Flux2Pipeline
from .glm_image import GlmImagePipeline
from .hunyuandit import HunyuanDiTPipeline
from .kandinsky import (
@@ -164,6 +165,8 @@ AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
("flux-control", FluxControlPipeline),
("flux-controlnet", FluxControlNetPipeline),
("flux-kontext", FluxKontextPipeline),
("flux2-klein", Flux2KleinPipeline),
("flux2", Flux2Pipeline),
("lumina", LuminaPipeline),
("lumina2", Lumina2Pipeline),
("chroma", ChromaPipeline),
@@ -202,6 +205,8 @@ AUTO_IMAGE2IMAGE_PIPELINES_MAPPING = OrderedDict(
("flux-controlnet", FluxControlNetImg2ImgPipeline),
("flux-control", FluxControlImg2ImgPipeline),
("flux-kontext", FluxKontextPipeline),
("flux2-klein", Flux2KleinPipeline),
("flux2", Flux2Pipeline),
("qwenimage", QwenImageImg2ImgPipeline),
("qwenimage-edit", QwenImageEditPipeline),
("qwenimage-edit-plus", QwenImageEditPlusPipeline),

View File

@@ -23,6 +23,7 @@ except OptionalDependencyNotAvailable:
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_flux2"] = ["Flux2Pipeline"]
_import_structure["pipeline_flux2_klein"] = ["Flux2KleinPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
@@ -31,6 +32,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
else:
from .pipeline_flux2 import Flux2Pipeline
from .pipeline_flux2_klein import Flux2KleinPipeline
else:
import sys

View File

@@ -725,8 +725,8 @@ class Flux2Pipeline(DiffusionPipeline, Flux2LoraLoaderMixin):
return self._guidance_scale
@property
def joint_attention_kwargs(self):
return self._joint_attention_kwargs
def attention_kwargs(self):
return self._attention_kwargs
@property
def num_timesteps(self):
@@ -975,7 +975,7 @@ class Flux2Pipeline(DiffusionPipeline, Flux2LoraLoaderMixin):
encoder_hidden_states=prompt_embeds,
txt_ids=text_ids, # B, text_seq_len, 4
img_ids=latent_image_ids, # B, image_seq_len, 4
joint_attention_kwargs=self._attention_kwargs,
joint_attention_kwargs=self.attention_kwargs,
return_dict=False,
)[0]

View File

@@ -0,0 +1,918 @@
# Copyright 2025 Black Forest Labs and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np
import PIL
import torch
from transformers import Qwen2TokenizerFast, Qwen3ForCausalLM
from ...loaders import Flux2LoraLoaderMixin
from ...models import AutoencoderKLFlux2, Flux2Transformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
from .image_processor import Flux2ImageProcessor
from .pipeline_output import Flux2PipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> import torch
>>> from diffusers import Flux2KleinPipeline
>>> pipe = Flux2KleinPipeline.from_pretrained(
... "black-forest-labs/FLUX.2-klein-base-9B", torch_dtype=torch.bfloat16
... )
>>> pipe.to("cuda")
>>> prompt = "A cat holding a sign that says hello world"
>>> # Depending on the variant being used, the pipeline call will slightly vary.
>>> # Refer to the pipeline documentation for more details.
>>> image = pipe(prompt, num_inference_steps=50, guidance_scale=4.0).images[0]
>>> image.save("flux2_output.png")
```
"""
# Copied from diffusers.pipelines.flux2.pipeline_flux2.compute_empirical_mu
def compute_empirical_mu(image_seq_len: int, num_steps: int) -> float:
a1, b1 = 8.73809524e-05, 1.89833333
a2, b2 = 0.00016927, 0.45666666
if image_seq_len > 4300:
mu = a2 * image_seq_len + b2
return float(mu)
m_200 = a2 * image_seq_len + b2
m_10 = a1 * image_seq_len + b1
a = (m_200 - m_10) / 190.0
b = m_200 - 200.0 * a
mu = a * num_steps + b
return float(mu)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs,
):
r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
`num_inference_steps` and `sigmas` must be `None`.
sigmas (`List[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")
class Flux2KleinPipeline(DiffusionPipeline, Flux2LoraLoaderMixin):
r"""
The Flux2 Klein pipeline for text-to-image generation.
Reference:
[https://bfl.ai/blog/flux2-klein-towards-interactive-visual-intelligence](https://bfl.ai/blog/flux2-klein-towards-interactive-visual-intelligence)
Args:
transformer ([`Flux2Transformer2DModel`]):
Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
vae ([`AutoencoderKLFlux2`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`Qwen3ForCausalLM`]):
[Qwen3ForCausalLM](https://huggingface.co/docs/transformers/en/model_doc/qwen3#transformers.Qwen3ForCausalLM)
tokenizer (`Qwen2TokenizerFast`):
Tokenizer of class
[Qwen2TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/qwen2#transformers.Qwen2TokenizerFast).
"""
model_cpu_offload_seq = "text_encoder->transformer->vae"
_callback_tensor_inputs = ["latents", "prompt_embeds"]
def __init__(
self,
scheduler: FlowMatchEulerDiscreteScheduler,
vae: AutoencoderKLFlux2,
text_encoder: Qwen3ForCausalLM,
tokenizer: Qwen2TokenizerFast,
transformer: Flux2Transformer2DModel,
is_distilled: bool = False,
):
super().__init__()
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
scheduler=scheduler,
transformer=transformer,
)
self.register_to_config(is_distilled=is_distilled)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
# Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this
self.image_processor = Flux2ImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
self.tokenizer_max_length = 512
self.default_sample_size = 128
@staticmethod
def _get_qwen3_prompt_embeds(
text_encoder: Qwen3ForCausalLM,
tokenizer: Qwen2TokenizerFast,
prompt: Union[str, List[str]],
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
max_sequence_length: int = 512,
hidden_states_layers: List[int] = (9, 18, 27),
):
dtype = text_encoder.dtype if dtype is None else dtype
device = text_encoder.device if device is None else device
prompt = [prompt] if isinstance(prompt, str) else prompt
all_input_ids = []
all_attention_masks = []
for single_prompt in prompt:
messages = [{"role": "user", "content": single_prompt}]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=False,
)
inputs = tokenizer(
text,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=max_sequence_length,
)
all_input_ids.append(inputs["input_ids"])
all_attention_masks.append(inputs["attention_mask"])
input_ids = torch.cat(all_input_ids, dim=0).to(device)
attention_mask = torch.cat(all_attention_masks, dim=0).to(device)
# Forward pass through the model
output = text_encoder(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True,
use_cache=False,
)
# Only use outputs from intermediate layers and stack them
out = torch.stack([output.hidden_states[k] for k in hidden_states_layers], dim=1)
out = out.to(dtype=dtype, device=device)
batch_size, num_channels, seq_len, hidden_dim = out.shape
prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim)
return prompt_embeds
@staticmethod
# Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._prepare_text_ids
def _prepare_text_ids(
x: torch.Tensor, # (B, L, D) or (L, D)
t_coord: Optional[torch.Tensor] = None,
):
B, L, _ = x.shape
out_ids = []
for i in range(B):
t = torch.arange(1) if t_coord is None else t_coord[i]
h = torch.arange(1)
w = torch.arange(1)
l = torch.arange(L)
coords = torch.cartesian_prod(t, h, w, l)
out_ids.append(coords)
return torch.stack(out_ids)
@staticmethod
# Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._prepare_latent_ids
def _prepare_latent_ids(
latents: torch.Tensor, # (B, C, H, W)
):
r"""
Generates 4D position coordinates (T, H, W, L) for latent tensors.
Args:
latents (torch.Tensor):
Latent tensor of shape (B, C, H, W)
Returns:
torch.Tensor:
Position IDs tensor of shape (B, H*W, 4) All batches share the same coordinate structure: T=0,
H=[0..H-1], W=[0..W-1], L=0
"""
batch_size, _, height, width = latents.shape
t = torch.arange(1) # [0] - time dimension
h = torch.arange(height)
w = torch.arange(width)
l = torch.arange(1) # [0] - layer dimension
# Create position IDs: (H*W, 4)
latent_ids = torch.cartesian_prod(t, h, w, l)
# Expand to batch: (B, H*W, 4)
latent_ids = latent_ids.unsqueeze(0).expand(batch_size, -1, -1)
return latent_ids
@staticmethod
# Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._prepare_image_ids
def _prepare_image_ids(
image_latents: List[torch.Tensor], # [(1, C, H, W), (1, C, H, W), ...]
scale: int = 10,
):
r"""
Generates 4D time-space coordinates (T, H, W, L) for a sequence of image latents.
This function creates a unique coordinate for every pixel/patch across all input latent with different
dimensions.
Args:
image_latents (List[torch.Tensor]):
A list of image latent feature tensors, typically of shape (C, H, W).
scale (int, optional):
A factor used to define the time separation (T-coordinate) between latents. T-coordinate for the i-th
latent is: 'scale + scale * i'. Defaults to 10.
Returns:
torch.Tensor:
The combined coordinate tensor. Shape: (1, N_total, 4) Where N_total is the sum of (H * W) for all
input latents.
Coordinate Components (Dimension 4):
- T (Time): The unique index indicating which latent image the coordinate belongs to.
- H (Height): The row index within that latent image.
- W (Width): The column index within that latent image.
- L (Seq. Length): A sequence length dimension, which is always fixed at 0 (size 1)
"""
if not isinstance(image_latents, list):
raise ValueError(f"Expected `image_latents` to be a list, got {type(image_latents)}.")
# create time offset for each reference image
t_coords = [scale + scale * t for t in torch.arange(0, len(image_latents))]
t_coords = [t.view(-1) for t in t_coords]
image_latent_ids = []
for x, t in zip(image_latents, t_coords):
x = x.squeeze(0)
_, height, width = x.shape
x_ids = torch.cartesian_prod(t, torch.arange(height), torch.arange(width), torch.arange(1))
image_latent_ids.append(x_ids)
image_latent_ids = torch.cat(image_latent_ids, dim=0)
image_latent_ids = image_latent_ids.unsqueeze(0)
return image_latent_ids
@staticmethod
# Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._patchify_latents
def _patchify_latents(latents):
batch_size, num_channels_latents, height, width = latents.shape
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
latents = latents.permute(0, 1, 3, 5, 2, 4)
latents = latents.reshape(batch_size, num_channels_latents * 4, height // 2, width // 2)
return latents
@staticmethod
# Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._unpatchify_latents
def _unpatchify_latents(latents):
batch_size, num_channels_latents, height, width = latents.shape
latents = latents.reshape(batch_size, num_channels_latents // (2 * 2), 2, 2, height, width)
latents = latents.permute(0, 1, 4, 2, 5, 3)
latents = latents.reshape(batch_size, num_channels_latents // (2 * 2), height * 2, width * 2)
return latents
@staticmethod
# Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._pack_latents
def _pack_latents(latents):
"""
pack latents: (batch_size, num_channels, height, width) -> (batch_size, height * width, num_channels)
"""
batch_size, num_channels, height, width = latents.shape
latents = latents.reshape(batch_size, num_channels, height * width).permute(0, 2, 1)
return latents
@staticmethod
# Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._unpack_latents_with_ids
def _unpack_latents_with_ids(x: torch.Tensor, x_ids: torch.Tensor) -> list[torch.Tensor]:
"""
using position ids to scatter tokens into place
"""
x_list = []
for data, pos in zip(x, x_ids):
_, ch = data.shape # noqa: F841
h_ids = pos[:, 1].to(torch.int64)
w_ids = pos[:, 2].to(torch.int64)
h = torch.max(h_ids) + 1
w = torch.max(w_ids) + 1
flat_ids = h_ids * w + w_ids
out = torch.zeros((h * w, ch), device=data.device, dtype=data.dtype)
out.scatter_(0, flat_ids.unsqueeze(1).expand(-1, ch), data)
# reshape from (H * W, C) to (H, W, C) and permute to (C, H, W)
out = out.view(h, w, ch).permute(2, 0, 1)
x_list.append(out)
return torch.stack(x_list, dim=0)
def encode_prompt(
self,
prompt: Union[str, List[str]],
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.Tensor] = None,
max_sequence_length: int = 512,
text_encoder_out_layers: Tuple[int] = (9, 18, 27),
):
device = device or self._execution_device
if prompt is None:
prompt = ""
prompt = [prompt] if isinstance(prompt, str) else prompt
if prompt_embeds is None:
prompt_embeds = self._get_qwen3_prompt_embeds(
text_encoder=self.text_encoder,
tokenizer=self.tokenizer,
prompt=prompt,
device=device,
max_sequence_length=max_sequence_length,
hidden_states_layers=text_encoder_out_layers,
)
batch_size, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
text_ids = self._prepare_text_ids(prompt_embeds)
text_ids = text_ids.to(device)
return prompt_embeds, text_ids
# Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._encode_vae_image
def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
if image.ndim != 4:
raise ValueError(f"Expected image dims 4, got {image.ndim}.")
image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="argmax")
image_latents = self._patchify_latents(image_latents)
latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(image_latents.device, image_latents.dtype)
latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps)
image_latents = (image_latents - latents_bn_mean) / latents_bn_std
return image_latents
# Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline.prepare_latents
def prepare_latents(
self,
batch_size,
num_latents_channels,
height,
width,
dtype,
device,
generator: torch.Generator,
latents: Optional[torch.Tensor] = None,
):
# VAE applies 8x compression on images but we must also account for packing which requires
# latent height and width to be divisible by 2.
height = 2 * (int(height) // (self.vae_scale_factor * 2))
width = 2 * (int(width) // (self.vae_scale_factor * 2))
shape = (batch_size, num_latents_channels * 4, height // 2, width // 2)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
latents = latents.to(device=device, dtype=dtype)
latent_ids = self._prepare_latent_ids(latents)
latent_ids = latent_ids.to(device)
latents = self._pack_latents(latents) # [B, C, H, W] -> [B, H*W, C]
return latents, latent_ids
# Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline.prepare_image_latents
def prepare_image_latents(
self,
images: List[torch.Tensor],
batch_size,
generator: torch.Generator,
device,
dtype,
):
image_latents = []
for image in images:
image = image.to(device=device, dtype=dtype)
imagge_latent = self._encode_vae_image(image=image, generator=generator)
image_latents.append(imagge_latent) # (1, 128, 32, 32)
image_latent_ids = self._prepare_image_ids(image_latents)
# Pack each latent and concatenate
packed_latents = []
for latent in image_latents:
# latent: (1, 128, 32, 32)
packed = self._pack_latents(latent) # (1, 1024, 128)
packed = packed.squeeze(0) # (1024, 128) - remove batch dim
packed_latents.append(packed)
# Concatenate all reference tokens along sequence dimension
image_latents = torch.cat(packed_latents, dim=0) # (N*1024, 128)
image_latents = image_latents.unsqueeze(0) # (1, N*1024, 128)
image_latents = image_latents.repeat(batch_size, 1, 1)
image_latent_ids = image_latent_ids.repeat(batch_size, 1, 1)
image_latent_ids = image_latent_ids.to(device)
return image_latents, image_latent_ids
def check_inputs(
self,
prompt,
height,
width,
prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
guidance_scale=None,
):
if (
height is not None
and height % (self.vae_scale_factor * 2) != 0
or width is not None
and width % (self.vae_scale_factor * 2) != 0
):
logger.warning(
f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
)
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if guidance_scale > 1.0 and self.config.is_distilled:
logger.warning(f"Guidance scale {guidance_scale} is ignored for step-wise distilled models.")
@property
def guidance_scale(self):
return self._guidance_scale
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1 and not self.config.is_distilled
@property
def attention_kwargs(self):
return self._attention_kwargs
@property
def num_timesteps(self):
return self._num_timesteps
@property
def current_timestep(self):
return self._current_timestep
@property
def interrupt(self):
return self._interrupt
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
image: Optional[Union[List[PIL.Image.Image], PIL.Image.Image]] = None,
prompt: Union[str, List[str]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
sigmas: Optional[List[float]] = None,
guidance_scale: Optional[float] = 4.0,
num_images_per_prompt: int = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[Union[str, List[str]]] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
attention_kwargs: Optional[Dict[str, Any]] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 512,
text_encoder_out_layers: Tuple[int] = (9, 18, 27),
):
r"""
Function invoked when calling the pipeline for generation.
Args:
image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
`Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
latents as `image`, but if passing latents directly it is not encoded again.
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
guidance_scale (`float`, *optional*, defaults to 4.0):
Guidance scale as defined in [Classifier-Free Diffusion
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
the text `prompt`, usually at the expense of lower image quality. For step-wise distilled models,
`guidance_scale` is ignored.
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image. This is set to 1024 by default for the best results.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image. This is set to 1024 by default for the best results.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
sigmas (`List[float]`, *optional*):
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. Note that "" is used as the negative prompt in this pipeline.
If not provided, will be generated from "".
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.qwenimage.QwenImagePipelineOutput`] instead of a plain tuple.
attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
text_encoder_out_layers (`Tuple[int]`):
Layer indices to use in the `text_encoder` to derive the final prompt embeddings.
Examples:
Returns:
[`~pipelines.flux2.Flux2PipelineOutput`] or `tuple`: [`~pipelines.flux2.Flux2PipelineOutput`] if
`return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the
generated images.
"""
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt=prompt,
height=height,
width=width,
prompt_embeds=prompt_embeds,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
guidance_scale=guidance_scale,
)
self._guidance_scale = guidance_scale
self._attention_kwargs = attention_kwargs
self._current_timestep = None
self._interrupt = False
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
# 3. prepare text embeddings
prompt_embeds, text_ids = self.encode_prompt(
prompt=prompt,
prompt_embeds=prompt_embeds,
device=device,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
text_encoder_out_layers=text_encoder_out_layers,
)
if self.do_classifier_free_guidance:
negative_prompt = ""
if prompt is not None and isinstance(prompt, list):
negative_prompt = [negative_prompt] * len(prompt)
negative_prompt_embeds, negative_text_ids = self.encode_prompt(
prompt=negative_prompt,
prompt_embeds=negative_prompt_embeds,
device=device,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
text_encoder_out_layers=text_encoder_out_layers,
)
# 4. process images
if image is not None and not isinstance(image, list):
image = [image]
condition_images = None
if image is not None:
for img in image:
self.image_processor.check_image_input(img)
condition_images = []
for img in image:
image_width, image_height = img.size
if image_width * image_height > 1024 * 1024:
img = self.image_processor._resize_to_target_area(img, 1024 * 1024)
image_width, image_height = img.size
multiple_of = self.vae_scale_factor * 2
image_width = (image_width // multiple_of) * multiple_of
image_height = (image_height // multiple_of) * multiple_of
img = self.image_processor.preprocess(img, height=image_height, width=image_width, resize_mode="crop")
condition_images.append(img)
height = height or image_height
width = width or image_width
height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor
# 5. prepare latent variables
num_channels_latents = self.transformer.config.in_channels // 4
latents, latent_ids = self.prepare_latents(
batch_size=batch_size * num_images_per_prompt,
num_latents_channels=num_channels_latents,
height=height,
width=width,
dtype=prompt_embeds.dtype,
device=device,
generator=generator,
latents=latents,
)
image_latents = None
image_latent_ids = None
if condition_images is not None:
image_latents, image_latent_ids = self.prepare_image_latents(
images=condition_images,
batch_size=batch_size * num_images_per_prompt,
generator=generator,
device=device,
dtype=self.vae.dtype,
)
# 6. Prepare timesteps
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
if hasattr(self.scheduler.config, "use_flow_sigmas") and self.scheduler.config.use_flow_sigmas:
sigmas = None
image_seq_len = latents.shape[1]
mu = compute_empirical_mu(image_seq_len=image_seq_len, num_steps=num_inference_steps)
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
device,
sigmas=sigmas,
mu=mu,
)
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
# 7. Denoising loop
# We set the index here to remove DtoH sync, helpful especially during compilation.
# Check out more details here: https://github.com/huggingface/diffusers/pull/11696
self.scheduler.set_begin_index(0)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
self._current_timestep = t
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(latents.dtype)
latent_model_input = latents.to(self.transformer.dtype)
latent_image_ids = latent_ids
if image_latents is not None:
latent_model_input = torch.cat([latents, image_latents], dim=1).to(self.transformer.dtype)
latent_image_ids = torch.cat([latent_ids, image_latent_ids], dim=1)
with self.transformer.cache_context("cond"):
noise_pred = self.transformer(
hidden_states=latent_model_input, # (B, image_seq_len, C)
timestep=timestep / 1000,
guidance=None,
encoder_hidden_states=prompt_embeds,
txt_ids=text_ids, # B, text_seq_len, 4
img_ids=latent_image_ids, # B, image_seq_len, 4
joint_attention_kwargs=self.attention_kwargs,
return_dict=False,
)[0]
noise_pred = noise_pred[:, : latents.size(1) :]
if self.do_classifier_free_guidance:
with self.transformer.cache_context("uncond"):
neg_noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep / 1000,
guidance=None,
encoder_hidden_states=negative_prompt_embeds,
txt_ids=negative_text_ids,
img_ids=latent_image_ids,
joint_attention_kwargs=self._attention_kwargs,
return_dict=False,
)[0]
neg_noise_pred = neg_noise_pred[:, : latents.size(1) :]
noise_pred = neg_noise_pred + guidance_scale * (noise_pred - neg_noise_pred)
# compute the previous noisy sample x_t -> x_t-1
latents_dtype = latents.dtype
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
if latents.dtype != latents_dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
latents = latents.to(latents_dtype)
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
self._current_timestep = None
latents = self._unpack_latents_with_ids(latents, latent_ids)
latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(latents.device, latents.dtype)
latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps).to(
latents.device, latents.dtype
)
latents = latents * latents_bn_std + latents_bn_mean
latents = self._unpatchify_latents(latents)
if output_type == "latent":
image = latents
else:
image = self.vae.decode(latents, return_dict=False)[0]
image = self.image_processor.postprocess(image, output_type=output_type)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (image,)
return Flux2PipelineOutput(images=image)

View File

@@ -947,6 +947,21 @@ class EasyAnimatePipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class Flux2KleinPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class Flux2Pipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]

View File

@@ -32,6 +32,22 @@ warnings.simplefilter(action="ignore", category=FutureWarning)
def pytest_configure(config):
config.addinivalue_line("markers", "big_accelerator: marks tests as requiring big accelerator resources")
config.addinivalue_line("markers", "lora: marks tests for LoRA/PEFT functionality")
config.addinivalue_line("markers", "ip_adapter: marks tests for IP Adapter functionality")
config.addinivalue_line("markers", "training: marks tests for training functionality")
config.addinivalue_line("markers", "attention: marks tests for attention processor functionality")
config.addinivalue_line("markers", "memory: marks tests for memory optimization functionality")
config.addinivalue_line("markers", "cpu_offload: marks tests for CPU offloading functionality")
config.addinivalue_line("markers", "group_offload: marks tests for group offloading functionality")
config.addinivalue_line("markers", "compile: marks tests for torch.compile functionality")
config.addinivalue_line("markers", "single_file: marks tests for single file checkpoint loading")
config.addinivalue_line("markers", "quantization: marks tests for quantization functionality")
config.addinivalue_line("markers", "bitsandbytes: marks tests for BitsAndBytes quantization functionality")
config.addinivalue_line("markers", "quanto: marks tests for Quanto quantization functionality")
config.addinivalue_line("markers", "torchao: marks tests for TorchAO quantization functionality")
config.addinivalue_line("markers", "gguf: marks tests for GGUF quantization functionality")
config.addinivalue_line("markers", "modelopt: marks tests for NVIDIA ModelOpt quantization functionality")
config.addinivalue_line("markers", "context_parallel: marks tests for context parallel inference functionality")
config.addinivalue_line("markers", "slow: mark test as slow")
config.addinivalue_line("markers", "nightly: mark test as nightly")

View File

@@ -0,0 +1,79 @@
from .attention import AttentionTesterMixin
from .cache import (
CacheTesterMixin,
FasterCacheConfigMixin,
FasterCacheTesterMixin,
FirstBlockCacheConfigMixin,
FirstBlockCacheTesterMixin,
PyramidAttentionBroadcastConfigMixin,
PyramidAttentionBroadcastTesterMixin,
)
from .common import BaseModelTesterConfig, ModelTesterMixin
from .compile import TorchCompileTesterMixin
from .ip_adapter import IPAdapterTesterMixin
from .lora import LoraHotSwappingForModelTesterMixin, LoraTesterMixin
from .memory import CPUOffloadTesterMixin, GroupOffloadTesterMixin, LayerwiseCastingTesterMixin, MemoryTesterMixin
from .parallelism import ContextParallelTesterMixin
from .quantization import (
BitsAndBytesCompileTesterMixin,
BitsAndBytesConfigMixin,
BitsAndBytesTesterMixin,
GGUFCompileTesterMixin,
GGUFConfigMixin,
GGUFTesterMixin,
ModelOptCompileTesterMixin,
ModelOptConfigMixin,
ModelOptTesterMixin,
QuantizationCompileTesterMixin,
QuantizationTesterMixin,
QuantoCompileTesterMixin,
QuantoConfigMixin,
QuantoTesterMixin,
TorchAoCompileTesterMixin,
TorchAoConfigMixin,
TorchAoTesterMixin,
)
from .single_file import SingleFileTesterMixin
from .training import TrainingTesterMixin
__all__ = [
"AttentionTesterMixin",
"BaseModelTesterConfig",
"BitsAndBytesCompileTesterMixin",
"BitsAndBytesConfigMixin",
"BitsAndBytesTesterMixin",
"CacheTesterMixin",
"ContextParallelTesterMixin",
"CPUOffloadTesterMixin",
"FasterCacheConfigMixin",
"FasterCacheTesterMixin",
"FirstBlockCacheConfigMixin",
"FirstBlockCacheTesterMixin",
"GGUFCompileTesterMixin",
"GGUFConfigMixin",
"GGUFTesterMixin",
"GroupOffloadTesterMixin",
"IPAdapterTesterMixin",
"LayerwiseCastingTesterMixin",
"LoraHotSwappingForModelTesterMixin",
"LoraTesterMixin",
"MemoryTesterMixin",
"ModelOptCompileTesterMixin",
"ModelOptConfigMixin",
"ModelOptTesterMixin",
"ModelTesterMixin",
"PyramidAttentionBroadcastConfigMixin",
"PyramidAttentionBroadcastTesterMixin",
"QuantizationCompileTesterMixin",
"QuantizationTesterMixin",
"QuantoCompileTesterMixin",
"QuantoConfigMixin",
"QuantoTesterMixin",
"SingleFileTesterMixin",
"TorchAoCompileTesterMixin",
"TorchAoConfigMixin",
"TorchAoTesterMixin",
"TorchCompileTesterMixin",
"TrainingTesterMixin",
]

View File

@@ -0,0 +1,181 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import pytest
import torch
from diffusers.models.attention import AttentionModuleMixin
from diffusers.models.attention_processor import (
AttnProcessor,
)
from ...testing_utils import (
assert_tensors_close,
backend_empty_cache,
is_attention,
torch_device,
)
@is_attention
class AttentionTesterMixin:
"""
Mixin class for testing attention processor and module functionality on models.
Tests functionality from AttentionModuleMixin including:
- Attention processor management (set/get)
- QKV projection fusion/unfusion
- Attention backends (XFormers, NPU, etc.)
Expected from config mixin:
- model_class: The model class to test
Expected methods from config mixin:
- get_init_dict(): Returns dict of arguments to initialize the model
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
Pytest mark: attention
Use `pytest -m "not attention"` to skip these tests
"""
def setup_method(self):
gc.collect()
backend_empty_cache(torch_device)
def teardown_method(self):
gc.collect()
backend_empty_cache(torch_device)
@torch.no_grad()
def test_fuse_unfuse_qkv_projections(self, atol=1e-3, rtol=0):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
if not hasattr(model, "fuse_qkv_projections"):
pytest.skip("Model does not support QKV projection fusion.")
output_before_fusion = model(**inputs_dict, return_dict=False)[0]
model.fuse_qkv_projections()
has_fused_projections = False
for module in model.modules():
if isinstance(module, AttentionModuleMixin):
if hasattr(module, "to_qkv") or hasattr(module, "to_kv"):
has_fused_projections = True
assert module.fused_projections, "fused_projections flag should be True"
break
if has_fused_projections:
output_after_fusion = model(**inputs_dict, return_dict=False)[0]
assert_tensors_close(
output_before_fusion,
output_after_fusion,
atol=atol,
rtol=rtol,
msg="Output should not change after fusing projections",
)
model.unfuse_qkv_projections()
for module in model.modules():
if isinstance(module, AttentionModuleMixin):
assert not hasattr(module, "to_qkv"), "to_qkv should be removed after unfusing"
assert not hasattr(module, "to_kv"), "to_kv should be removed after unfusing"
assert not module.fused_projections, "fused_projections flag should be False"
output_after_unfusion = model(**inputs_dict, return_dict=False)[0]
assert_tensors_close(
output_before_fusion,
output_after_unfusion,
atol=atol,
rtol=rtol,
msg="Output should match original after unfusing projections",
)
def test_get_set_processor(self):
init_dict = self.get_init_dict()
model = self.model_class(**init_dict)
model.to(torch_device)
# Check if model has attention processors
if not hasattr(model, "attn_processors"):
pytest.skip("Model does not have attention processors.")
# Test getting processors
processors = model.attn_processors
assert isinstance(processors, dict), "attn_processors should return a dict"
assert len(processors) > 0, "Model should have at least one attention processor"
# Test that all processors can be retrieved via get_processor
for module in model.modules():
if isinstance(module, AttentionModuleMixin):
processor = module.get_processor()
assert processor is not None, "get_processor should return a processor"
# Test setting a new processor
new_processor = AttnProcessor()
module.set_processor(new_processor)
retrieved_processor = module.get_processor()
assert retrieved_processor is new_processor, "Retrieved processor should be the same as the one set"
def test_attention_processor_dict(self):
init_dict = self.get_init_dict()
model = self.model_class(**init_dict)
model.to(torch_device)
if not hasattr(model, "set_attn_processor"):
pytest.skip("Model does not support setting attention processors.")
# Get current processors
current_processors = model.attn_processors
# Create a dict of new processors
new_processors = {key: AttnProcessor() for key in current_processors.keys()}
# Set processors using dict
model.set_attn_processor(new_processors)
# Verify all processors were set
updated_processors = model.attn_processors
for key in current_processors.keys():
assert type(updated_processors[key]) == AttnProcessor, f"Processor {key} should be AttnProcessor"
def test_attention_processor_count_mismatch_raises_error(self):
init_dict = self.get_init_dict()
model = self.model_class(**init_dict)
model.to(torch_device)
if not hasattr(model, "set_attn_processor"):
pytest.skip("Model does not support setting attention processors.")
# Get current processors
current_processors = model.attn_processors
# Create a dict with wrong number of processors
wrong_processors = {list(current_processors.keys())[0]: AttnProcessor()}
# Verify error is raised
with pytest.raises(ValueError) as exc_info:
model.set_attn_processor(wrong_processors)
assert "number of processors" in str(exc_info.value).lower(), "Error should mention processor count mismatch"

View File

@@ -0,0 +1,556 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import pytest
import torch
from diffusers.hooks import FasterCacheConfig, FirstBlockCacheConfig, PyramidAttentionBroadcastConfig
from diffusers.hooks.faster_cache import _FASTER_CACHE_BLOCK_HOOK, _FASTER_CACHE_DENOISER_HOOK
from diffusers.hooks.first_block_cache import _FBC_BLOCK_HOOK, _FBC_LEADER_BLOCK_HOOK
from diffusers.hooks.pyramid_attention_broadcast import _PYRAMID_ATTENTION_BROADCAST_HOOK
from diffusers.models.cache_utils import CacheMixin
from ...testing_utils import assert_tensors_close, backend_empty_cache, is_cache, torch_device
def require_cache_mixin(func):
"""Decorator to skip tests if model doesn't use CacheMixin."""
def wrapper(self, *args, **kwargs):
if not issubclass(self.model_class, CacheMixin):
pytest.skip(f"{self.model_class.__name__} does not use CacheMixin.")
return func(self, *args, **kwargs)
return wrapper
class CacheTesterMixin:
"""
Base mixin class providing common test implementations for cache testing.
Cache-specific mixins should:
1. Inherit from their respective config mixin (e.g., PyramidAttentionBroadcastConfigMixin)
2. Inherit from this mixin
3. Define the cache config to use for tests
Expected class attributes:
- model_class: The model class to test (must use CacheMixin)
Expected methods in test classes:
- get_init_dict(): Returns dict of arguments to initialize the model
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
Optional overrides:
- cache_input_key: Property returning the input tensor key to vary between passes (default: "hidden_states")
"""
@property
def cache_input_key(self):
return "hidden_states"
def setup_method(self):
gc.collect()
backend_empty_cache(torch_device)
def teardown_method(self):
gc.collect()
backend_empty_cache(torch_device)
def _get_cache_config(self):
"""
Get the cache config for testing.
Should be implemented by subclasses.
"""
raise NotImplementedError("Subclass must implement _get_cache_config")
def _get_hook_names(self):
"""
Get the hook names to check for this cache type.
Should be implemented by subclasses.
Returns a list of hook name strings.
"""
raise NotImplementedError("Subclass must implement _get_hook_names")
def _test_cache_enable_disable_state(self):
"""Test that cache enable/disable updates the is_cache_enabled state correctly."""
init_dict = self.get_init_dict()
model = self.model_class(**init_dict).to(torch_device)
# Initially cache should not be enabled
assert not model.is_cache_enabled, "Cache should not be enabled initially."
config = self._get_cache_config()
# Enable cache
model.enable_cache(config)
assert model.is_cache_enabled, "Cache should be enabled after enable_cache()."
# Disable cache
model.disable_cache()
assert not model.is_cache_enabled, "Cache should not be enabled after disable_cache()."
def _test_cache_double_enable_raises_error(self):
"""Test that enabling cache twice raises an error."""
init_dict = self.get_init_dict()
model = self.model_class(**init_dict).to(torch_device)
config = self._get_cache_config()
model.enable_cache(config)
# Trying to enable again should raise ValueError
with pytest.raises(ValueError, match="Caching has already been enabled"):
model.enable_cache(config)
# Cleanup
model.disable_cache()
def _test_cache_hooks_registered(self):
"""Test that cache hooks are properly registered and removed."""
init_dict = self.get_init_dict()
model = self.model_class(**init_dict).to(torch_device)
config = self._get_cache_config()
hook_names = self._get_hook_names()
model.enable_cache(config)
# Check that at least one hook was registered
hook_count = 0
for module in model.modules():
if hasattr(module, "_diffusers_hook"):
for hook_name in hook_names:
hook = module._diffusers_hook.get_hook(hook_name)
if hook is not None:
hook_count += 1
assert hook_count > 0, f"At least one cache hook should be registered. Hook names: {hook_names}"
# Disable and verify hooks are removed
model.disable_cache()
hook_count_after = 0
for module in model.modules():
if hasattr(module, "_diffusers_hook"):
for hook_name in hook_names:
hook = module._diffusers_hook.get_hook(hook_name)
if hook is not None:
hook_count_after += 1
assert hook_count_after == 0, "Cache hooks should be removed after disable_cache()."
@torch.no_grad()
def _test_cache_inference(self):
"""Test that model can run inference with cache enabled."""
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict).to(torch_device)
model.eval()
config = self._get_cache_config()
model.enable_cache(config)
# First pass populates the cache
_ = model(**inputs_dict, return_dict=False)[0]
# Create modified inputs for second pass (vary input tensor to simulate denoising)
inputs_dict_step2 = inputs_dict.copy()
if self.cache_input_key in inputs_dict_step2:
inputs_dict_step2[self.cache_input_key] = inputs_dict_step2[self.cache_input_key] + torch.randn_like(
inputs_dict_step2[self.cache_input_key]
)
# Second pass uses cached attention with different inputs (produces approximated output)
output_with_cache = model(**inputs_dict_step2, return_dict=False)[0]
assert output_with_cache is not None, "Model output should not be None with cache enabled."
assert not torch.isnan(output_with_cache).any(), "Model output contains NaN with cache enabled."
# Run same inputs without cache to compare
model.disable_cache()
output_without_cache = model(**inputs_dict_step2, return_dict=False)[0]
# Cached output should be different from non-cached output (due to approximation)
assert not torch.allclose(output_without_cache, output_with_cache, atol=1e-5), (
"Cached output should be different from non-cached output due to cache approximation."
)
@torch.no_grad()
def _test_cache_context_manager(self, atol=1e-5, rtol=0):
"""Test the cache_context context manager properly isolates cache state."""
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict).to(torch_device)
model.eval()
config = self._get_cache_config()
model.enable_cache(config)
# Run inference in first context
with model.cache_context("context_1"):
output_ctx1 = model(**inputs_dict, return_dict=False)[0]
# Run same inference in second context (cache should be reset)
with model.cache_context("context_2"):
output_ctx2 = model(**inputs_dict, return_dict=False)[0]
# Both contexts should produce the same output (first pass in each)
assert_tensors_close(
output_ctx1,
output_ctx2,
atol=atol,
rtol=rtol,
msg="First pass in different cache contexts should produce the same output.",
)
model.disable_cache()
@torch.no_grad()
def _test_reset_stateful_cache(self):
"""Test that _reset_stateful_cache resets the cache state."""
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict).to(torch_device)
model.eval()
config = self._get_cache_config()
model.enable_cache(config)
_ = model(**inputs_dict, return_dict=False)[0]
model._reset_stateful_cache()
model.disable_cache()
@is_cache
class PyramidAttentionBroadcastConfigMixin:
"""
Base mixin providing PyramidAttentionBroadcast cache config.
Expected class attributes:
- model_class: The model class to test (must use CacheMixin)
"""
# Default PAB config - can be overridden by subclasses
PAB_CONFIG = {
"spatial_attention_block_skip_range": 2,
}
# Store timestep for callback (must be within default range (100, 800) for skipping to trigger)
_current_timestep = 500
def _get_cache_config(self):
config_kwargs = self.PAB_CONFIG.copy()
config_kwargs["current_timestep_callback"] = lambda: self._current_timestep
return PyramidAttentionBroadcastConfig(**config_kwargs)
def _get_hook_names(self):
return [_PYRAMID_ATTENTION_BROADCAST_HOOK]
@is_cache
class PyramidAttentionBroadcastTesterMixin(PyramidAttentionBroadcastConfigMixin, CacheTesterMixin):
"""
Mixin class for testing PyramidAttentionBroadcast caching on models.
Expected class attributes:
- model_class: The model class to test (must use CacheMixin)
Expected methods to be implemented by subclasses:
- get_init_dict(): Returns dict of arguments to initialize the model
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
Pytest mark: cache
Use `pytest -m "not cache"` to skip these tests
"""
@require_cache_mixin
def test_pab_cache_enable_disable_state(self):
self._test_cache_enable_disable_state()
@require_cache_mixin
def test_pab_cache_double_enable_raises_error(self):
self._test_cache_double_enable_raises_error()
@require_cache_mixin
def test_pab_cache_hooks_registered(self):
self._test_cache_hooks_registered()
@require_cache_mixin
def test_pab_cache_inference(self):
self._test_cache_inference()
@require_cache_mixin
def test_pab_cache_context_manager(self):
self._test_cache_context_manager()
@require_cache_mixin
def test_pab_reset_stateful_cache(self):
self._test_reset_stateful_cache()
@is_cache
class FirstBlockCacheConfigMixin:
"""
Base mixin providing FirstBlockCache config.
Expected class attributes:
- model_class: The model class to test (must use CacheMixin)
"""
# Default FBC config - can be overridden by subclasses
# Higher threshold makes FBC more aggressive about caching (skips more often)
FBC_CONFIG = {
"threshold": 1.0,
}
def _get_cache_config(self):
return FirstBlockCacheConfig(**self.FBC_CONFIG)
def _get_hook_names(self):
return [_FBC_LEADER_BLOCK_HOOK, _FBC_BLOCK_HOOK]
@is_cache
class FirstBlockCacheTesterMixin(FirstBlockCacheConfigMixin, CacheTesterMixin):
"""
Mixin class for testing FirstBlockCache on models.
Expected class attributes:
- model_class: The model class to test (must use CacheMixin)
Expected methods to be implemented by subclasses:
- get_init_dict(): Returns dict of arguments to initialize the model
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
Pytest mark: cache
Use `pytest -m "not cache"` to skip these tests
"""
@torch.no_grad()
def _test_cache_inference(self):
"""Test that model can run inference with FBC cache enabled (requires cache_context)."""
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict).to(torch_device)
model.eval()
config = self._get_cache_config()
model.enable_cache(config)
# FBC requires cache_context to be set for inference
with model.cache_context("fbc_test"):
# First pass populates the cache
_ = model(**inputs_dict, return_dict=False)[0]
# Create modified inputs for second pass
inputs_dict_step2 = inputs_dict.copy()
if self.cache_input_key in inputs_dict_step2:
inputs_dict_step2[self.cache_input_key] = inputs_dict_step2[self.cache_input_key] + torch.randn_like(
inputs_dict_step2[self.cache_input_key]
)
# Second pass - FBC should skip remaining blocks and use cached residuals
output_with_cache = model(**inputs_dict_step2, return_dict=False)[0]
assert output_with_cache is not None, "Model output should not be None with cache enabled."
assert not torch.isnan(output_with_cache).any(), "Model output contains NaN with cache enabled."
# Run same inputs without cache to compare
model.disable_cache()
output_without_cache = model(**inputs_dict_step2, return_dict=False)[0]
# Cached output should be different from non-cached output (due to approximation)
assert not torch.allclose(output_without_cache, output_with_cache, atol=1e-5), (
"Cached output should be different from non-cached output due to cache approximation."
)
@torch.no_grad()
def _test_reset_stateful_cache(self):
"""Test that _reset_stateful_cache resets the FBC cache state (requires cache_context)."""
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict).to(torch_device)
model.eval()
config = self._get_cache_config()
model.enable_cache(config)
with model.cache_context("fbc_test"):
_ = model(**inputs_dict, return_dict=False)[0]
model._reset_stateful_cache()
model.disable_cache()
@require_cache_mixin
def test_fbc_cache_enable_disable_state(self):
self._test_cache_enable_disable_state()
@require_cache_mixin
def test_fbc_cache_double_enable_raises_error(self):
self._test_cache_double_enable_raises_error()
@require_cache_mixin
def test_fbc_cache_hooks_registered(self):
self._test_cache_hooks_registered()
@require_cache_mixin
def test_fbc_cache_inference(self):
self._test_cache_inference()
@require_cache_mixin
def test_fbc_cache_context_manager(self):
self._test_cache_context_manager()
@require_cache_mixin
def test_fbc_reset_stateful_cache(self):
self._test_reset_stateful_cache()
@is_cache
class FasterCacheConfigMixin:
"""
Base mixin providing FasterCache config.
Expected class attributes:
- model_class: The model class to test (must use CacheMixin)
"""
# Default FasterCache config - can be overridden by subclasses
FASTER_CACHE_CONFIG = {
"spatial_attention_block_skip_range": 2,
"spatial_attention_timestep_skip_range": (-1, 901),
"tensor_format": "BCHW",
}
def _get_cache_config(self, current_timestep_callback=None):
config_kwargs = self.FASTER_CACHE_CONFIG.copy()
if current_timestep_callback is None:
current_timestep_callback = lambda: 1000 # noqa: E731
config_kwargs["current_timestep_callback"] = current_timestep_callback
return FasterCacheConfig(**config_kwargs)
def _get_hook_names(self):
return [_FASTER_CACHE_DENOISER_HOOK, _FASTER_CACHE_BLOCK_HOOK]
@is_cache
class FasterCacheTesterMixin(FasterCacheConfigMixin, CacheTesterMixin):
"""
Mixin class for testing FasterCache on models.
Note: FasterCache is designed for pipeline-level inference with proper CFG batch handling
and timestep management. Inference tests are skipped at model level - FasterCache should
be tested via pipeline tests (e.g., FluxPipeline, HunyuanVideoPipeline).
Expected class attributes:
- model_class: The model class to test (must use CacheMixin)
Expected methods to be implemented by subclasses:
- get_init_dict(): Returns dict of arguments to initialize the model
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
Pytest mark: cache
Use `pytest -m "not cache"` to skip these tests
"""
@torch.no_grad()
def _test_cache_inference(self):
"""Test that model can run inference with FasterCache enabled."""
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict).to(torch_device)
model.eval()
current_timestep = [1000]
config = self._get_cache_config(current_timestep_callback=lambda: current_timestep[0])
model.enable_cache(config)
# First pass with timestep outside skip range - computes and populates cache
current_timestep[0] = 1000
_ = model(**inputs_dict, return_dict=False)[0]
# Move timestep inside skip range so subsequent passes use cache
current_timestep[0] = 500
# Create modified inputs for second pass
inputs_dict_step2 = inputs_dict.copy()
if self.cache_input_key in inputs_dict_step2:
inputs_dict_step2[self.cache_input_key] = inputs_dict_step2[self.cache_input_key] + torch.randn_like(
inputs_dict_step2[self.cache_input_key]
)
# Second pass uses cached attention with different inputs
output_with_cache = model(**inputs_dict_step2, return_dict=False)[0]
assert output_with_cache is not None, "Model output should not be None with cache enabled."
assert not torch.isnan(output_with_cache).any(), "Model output contains NaN with cache enabled."
# Run same inputs without cache to compare
model.disable_cache()
output_without_cache = model(**inputs_dict_step2, return_dict=False)[0]
# Cached output should be different from non-cached output (due to approximation)
assert not torch.allclose(output_without_cache, output_with_cache, atol=1e-5), (
"Cached output should be different from non-cached output due to cache approximation."
)
@torch.no_grad()
def _test_reset_stateful_cache(self):
"""Test that _reset_stateful_cache resets the FasterCache state."""
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict).to(torch_device)
model.eval()
config = self._get_cache_config()
model.enable_cache(config)
_ = model(**inputs_dict, return_dict=False)[0]
model._reset_stateful_cache()
model.disable_cache()
@require_cache_mixin
def test_faster_cache_enable_disable_state(self):
self._test_cache_enable_disable_state()
@require_cache_mixin
def test_faster_cache_double_enable_raises_error(self):
self._test_cache_double_enable_raises_error()
@require_cache_mixin
def test_faster_cache_hooks_registered(self):
self._test_cache_hooks_registered()
@require_cache_mixin
def test_faster_cache_inference(self):
self._test_cache_inference()
@require_cache_mixin
def test_faster_cache_context_manager(self):
self._test_cache_context_manager()
@require_cache_mixin
def test_faster_cache_reset_stateful_cache(self):
self._test_reset_stateful_cache()

View File

@@ -0,0 +1,666 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
from collections import defaultdict
from typing import Any, Dict, Optional, Type
import pytest
import torch
import torch.nn as nn
from accelerate.utils.modeling import _get_proper_dtype, compute_module_sizes, dtype_byte_size
from diffusers.utils import SAFE_WEIGHTS_INDEX_NAME, _add_variant, logging
from diffusers.utils.testing_utils import require_accelerator, require_torch_multi_accelerator
from ...testing_utils import assert_tensors_close, torch_device
def named_persistent_module_tensors(
module: nn.Module,
recurse: bool = False,
):
"""
A helper function that gathers all the tensors (parameters + persistent buffers) of a given module.
Args:
module (`torch.nn.Module`):
The module we want the tensors on.
recurse (`bool`, *optional`, defaults to `False`):
Whether or not to go look in every submodule or just return the direct parameters and buffers.
"""
yield from module.named_parameters(recurse=recurse)
for named_buffer in module.named_buffers(recurse=recurse):
name, _ = named_buffer
# Get parent by splitting on dots and traversing the model
parent = module
if "." in name:
parent_name = name.rsplit(".", 1)[0]
for part in parent_name.split("."):
parent = getattr(parent, part)
name = name.split(".")[-1]
if name not in parent._non_persistent_buffers_set:
yield named_buffer
def compute_module_persistent_sizes(
model: nn.Module,
dtype: str | torch.device | None = None,
special_dtypes: dict[str, str | torch.device] | None = None,
):
"""
Compute the size of each submodule of a given model (parameters + persistent buffers).
"""
if dtype is not None:
dtype = _get_proper_dtype(dtype)
dtype_size = dtype_byte_size(dtype)
if special_dtypes is not None:
special_dtypes = {key: _get_proper_dtype(dtyp) for key, dtyp in special_dtypes.items()}
special_dtypes_size = {key: dtype_byte_size(dtyp) for key, dtyp in special_dtypes.items()}
module_sizes = defaultdict(int)
module_list = []
module_list = named_persistent_module_tensors(model, recurse=True)
for name, tensor in module_list:
if special_dtypes is not None and name in special_dtypes:
size = tensor.numel() * special_dtypes_size[name]
elif dtype is None:
size = tensor.numel() * dtype_byte_size(tensor.dtype)
elif str(tensor.dtype).startswith(("torch.uint", "torch.int", "torch.bool")):
# According to the code in set_module_tensor_to_device, these types won't be converted
# so use their original size here
size = tensor.numel() * dtype_byte_size(tensor.dtype)
else:
size = tensor.numel() * min(dtype_size, dtype_byte_size(tensor.dtype))
name_parts = name.split(".")
for idx in range(len(name_parts) + 1):
module_sizes[".".join(name_parts[:idx])] += size
return module_sizes
def calculate_expected_num_shards(index_map_path):
"""
Calculate expected number of shards from index file.
Args:
index_map_path: Path to the sharded checkpoint index file
Returns:
int: Expected number of shards
"""
with open(index_map_path) as f:
weight_map_dict = json.load(f)["weight_map"]
first_key = list(weight_map_dict.keys())[0]
weight_loc = weight_map_dict[first_key] # e.g., diffusion_pytorch_model-00001-of-00002.safetensors
expected_num_shards = int(weight_loc.split("-")[-1].split(".")[0])
return expected_num_shards
def check_device_map_is_respected(model, device_map):
for param_name, param in model.named_parameters():
# Find device in device_map
while len(param_name) > 0 and param_name not in device_map:
param_name = ".".join(param_name.split(".")[:-1])
if param_name not in device_map:
raise ValueError("device map is incomplete, it does not contain any device for `param_name`.")
param_device = device_map[param_name]
if param_device in ["cpu", "disk"]:
assert param.device == torch.device("meta"), f"Expected device 'meta' for {param_name}, got {param.device}"
else:
assert param.device == torch.device(param_device), (
f"Expected device {param_device} for {param_name}, got {param.device}"
)
def cast_inputs_to_dtype(inputs, current_dtype, target_dtype):
if torch.is_tensor(inputs):
return inputs.to(target_dtype) if inputs.dtype == current_dtype else inputs
if isinstance(inputs, dict):
return {k: cast_inputs_to_dtype(v, current_dtype, target_dtype) for k, v in inputs.items()}
if isinstance(inputs, list):
return [cast_inputs_to_dtype(v, current_dtype, target_dtype) for v in inputs]
return inputs
class BaseModelTesterConfig:
"""
Base class defining the configuration interface for model testing.
This class defines the contract that all model test classes must implement.
It provides a consistent interface for accessing model configuration, initialization
parameters, and test inputs across all testing mixins.
Required properties (must be implemented by subclasses):
- model_class: The model class to test
Optional properties (can be overridden, have sensible defaults):
- pretrained_model_name_or_path: Hub repository ID for pretrained model (default: None)
- pretrained_model_kwargs: Additional kwargs for from_pretrained (default: {})
- output_shape: Expected output shape for output validation tests (default: None)
- model_split_percents: Percentages for model parallelism tests (default: [0.5, 0.7])
Required methods (must be implemented by subclasses):
- get_init_dict(): Returns dict of arguments to initialize the model
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
Example usage:
class MyModelTestConfig(BaseModelTesterConfig):
@property
def model_class(self):
return MyModel
@property
def pretrained_model_name_or_path(self):
return "org/my-model"
@property
def output_shape(self):
return (1, 3, 32, 32)
def get_init_dict(self):
return {"in_channels": 3, "out_channels": 3}
def get_dummy_inputs(self):
return {"sample": torch.randn(1, 3, 32, 32, device=torch_device)}
class TestMyModel(MyModelTestConfig, ModelTesterMixin, QuantizationTesterMixin):
pass
"""
# ==================== Required Properties ====================
@property
def model_class(self) -> Type[nn.Module]:
"""The model class to test. Must be implemented by subclasses."""
raise NotImplementedError("Subclasses must implement the `model_class` property.")
# ==================== Optional Properties ====================
@property
def pretrained_model_name_or_path(self) -> Optional[str]:
"""Hub repository ID for the pretrained model (used for quantization and hub tests)."""
return None
@property
def pretrained_model_kwargs(self) -> Dict[str, Any]:
"""Additional kwargs to pass to from_pretrained (e.g., subfolder, variant)."""
return {}
@property
def output_shape(self) -> Optional[tuple]:
"""Expected output shape for output validation tests."""
return None
@property
def model_split_percents(self) -> list:
"""Percentages for model parallelism tests."""
return [0.5, 0.7]
# ==================== Required Methods ====================
def get_init_dict(self) -> Dict[str, Any]:
"""
Returns dict of arguments to initialize the model.
Returns:
Dict[str, Any]: Initialization arguments for the model constructor.
Example:
return {
"in_channels": 3,
"out_channels": 3,
"sample_size": 32,
}
"""
raise NotImplementedError("Subclasses must implement `get_init_dict()`.")
def get_dummy_inputs(self) -> Dict[str, Any]:
"""
Returns dict of inputs to pass to the model forward pass.
Returns:
Dict[str, Any]: Input tensors/values for model.forward().
Example:
return {
"sample": torch.randn(1, 3, 32, 32, device=torch_device),
"timestep": torch.tensor([1], device=torch_device),
}
"""
raise NotImplementedError("Subclasses must implement `get_dummy_inputs()`.")
class ModelTesterMixin:
"""
Base mixin class for model testing with common test methods.
This mixin expects the test class to also inherit from BaseModelTesterConfig
(or implement its interface) which provides:
- model_class: The model class to test
- get_init_dict(): Returns dict of arguments to initialize the model
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
Example:
class MyModelTestConfig(BaseModelTesterConfig):
model_class = MyModel
def get_init_dict(self): ...
def get_dummy_inputs(self): ...
class TestMyModel(MyModelTestConfig, ModelTesterMixin):
pass
"""
@torch.no_grad()
def test_from_save_pretrained(self, tmp_path, atol=5e-5, rtol=5e-5):
torch.manual_seed(0)
model = self.model_class(**self.get_init_dict())
model.to(torch_device)
model.eval()
model.save_pretrained(tmp_path)
new_model = self.model_class.from_pretrained(tmp_path)
new_model.to(torch_device)
for param_name in model.state_dict().keys():
param_1 = model.state_dict()[param_name]
param_2 = new_model.state_dict()[param_name]
assert param_1.shape == param_2.shape, (
f"Parameter shape mismatch for {param_name}. Original: {param_1.shape}, loaded: {param_2.shape}"
)
image = model(**self.get_dummy_inputs(), return_dict=False)[0]
new_image = new_model(**self.get_dummy_inputs(), return_dict=False)[0]
assert_tensors_close(image, new_image, atol=atol, rtol=rtol, msg="Models give different forward passes.")
@torch.no_grad()
def test_from_save_pretrained_variant(self, tmp_path, atol=5e-5, rtol=0):
model = self.model_class(**self.get_init_dict())
model.to(torch_device)
model.eval()
model.save_pretrained(tmp_path, variant="fp16")
new_model = self.model_class.from_pretrained(tmp_path, variant="fp16")
with pytest.raises(OSError) as exc_info:
self.model_class.from_pretrained(tmp_path)
assert "Error no file named diffusion_pytorch_model.bin found in directory" in str(exc_info.value)
new_model.to(torch_device)
image = model(**self.get_dummy_inputs(), return_dict=False)[0]
new_image = new_model(**self.get_dummy_inputs(), return_dict=False)[0]
assert_tensors_close(image, new_image, atol=atol, rtol=rtol, msg="Models give different forward passes.")
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=["fp32", "fp16", "bf16"])
def test_from_save_pretrained_dtype(self, tmp_path, dtype):
model = self.model_class(**self.get_init_dict())
model.to(torch_device)
model.eval()
if torch_device == "mps" and dtype == torch.bfloat16:
pytest.skip(reason=f"{dtype} is not supported on {torch_device}")
model.to(dtype)
model.save_pretrained(tmp_path)
new_model = self.model_class.from_pretrained(tmp_path, low_cpu_mem_usage=True, torch_dtype=dtype)
assert new_model.dtype == dtype
if hasattr(self.model_class, "_keep_in_fp32_modules") and self.model_class._keep_in_fp32_modules is None:
# When loading without accelerate dtype == torch.float32 if _keep_in_fp32_modules is not None
new_model = self.model_class.from_pretrained(tmp_path, low_cpu_mem_usage=False, torch_dtype=dtype)
assert new_model.dtype == dtype
@torch.no_grad()
def test_determinism(self, atol=1e-5, rtol=0):
model = self.model_class(**self.get_init_dict())
model.to(torch_device)
model.eval()
first = model(**self.get_dummy_inputs(), return_dict=False)[0]
second = model(**self.get_dummy_inputs(), return_dict=False)[0]
first_flat = first.flatten()
second_flat = second.flatten()
mask = ~(torch.isnan(first_flat) | torch.isnan(second_flat))
first_filtered = first_flat[mask]
second_filtered = second_flat[mask]
assert_tensors_close(
first_filtered, second_filtered, atol=atol, rtol=rtol, msg="Model outputs are not deterministic"
)
@torch.no_grad()
def test_output(self, expected_output_shape=None):
model = self.model_class(**self.get_init_dict())
model.to(torch_device)
model.eval()
inputs_dict = self.get_dummy_inputs()
output = model(**inputs_dict, return_dict=False)[0]
assert output is not None, "Model output is None"
assert output[0].shape == expected_output_shape or self.output_shape, (
f"Output shape does not match expected. Expected {expected_output_shape}, got {output.shape}"
)
@torch.no_grad()
def test_outputs_equivalence(self, atol=1e-5, rtol=0):
def set_nan_tensor_to_zero(t):
device = t.device
if device.type == "mps":
t = t.to("cpu")
t[t != t] = 0
return t.to(device)
def recursive_check(tuple_object, dict_object):
if isinstance(tuple_object, (list, tuple)):
for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object.values()):
recursive_check(tuple_iterable_value, dict_iterable_value)
elif isinstance(tuple_object, dict):
for tuple_iterable_value, dict_iterable_value in zip(tuple_object.values(), dict_object.values()):
recursive_check(tuple_iterable_value, dict_iterable_value)
elif tuple_object is None:
return
else:
assert_tensors_close(
set_nan_tensor_to_zero(tuple_object),
set_nan_tensor_to_zero(dict_object),
atol=atol,
rtol=rtol,
msg="Tuple and dict output are not equal",
)
model = self.model_class(**self.get_init_dict())
model.to(torch_device)
model.eval()
outputs_dict = model(**self.get_dummy_inputs())
outputs_tuple = model(**self.get_dummy_inputs(), return_dict=False)
recursive_check(outputs_tuple, outputs_dict)
def test_getattr_is_correct(self, caplog):
init_dict = self.get_init_dict()
model = self.model_class(**init_dict)
# save some things to test
model.dummy_attribute = 5
model.register_to_config(test_attribute=5)
logger_name = "diffusers.models.modeling_utils"
with caplog.at_level(logging.WARNING, logger=logger_name):
caplog.clear()
assert hasattr(model, "dummy_attribute")
assert getattr(model, "dummy_attribute") == 5
assert model.dummy_attribute == 5
# no warning should be thrown
assert caplog.text == ""
with caplog.at_level(logging.WARNING, logger=logger_name):
caplog.clear()
assert hasattr(model, "save_pretrained")
fn = model.save_pretrained
fn_1 = getattr(model, "save_pretrained")
assert fn == fn_1
# no warning should be thrown
assert caplog.text == ""
# warning should be thrown for config attributes accessed directly
with pytest.warns(FutureWarning):
assert model.test_attribute == 5
with pytest.warns(FutureWarning):
assert getattr(model, "test_attribute") == 5
with pytest.raises(AttributeError) as error:
model.does_not_exist
assert str(error.value) == f"'{type(model).__name__}' object has no attribute 'does_not_exist'"
@require_accelerator
@pytest.mark.skipif(
torch_device not in ["cuda", "xpu"],
reason="float16 and bfloat16 can only be used with an accelerator",
)
def test_keep_in_fp32_modules(self):
model = self.model_class(**self.get_init_dict())
fp32_modules = model._keep_in_fp32_modules
if fp32_modules is None or len(fp32_modules) == 0:
pytest.skip("Model does not have _keep_in_fp32_modules defined.")
# Test with float16
model.to(torch_device)
model.to(torch.float16)
for name, param in model.named_parameters():
if any(module_to_keep_in_fp32 in name.split(".") for module_to_keep_in_fp32 in fp32_modules):
assert param.dtype == torch.float32, f"Parameter {name} should be float32 but got {param.dtype}"
else:
assert param.dtype == torch.float16, f"Parameter {name} should be float16 but got {param.dtype}"
@require_accelerator
@pytest.mark.skipif(
torch_device not in ["cuda", "xpu"],
reason="float16 and bfloat16 can only be use for inference with an accelerator",
)
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"])
@torch.no_grad()
def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype):
model = self.model_class(**self.get_init_dict())
model.to(torch_device)
fp32_modules = model._keep_in_fp32_modules or []
model.to(dtype).save_pretrained(tmp_path)
model_loaded = self.model_class.from_pretrained(tmp_path, torch_dtype=dtype).to(torch_device)
for name, param in model_loaded.named_parameters():
if fp32_modules and any(
module_to_keep_in_fp32 in name.split(".") for module_to_keep_in_fp32 in fp32_modules
):
assert param.data.dtype == torch.float32
else:
assert param.data.dtype == dtype
inputs = cast_inputs_to_dtype(self.get_dummy_inputs(), torch.float32, dtype)
output = model(**inputs, return_dict=False)[0]
output_loaded = model_loaded(**inputs, return_dict=False)[0]
self._check_dtype_inference_output(output, output_loaded, dtype)
def _check_dtype_inference_output(self, output, output_loaded, dtype, atol=1e-4, rtol=0):
"""Check dtype inference output with configurable tolerance."""
assert_tensors_close(
output, output_loaded, atol=atol, rtol=rtol, msg=f"Loaded model output differs for {dtype}"
)
@require_accelerator
@torch.no_grad()
def test_sharded_checkpoints(self, tmp_path, atol=1e-5, rtol=0):
torch.manual_seed(0)
config = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**config).eval()
model = model.to(torch_device)
base_output = model(**inputs_dict, return_dict=False)[0]
model_size = compute_module_persistent_sizes(model)[""]
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small
model.cpu().save_pretrained(tmp_path, max_shard_size=f"{max_shard_size}KB")
assert os.path.exists(os.path.join(tmp_path, SAFE_WEIGHTS_INDEX_NAME)), "Index file should exist"
# Check if the right number of shards exists
expected_num_shards = calculate_expected_num_shards(os.path.join(tmp_path, SAFE_WEIGHTS_INDEX_NAME))
actual_num_shards = len([file for file in os.listdir(tmp_path) if file.endswith(".safetensors")])
assert actual_num_shards == expected_num_shards, (
f"Expected {expected_num_shards} shards, got {actual_num_shards}"
)
new_model = self.model_class.from_pretrained(tmp_path).eval()
new_model = new_model.to(torch_device)
torch.manual_seed(0)
inputs_dict_new = self.get_dummy_inputs()
new_output = new_model(**inputs_dict_new, return_dict=False)[0]
assert_tensors_close(
base_output, new_output, atol=atol, rtol=rtol, msg="Output should match after sharded save/load"
)
@require_accelerator
@torch.no_grad()
def test_sharded_checkpoints_with_variant(self, tmp_path, atol=1e-5, rtol=0):
torch.manual_seed(0)
config = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**config).eval()
model = model.to(torch_device)
base_output = model(**inputs_dict, return_dict=False)[0]
model_size = compute_module_persistent_sizes(model)[""]
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small
variant = "fp16"
model.cpu().save_pretrained(tmp_path, max_shard_size=f"{max_shard_size}KB", variant=variant)
index_filename = _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)
assert os.path.exists(os.path.join(tmp_path, index_filename)), (
f"Variant index file {index_filename} should exist"
)
# Check if the right number of shards exists
expected_num_shards = calculate_expected_num_shards(os.path.join(tmp_path, index_filename))
actual_num_shards = len([file for file in os.listdir(tmp_path) if file.endswith(".safetensors")])
assert actual_num_shards == expected_num_shards, (
f"Expected {expected_num_shards} shards, got {actual_num_shards}"
)
new_model = self.model_class.from_pretrained(tmp_path, variant=variant).eval()
new_model = new_model.to(torch_device)
torch.manual_seed(0)
inputs_dict_new = self.get_dummy_inputs()
new_output = new_model(**inputs_dict_new, return_dict=False)[0]
assert_tensors_close(
base_output, new_output, atol=atol, rtol=rtol, msg="Output should match after variant sharded save/load"
)
@torch.no_grad()
def test_sharded_checkpoints_with_parallel_loading(self, tmp_path, atol=1e-5, rtol=0):
from diffusers.utils import constants
torch.manual_seed(0)
config = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**config).eval()
model = model.to(torch_device)
base_output = model(**inputs_dict, return_dict=False)[0]
model_size = compute_module_persistent_sizes(model)[""]
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small
# Save original values to restore after test
original_parallel_loading = constants.HF_ENABLE_PARALLEL_LOADING
original_parallel_workers = getattr(constants, "HF_PARALLEL_WORKERS", None)
try:
model.cpu().save_pretrained(tmp_path, max_shard_size=f"{max_shard_size}KB")
assert os.path.exists(os.path.join(tmp_path, SAFE_WEIGHTS_INDEX_NAME)), "Index file should exist"
# Check if the right number of shards exists
expected_num_shards = calculate_expected_num_shards(os.path.join(tmp_path, SAFE_WEIGHTS_INDEX_NAME))
actual_num_shards = len([file for file in os.listdir(tmp_path) if file.endswith(".safetensors")])
assert actual_num_shards == expected_num_shards, (
f"Expected {expected_num_shards} shards, got {actual_num_shards}"
)
# Load without parallel loading
constants.HF_ENABLE_PARALLEL_LOADING = False
model_sequential = self.model_class.from_pretrained(tmp_path).eval()
model_sequential = model_sequential.to(torch_device)
# Load with parallel loading
constants.HF_ENABLE_PARALLEL_LOADING = True
constants.DEFAULT_HF_PARALLEL_LOADING_WORKERS = 2
torch.manual_seed(0)
model_parallel = self.model_class.from_pretrained(tmp_path).eval()
model_parallel = model_parallel.to(torch_device)
torch.manual_seed(0)
inputs_dict_parallel = self.get_dummy_inputs()
output_parallel = model_parallel(**inputs_dict_parallel, return_dict=False)[0]
assert_tensors_close(
base_output, output_parallel, atol=atol, rtol=rtol, msg="Output should match with parallel loading"
)
finally:
# Restore original values
constants.HF_ENABLE_PARALLEL_LOADING = original_parallel_loading
if original_parallel_workers is not None:
constants.HF_PARALLEL_WORKERS = original_parallel_workers
@require_torch_multi_accelerator
@torch.no_grad()
def test_model_parallelism(self, tmp_path, atol=1e-5, rtol=0):
if self.model_class._no_split_modules is None:
pytest.skip("Test not supported for this model as `_no_split_modules` is not set.")
config = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**config).eval()
model = model.to(torch_device)
torch.manual_seed(0)
base_output = model(**inputs_dict, return_dict=False)[0]
model_size = compute_module_sizes(model)[""]
max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents]
model.cpu().save_pretrained(tmp_path)
for max_size in max_gpu_sizes:
max_memory = {0: max_size, 1: model_size * 2, "cpu": model_size * 2}
new_model = self.model_class.from_pretrained(tmp_path, device_map="auto", max_memory=max_memory)
# Making sure part of the model will be on GPU 0 and GPU 1
assert set(new_model.hf_device_map.values()) == {0, 1}, "Model should be split across GPUs"
check_device_map_is_respected(new_model, new_model.hf_device_map)
torch.manual_seed(0)
new_output = new_model(**inputs_dict, return_dict=False)[0]
assert_tensors_close(
base_output, new_output, atol=atol, rtol=rtol, msg="Output should match with model parallelism"
)

View File

@@ -0,0 +1,166 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import os
import pytest
import torch
from ...testing_utils import (
backend_empty_cache,
is_torch_compile,
require_accelerator,
require_torch_version_greater,
torch_device,
)
@is_torch_compile
@require_accelerator
@require_torch_version_greater("2.7.1")
class TorchCompileTesterMixin:
"""
Mixin class for testing torch.compile functionality on models.
Expected from config mixin:
- model_class: The model class to test
Optional properties:
- different_shapes_for_compilation: List of (height, width) tuples for dynamic shape testing (default: None)
Expected methods from config mixin:
- get_init_dict(): Returns dict of arguments to initialize the model
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
Pytest mark: compile
Use `pytest -m "not compile"` to skip these tests
"""
@property
def different_shapes_for_compilation(self) -> list[tuple[int, int]] | None:
"""Optional list of (height, width) tuples for dynamic shape testing."""
return None
def setup_method(self):
torch.compiler.reset()
gc.collect()
backend_empty_cache(torch_device)
def teardown_method(self):
torch.compiler.reset()
gc.collect()
backend_empty_cache(torch_device)
@torch.no_grad()
def test_torch_compile_recompilation_and_graph_break(self):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict).to(torch_device)
model.eval()
model = torch.compile(model, fullgraph=True)
with (
torch._inductor.utils.fresh_inductor_cache(),
torch._dynamo.config.patch(error_on_recompile=True),
):
_ = model(**inputs_dict)
_ = model(**inputs_dict)
@torch.no_grad()
def test_torch_compile_repeated_blocks(self):
if self.model_class._repeated_blocks is None:
pytest.skip("Skipping test as the model class doesn't have `_repeated_blocks` set.")
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict).to(torch_device)
model.eval()
model.compile_repeated_blocks(fullgraph=True)
recompile_limit = 1
if self.model_class.__name__ == "UNet2DConditionModel":
recompile_limit = 2
with (
torch._inductor.utils.fresh_inductor_cache(),
torch._dynamo.config.patch(recompile_limit=recompile_limit),
):
_ = model(**inputs_dict)
_ = model(**inputs_dict)
@torch.no_grad()
def test_compile_with_group_offloading(self):
if not self.model_class._supports_group_offloading:
pytest.skip("Model does not support group offloading.")
torch._dynamo.config.cache_size_limit = 10000
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict)
model.eval()
group_offload_kwargs = {
"onload_device": torch_device,
"offload_device": "cpu",
"offload_type": "block_level",
"num_blocks_per_group": 1,
"use_stream": True,
"non_blocking": True,
}
model.enable_group_offload(**group_offload_kwargs)
model.compile()
_ = model(**inputs_dict)
_ = model(**inputs_dict)
@torch.no_grad()
def test_compile_on_different_shapes(self):
if self.different_shapes_for_compilation is None:
pytest.skip(f"Skipping as `different_shapes_for_compilation` is not set for {self.__class__.__name__}.")
torch.fx.experimental._config.use_duck_shape = False
init_dict = self.get_init_dict()
model = self.model_class(**init_dict).to(torch_device)
model.eval()
model = torch.compile(model, fullgraph=True, dynamic=True)
for height, width in self.different_shapes_for_compilation:
with torch._dynamo.config.patch(error_on_recompile=True):
inputs_dict = self.get_dummy_inputs(height=height, width=width)
_ = model(**inputs_dict)
@torch.no_grad()
def test_compile_works_with_aot(self, tmp_path):
from torch._inductor.package import load_package
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict).to(torch_device)
exported_model = torch.export.export(model, args=(), kwargs=inputs_dict)
package_path = os.path.join(str(tmp_path), f"{self.model_class.__name__}.pt2")
_ = torch._inductor.aoti_compile_and_package(exported_model, package_path=package_path)
assert os.path.exists(package_path), f"Package file not created at {package_path}"
loaded_binary = load_package(package_path, run_single_threaded=True)
model.forward = loaded_binary
_ = model(**inputs_dict)
_ = model(**inputs_dict)

View File

@@ -0,0 +1,158 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import pytest
import torch
from ...testing_utils import backend_empty_cache, is_ip_adapter, torch_device
def check_if_ip_adapter_correctly_set(model, processor_cls) -> bool:
"""
Check if IP Adapter processors are correctly set in the model.
Args:
model: The model to check
Returns:
bool: True if IP Adapter is correctly set, False otherwise
"""
for module in model.attn_processors.values():
if isinstance(module, processor_cls):
return True
return False
@is_ip_adapter
class IPAdapterTesterMixin:
"""
Mixin class for testing IP Adapter functionality on models.
Expected from config mixin:
- model_class: The model class to test
Required properties (must be implemented by subclasses):
- ip_adapter_processor_cls: The IP Adapter processor class to use
Required methods (must be implemented by subclasses):
- create_ip_adapter_state_dict(): Creates IP Adapter state dict for testing
- modify_inputs_for_ip_adapter(): Modifies inputs to include IP Adapter data
Expected methods from config mixin:
- get_init_dict(): Returns dict of arguments to initialize the model
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
Pytest mark: ip_adapter
Use `pytest -m "not ip_adapter"` to skip these tests
"""
def setup_method(self):
gc.collect()
backend_empty_cache(torch_device)
def teardown_method(self):
gc.collect()
backend_empty_cache(torch_device)
@property
def ip_adapter_processor_cls(self):
"""IP Adapter processor class to use for testing. Must be implemented by subclasses."""
raise NotImplementedError("Subclasses must implement the `ip_adapter_processor_cls` property.")
def create_ip_adapter_state_dict(self, model):
raise NotImplementedError("child class must implement method to create IPAdapter State Dict")
def modify_inputs_for_ip_adapter(self, model, inputs_dict):
raise NotImplementedError("child class must implement method to create IPAdapter model inputs")
@torch.no_grad()
def test_load_ip_adapter(self):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict).to(torch_device)
torch.manual_seed(0)
output_no_adapter = model(**inputs_dict, return_dict=False)[0]
ip_adapter_state_dict = self.create_ip_adapter_state_dict(model)
model._load_ip_adapter_weights([ip_adapter_state_dict])
assert check_if_ip_adapter_correctly_set(model, self.ip_adapter_processor_cls), (
"IP Adapter processors not set correctly"
)
inputs_dict_with_adapter = self.modify_inputs_for_ip_adapter(model, inputs_dict.copy())
outputs_with_adapter = model(**inputs_dict_with_adapter, return_dict=False)[0]
assert not torch.allclose(output_no_adapter, outputs_with_adapter, atol=1e-4, rtol=1e-4), (
"Output should differ with IP Adapter enabled"
)
@pytest.mark.skip(
reason="Setting IP Adapter scale is not defined at the model level. Enable this test after refactoring"
)
def test_ip_adapter_scale(self):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict).to(torch_device)
ip_adapter_state_dict = self.create_ip_adapter_state_dict(model)
model._load_ip_adapter_weights([ip_adapter_state_dict])
inputs_dict_with_adapter = self.modify_inputs_for_ip_adapter(model, inputs_dict.copy())
# Test scale = 0.0 (no effect)
model.set_ip_adapter_scale(0.0)
torch.manual_seed(0)
output_scale_zero = model(**inputs_dict_with_adapter, return_dict=False)[0]
# Test scale = 1.0 (full effect)
model.set_ip_adapter_scale(1.0)
torch.manual_seed(0)
output_scale_one = model(**inputs_dict_with_adapter, return_dict=False)[0]
# Outputs should differ with different scales
assert not torch.allclose(output_scale_zero, output_scale_one, atol=1e-4, rtol=1e-4), (
"Output should differ with different IP Adapter scales"
)
@pytest.mark.skip(
reason="Unloading IP Adapter is not defined at the model level. Enable this test after refactoring"
)
def test_unload_ip_adapter(self):
init_dict = self.get_init_dict()
model = self.model_class(**init_dict).to(torch_device)
# Save original processors
original_processors = {k: type(v).__name__ for k, v in model.attn_processors.items()}
# Create and load IP adapter
ip_adapter_state_dict = self.create_ip_adapter_state_dict(model)
model._load_ip_adapter_weights([ip_adapter_state_dict])
assert check_if_ip_adapter_correctly_set(model, self.ip_adapter_processor_cls), "IP Adapter should be set"
# Unload IP adapter
model.unload_ip_adapter()
assert not check_if_ip_adapter_correctly_set(model, self.ip_adapter_processor_cls), (
"IP Adapter should be unloaded"
)
# Verify processors are restored
current_processors = {k: type(v).__name__ for k, v in model.attn_processors.items()}
assert original_processors == current_processors, "Processors should be restored after unload"

View File

@@ -0,0 +1,555 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import json
import os
import re
import pytest
import safetensors.torch
import torch
import torch.nn as nn
from diffusers.utils.import_utils import is_peft_available
from diffusers.utils.testing_utils import check_if_dicts_are_equal
from ...testing_utils import (
assert_tensors_close,
backend_empty_cache,
is_lora,
is_torch_compile,
require_peft_backend,
require_peft_version_greater,
require_torch_accelerator,
require_torch_version_greater,
torch_device,
)
if is_peft_available():
from diffusers.loaders.peft import PeftAdapterMixin
def check_if_lora_correctly_set(model) -> bool:
"""
Check if LoRA layers are correctly set in the model.
Args:
model: The model to check
Returns:
bool: True if LoRA is correctly set, False otherwise
"""
from peft.tuners.tuners_utils import BaseTunerLayer
for module in model.modules():
if isinstance(module, BaseTunerLayer):
return True
return False
@is_lora
@require_peft_backend
class LoraTesterMixin:
"""
Mixin class for testing LoRA/PEFT functionality on models.
Expected from config mixin:
- model_class: The model class to test
Expected methods from config mixin:
- get_init_dict(): Returns dict of arguments to initialize the model
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
Pytest mark: lora
Use `pytest -m "not lora"` to skip these tests
"""
def setup_method(self):
if not issubclass(self.model_class, PeftAdapterMixin):
pytest.skip(f"PEFT is not supported for this model ({self.model_class.__name__}).")
@torch.no_grad()
def test_save_load_lora_adapter(self, tmp_path, rank=4, lora_alpha=4, use_dora=False, atol=1e-4, rtol=1e-4):
from peft import LoraConfig
from peft.utils import get_peft_model_state_dict
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict).to(torch_device)
torch.manual_seed(0)
output_no_lora = model(**inputs_dict, return_dict=False)[0]
denoiser_lora_config = LoraConfig(
r=rank,
lora_alpha=lora_alpha,
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
init_lora_weights=False,
use_dora=use_dora,
)
model.add_adapter(denoiser_lora_config)
assert check_if_lora_correctly_set(model), "LoRA layers not set correctly"
torch.manual_seed(0)
outputs_with_lora = model(**inputs_dict, return_dict=False)[0]
assert not torch.allclose(output_no_lora, outputs_with_lora, atol=atol, rtol=rtol), (
"Output should differ with LoRA enabled"
)
model.save_lora_adapter(tmp_path)
assert os.path.isfile(os.path.join(tmp_path, "pytorch_lora_weights.safetensors")), (
"LoRA weights file not created"
)
state_dict_loaded = safetensors.torch.load_file(os.path.join(tmp_path, "pytorch_lora_weights.safetensors"))
model.unload_lora()
assert not check_if_lora_correctly_set(model), "LoRA should be unloaded"
model.load_lora_adapter(tmp_path, prefix=None, use_safetensors=True)
state_dict_retrieved = get_peft_model_state_dict(model, adapter_name="default_0")
for k in state_dict_loaded:
loaded_v = state_dict_loaded[k]
retrieved_v = state_dict_retrieved[k].to(loaded_v.device)
assert_tensors_close(loaded_v, retrieved_v, atol=atol, rtol=rtol, msg=f"Mismatch in LoRA weight {k}")
assert check_if_lora_correctly_set(model), "LoRA layers not set correctly after reload"
torch.manual_seed(0)
outputs_with_lora_2 = model(**inputs_dict, return_dict=False)[0]
assert not torch.allclose(output_no_lora, outputs_with_lora_2, atol=atol, rtol=rtol), (
"Output should differ with LoRA enabled"
)
assert_tensors_close(
outputs_with_lora,
outputs_with_lora_2,
atol=atol,
rtol=rtol,
msg="Outputs should match before and after save/load",
)
def test_lora_wrong_adapter_name_raises_error(self, tmp_path):
from peft import LoraConfig
init_dict = self.get_init_dict()
model = self.model_class(**init_dict).to(torch_device)
denoiser_lora_config = LoraConfig(
r=4,
lora_alpha=4,
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
init_lora_weights=False,
use_dora=False,
)
model.add_adapter(denoiser_lora_config)
assert check_if_lora_correctly_set(model), "LoRA layers not set correctly"
wrong_name = "foo"
with pytest.raises(ValueError) as exc_info:
model.save_lora_adapter(tmp_path, adapter_name=wrong_name)
assert f"Adapter name {wrong_name} not found in the model." in str(exc_info.value)
def test_lora_adapter_metadata_is_loaded_correctly(self, tmp_path, rank=4, lora_alpha=4, use_dora=False):
from peft import LoraConfig
init_dict = self.get_init_dict()
model = self.model_class(**init_dict).to(torch_device)
denoiser_lora_config = LoraConfig(
r=rank,
lora_alpha=lora_alpha,
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
init_lora_weights=False,
use_dora=use_dora,
)
model.add_adapter(denoiser_lora_config)
metadata = model.peft_config["default"].to_dict()
assert check_if_lora_correctly_set(model), "LoRA layers not set correctly"
model.save_lora_adapter(tmp_path)
model_file = os.path.join(tmp_path, "pytorch_lora_weights.safetensors")
assert os.path.isfile(model_file), "LoRA weights file not created"
model.unload_lora()
assert not check_if_lora_correctly_set(model), "LoRA should be unloaded"
model.load_lora_adapter(tmp_path, prefix=None, use_safetensors=True)
parsed_metadata = model.peft_config["default_0"].to_dict()
check_if_dicts_are_equal(metadata, parsed_metadata)
def test_lora_adapter_wrong_metadata_raises_error(self, tmp_path):
from peft import LoraConfig
from diffusers.loaders.lora_base import LORA_ADAPTER_METADATA_KEY
init_dict = self.get_init_dict()
model = self.model_class(**init_dict).to(torch_device)
denoiser_lora_config = LoraConfig(
r=4,
lora_alpha=4,
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
init_lora_weights=False,
use_dora=False,
)
model.add_adapter(denoiser_lora_config)
assert check_if_lora_correctly_set(model), "LoRA layers not set correctly"
model.save_lora_adapter(tmp_path)
model_file = os.path.join(tmp_path, "pytorch_lora_weights.safetensors")
assert os.path.isfile(model_file), "LoRA weights file not created"
# Perturb the metadata in the state dict
loaded_state_dict = safetensors.torch.load_file(model_file)
metadata = {"format": "pt"}
lora_adapter_metadata = denoiser_lora_config.to_dict()
lora_adapter_metadata.update({"foo": 1, "bar": 2})
for key, value in lora_adapter_metadata.items():
if isinstance(value, set):
lora_adapter_metadata[key] = list(value)
metadata[LORA_ADAPTER_METADATA_KEY] = json.dumps(lora_adapter_metadata, indent=2, sort_keys=True)
safetensors.torch.save_file(loaded_state_dict, model_file, metadata=metadata)
model.unload_lora()
assert not check_if_lora_correctly_set(model), "LoRA should be unloaded"
with pytest.raises(TypeError) as exc_info:
model.load_lora_adapter(tmp_path, prefix=None, use_safetensors=True)
assert "`LoraConfig` class could not be instantiated" in str(exc_info.value)
@is_lora
@is_torch_compile
@require_peft_backend
@require_peft_version_greater("0.14.0")
@require_torch_version_greater("2.7.1")
@require_torch_accelerator
class LoraHotSwappingForModelTesterMixin:
"""
Mixin class for testing LoRA hot swapping functionality on models.
Test that hotswapping does not result in recompilation on the model directly.
We're not extensively testing the hotswapping functionality since it is implemented in PEFT
and is extensively tested there. The goal of this test is specifically to ensure that
hotswapping with diffusers does not require recompilation.
See https://github.com/huggingface/peft/blob/eaab05e18d51fb4cce20a73c9acd82a00c013b83/tests/test_gpu_examples.py#L4252
for the analogous PEFT test.
Expected from config mixin:
- model_class: The model class to test
Optional properties:
- different_shapes_for_compilation: List of (height, width) tuples for dynamic compilation tests (default: None)
Expected methods from config mixin:
- get_init_dict(): Returns dict of arguments to initialize the model
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
Pytest marks: lora, torch_compile
Use `pytest -m "not lora"` or `pytest -m "not torch_compile"` to skip these tests
"""
@property
def different_shapes_for_compilation(self) -> list[tuple[int, int]] | None:
"""Optional list of (height, width) tuples for dynamic compilation tests."""
return None
def setup_method(self):
if not issubclass(self.model_class, PeftAdapterMixin):
pytest.skip(f"PEFT is not supported for this model ({self.model_class.__name__}).")
def teardown_method(self):
# It is critical that the dynamo cache is reset for each test. Otherwise, if the test re-uses the same model,
# there will be recompilation errors, as torch caches the model when run in the same process.
torch.compiler.reset()
gc.collect()
backend_empty_cache(torch_device)
def _get_lora_config(self, lora_rank, lora_alpha, target_modules):
from peft import LoraConfig
lora_config = LoraConfig(
r=lora_rank,
lora_alpha=lora_alpha,
target_modules=target_modules,
init_lora_weights=False,
use_dora=False,
)
return lora_config
def _get_linear_module_name_other_than_attn(self, model):
linear_names = [
name for name, module in model.named_modules() if isinstance(module, nn.Linear) and "to_" not in name
]
return linear_names[0]
def _check_model_hotswap(
self, tmp_path, do_compile, rank0, rank1, target_modules0, target_modules1=None, atol=5e-3, rtol=5e-3
):
"""
Check that hotswapping works on a model.
Steps:
- create 2 LoRA adapters and save them
- load the first adapter
- hotswap the second adapter
- check that the outputs are correct
- optionally compile the model
- optionally check if recompilations happen on different shapes
Note: We set rank == alpha here because save_lora_adapter does not save the alpha scalings, thus the test would
fail if the values are different. Since rank != alpha does not matter for the purpose of this test, this is
fine.
"""
different_shapes = self.different_shapes_for_compilation
# create 2 adapters with different ranks and alphas
torch.manual_seed(0)
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict).to(torch_device)
alpha0, alpha1 = rank0, rank1
max_rank = max([rank0, rank1])
if target_modules1 is None:
target_modules1 = target_modules0[:]
lora_config0 = self._get_lora_config(rank0, alpha0, target_modules0)
lora_config1 = self._get_lora_config(rank1, alpha1, target_modules1)
model.add_adapter(lora_config0, adapter_name="adapter0")
with torch.inference_mode():
torch.manual_seed(0)
output0_before = model(**inputs_dict)["sample"]
model.add_adapter(lora_config1, adapter_name="adapter1")
model.set_adapter("adapter1")
with torch.inference_mode():
torch.manual_seed(0)
output1_before = model(**inputs_dict)["sample"]
# sanity checks:
assert not torch.allclose(output0_before, output1_before, atol=atol, rtol=rtol)
assert not (output0_before == 0).all()
assert not (output1_before == 0).all()
# save the adapter checkpoints
model.save_lora_adapter(os.path.join(tmp_path, "0"), safe_serialization=True, adapter_name="adapter0")
model.save_lora_adapter(os.path.join(tmp_path, "1"), safe_serialization=True, adapter_name="adapter1")
del model
# load the first adapter
torch.manual_seed(0)
init_dict = self.get_init_dict()
model = self.model_class(**init_dict).to(torch_device)
if do_compile or (rank0 != rank1):
# no need to prepare if the model is not compiled or if the ranks are identical
model.enable_lora_hotswap(target_rank=max_rank)
file_name0 = os.path.join(os.path.join(tmp_path, "0"), "pytorch_lora_weights.safetensors")
file_name1 = os.path.join(os.path.join(tmp_path, "1"), "pytorch_lora_weights.safetensors")
model.load_lora_adapter(file_name0, safe_serialization=True, adapter_name="adapter0", prefix=None)
if do_compile:
model = torch.compile(model, mode="reduce-overhead", dynamic=different_shapes is not None)
with torch.inference_mode():
# additionally check if dynamic compilation works.
if different_shapes is not None:
for height, width in different_shapes:
new_inputs_dict = self.prepare_dummy_input(height=height, width=width)
_ = model(**new_inputs_dict)
else:
output0_after = model(**inputs_dict)["sample"]
assert_tensors_close(
output0_before, output0_after, atol=atol, rtol=rtol, msg="Output mismatch after loading adapter0"
)
# hotswap the 2nd adapter
model.load_lora_adapter(file_name1, adapter_name="adapter0", hotswap=True, prefix=None)
# we need to call forward to potentially trigger recompilation
with torch.inference_mode():
if different_shapes is not None:
for height, width in different_shapes:
new_inputs_dict = self.prepare_dummy_input(height=height, width=width)
_ = model(**new_inputs_dict)
else:
output1_after = model(**inputs_dict)["sample"]
assert_tensors_close(
output1_before,
output1_after,
atol=atol,
rtol=rtol,
msg="Output mismatch after hotswapping to adapter1",
)
# check error when not passing valid adapter name
name = "does-not-exist"
msg = f"Trying to hotswap LoRA adapter '{name}' but there is no existing adapter by that name"
with pytest.raises(ValueError, match=re.escape(msg)):
model.load_lora_adapter(file_name1, adapter_name=name, hotswap=True, prefix=None)
@pytest.mark.parametrize("rank0,rank1", [(11, 11), (7, 13), (13, 7)])
def test_hotswapping_model(self, tmp_path, rank0, rank1):
self._check_model_hotswap(
tmp_path, do_compile=False, rank0=rank0, rank1=rank1, target_modules0=["to_q", "to_k", "to_v", "to_out.0"]
)
@pytest.mark.parametrize("rank0,rank1", [(11, 11), (7, 13), (13, 7)])
def test_hotswapping_compiled_model_linear(self, tmp_path, rank0, rank1):
# It's important to add this context to raise an error on recompilation
target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
with torch._dynamo.config.patch(error_on_recompile=True), torch._inductor.utils.fresh_inductor_cache():
self._check_model_hotswap(
tmp_path, do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules
)
@pytest.mark.parametrize("rank0,rank1", [(11, 11), (7, 13), (13, 7)])
def test_hotswapping_compiled_model_conv2d(self, tmp_path, rank0, rank1):
if "unet" not in self.model_class.__name__.lower():
pytest.skip("Test only applies to UNet.")
# It's important to add this context to raise an error on recompilation
target_modules = ["conv", "conv1", "conv2"]
with torch._dynamo.config.patch(error_on_recompile=True), torch._inductor.utils.fresh_inductor_cache():
self._check_model_hotswap(
tmp_path, do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules
)
@pytest.mark.parametrize("rank0,rank1", [(11, 11), (7, 13), (13, 7)])
def test_hotswapping_compiled_model_both_linear_and_conv2d(self, tmp_path, rank0, rank1):
if "unet" not in self.model_class.__name__.lower():
pytest.skip("Test only applies to UNet.")
# It's important to add this context to raise an error on recompilation
target_modules = ["to_q", "conv"]
with torch._dynamo.config.patch(error_on_recompile=True), torch._inductor.utils.fresh_inductor_cache():
self._check_model_hotswap(
tmp_path, do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules
)
@pytest.mark.parametrize("rank0,rank1", [(11, 11), (7, 13), (13, 7)])
def test_hotswapping_compiled_model_both_linear_and_other(self, tmp_path, rank0, rank1):
# In `test_hotswapping_compiled_model_both_linear_and_conv2d()`, we check if we can do hotswapping
# with `torch.compile()` for models that have both linear and conv layers. In this test, we check
# if we can target a linear layer from the transformer blocks and another linear layer from non-attention
# block.
target_modules = ["to_q"]
init_dict = self.get_init_dict()
model = self.model_class(**init_dict)
target_modules.append(self._get_linear_module_name_other_than_attn(model))
del model
# It's important to add this context to raise an error on recompilation
with torch._dynamo.config.patch(error_on_recompile=True):
self._check_model_hotswap(
tmp_path, do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules
)
def test_enable_lora_hotswap_called_after_adapter_added_raises(self):
# ensure that enable_lora_hotswap is called before loading the first adapter
lora_config = self._get_lora_config(8, 8, target_modules=["to_q"])
init_dict = self.get_init_dict()
model = self.model_class(**init_dict).to(torch_device)
model.add_adapter(lora_config)
msg = re.escape("Call `enable_lora_hotswap` before loading the first adapter.")
with pytest.raises(RuntimeError, match=msg):
model.enable_lora_hotswap(target_rank=32)
def test_enable_lora_hotswap_called_after_adapter_added_warning(self, caplog):
# ensure that enable_lora_hotswap is called before loading the first adapter
import logging
lora_config = self._get_lora_config(8, 8, target_modules=["to_q"])
init_dict = self.get_init_dict()
model = self.model_class(**init_dict).to(torch_device)
model.add_adapter(lora_config)
msg = (
"It is recommended to call `enable_lora_hotswap` before loading the first adapter to avoid recompilation."
)
with caplog.at_level(logging.WARNING):
model.enable_lora_hotswap(target_rank=32, check_compiled="warn")
assert any(msg in record.message for record in caplog.records)
def test_enable_lora_hotswap_called_after_adapter_added_ignore(self, caplog):
# check possibility to ignore the error/warning
import logging
lora_config = self._get_lora_config(8, 8, target_modules=["to_q"])
init_dict = self.get_init_dict()
model = self.model_class(**init_dict).to(torch_device)
model.add_adapter(lora_config)
with caplog.at_level(logging.WARNING):
model.enable_lora_hotswap(target_rank=32, check_compiled="ignore")
assert len(caplog.records) == 0
def test_enable_lora_hotswap_wrong_check_compiled_argument_raises(self):
# check that wrong argument value raises an error
lora_config = self._get_lora_config(8, 8, target_modules=["to_q"])
init_dict = self.get_init_dict()
model = self.model_class(**init_dict).to(torch_device)
model.add_adapter(lora_config)
msg = re.escape("check_compiles should be one of 'error', 'warn', or 'ignore', got 'wrong-argument' instead.")
with pytest.raises(ValueError, match=msg):
model.enable_lora_hotswap(target_rank=32, check_compiled="wrong-argument")
def test_hotswap_second_adapter_targets_more_layers_raises(self, tmp_path, caplog):
# check the error and log
import logging
# at the moment, PEFT requires the 2nd adapter to target the same or a subset of layers
target_modules0 = ["to_q"]
target_modules1 = ["to_q", "to_k"]
with pytest.raises(RuntimeError): # peft raises RuntimeError
with caplog.at_level(logging.ERROR):
self._check_model_hotswap(
tmp_path,
do_compile=True,
rank0=8,
rank1=8,
target_modules0=target_modules0,
target_modules1=target_modules1,
)
assert any("Hotswapping adapter0 was unsuccessful" in record.message for record in caplog.records)
@pytest.mark.parametrize("rank0,rank1", [(11, 11), (7, 13), (13, 7)])
@require_torch_version_greater("2.7.1")
def test_hotswapping_compile_on_different_shapes(self, tmp_path, rank0, rank1):
different_shapes_for_compilation = self.different_shapes_for_compilation
if different_shapes_for_compilation is None:
pytest.skip(f"Skipping as `different_shapes_for_compilation` is not set for {self.__class__.__name__}.")
# Specifying `use_duck_shape=False` instructs the compiler if it should use the same symbolic
# variable to represent input sizes that are the same. For more details,
# check out this [comment](https://github.com/huggingface/diffusers/pull/11327#discussion_r2047659790).
torch.fx.experimental._config.use_duck_shape = False
target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
with torch._dynamo.config.patch(error_on_recompile=True):
self._check_model_hotswap(
tmp_path,
do_compile=True,
rank0=rank0,
rank1=rank1,
target_modules0=target_modules,
)

View File

@@ -0,0 +1,498 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import glob
import inspect
from functools import wraps
import pytest
import torch
from accelerate.utils.modeling import compute_module_sizes
from diffusers.utils.testing_utils import _check_safetensors_serialization
from diffusers.utils.torch_utils import get_torch_cuda_device_capability
from ...testing_utils import (
assert_tensors_close,
backend_empty_cache,
backend_max_memory_allocated,
backend_reset_peak_memory_stats,
backend_synchronize,
is_cpu_offload,
is_group_offload,
is_memory,
require_accelerator,
torch_device,
)
from .common import cast_inputs_to_dtype, check_device_map_is_respected
def require_offload_support(func):
"""
Decorator to skip tests if model doesn't support offloading (requires _no_split_modules).
"""
@wraps(func)
def wrapper(self, *args, **kwargs):
if self.model_class._no_split_modules is None:
pytest.skip("Test not supported for this model as `_no_split_modules` is not set.")
return func(self, *args, **kwargs)
return wrapper
def require_group_offload_support(func):
"""
Decorator to skip tests if model doesn't support group offloading.
"""
@wraps(func)
def wrapper(self, *args, **kwargs):
if not self.model_class._supports_group_offloading:
pytest.skip("Model does not support group offloading.")
return func(self, *args, **kwargs)
return wrapper
@is_cpu_offload
class CPUOffloadTesterMixin:
"""
Mixin class for testing CPU offloading functionality.
Expected from config mixin:
- model_class: The model class to test
Optional properties:
- model_split_percents: List of percentages for splitting model across devices (default: [0.5, 0.7])
Expected methods from config mixin:
- get_init_dict(): Returns dict of arguments to initialize the model
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
Pytest mark: cpu_offload
Use `pytest -m "not cpu_offload"` to skip these tests
"""
@property
def model_split_percents(self) -> list[float]:
"""List of percentages for splitting model across devices during offloading tests."""
return [0.5, 0.7]
@require_offload_support
@torch.no_grad()
def test_cpu_offload(self, tmp_path, atol=1e-5, rtol=0):
config = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**config).eval()
model = model.to(torch_device)
torch.manual_seed(0)
base_output = model(**inputs_dict)
model_size = compute_module_sizes(model)[""]
# We test several splits of sizes to make sure it works
max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents]
model.cpu().save_pretrained(str(tmp_path))
for max_size in max_gpu_sizes:
max_memory = {0: max_size, "cpu": model_size * 2}
new_model = self.model_class.from_pretrained(str(tmp_path), device_map="auto", max_memory=max_memory)
# Making sure part of the model will actually end up offloaded
assert set(new_model.hf_device_map.values()) == {0, "cpu"}, "Model should be split between GPU and CPU"
check_device_map_is_respected(new_model, new_model.hf_device_map)
torch.manual_seed(0)
new_output = new_model(**inputs_dict)
assert_tensors_close(
base_output[0], new_output[0], atol=atol, rtol=rtol, msg="Output should match with CPU offloading"
)
@require_offload_support
@torch.no_grad()
def test_disk_offload_without_safetensors(self, tmp_path, atol=1e-5, rtol=0):
config = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**config).eval()
model = model.to(torch_device)
torch.manual_seed(0)
base_output = model(**inputs_dict)
model_size = compute_module_sizes(model)[""]
max_size = int(self.model_split_percents[0] * model_size)
# Force disk offload by setting very small CPU memory
max_memory = {0: max_size, "cpu": int(0.1 * max_size)}
model.cpu().save_pretrained(str(tmp_path), safe_serialization=False)
# This errors out because it's missing an offload folder
with pytest.raises(ValueError):
new_model = self.model_class.from_pretrained(str(tmp_path), device_map="auto", max_memory=max_memory)
new_model = self.model_class.from_pretrained(
str(tmp_path), device_map="auto", max_memory=max_memory, offload_folder=str(tmp_path)
)
check_device_map_is_respected(new_model, new_model.hf_device_map)
torch.manual_seed(0)
new_output = new_model(**inputs_dict)
assert_tensors_close(
base_output[0], new_output[0], atol=atol, rtol=rtol, msg="Output should match with disk offloading"
)
@require_offload_support
@torch.no_grad()
def test_disk_offload_with_safetensors(self, tmp_path, atol=1e-5, rtol=0):
config = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**config).eval()
model = model.to(torch_device)
torch.manual_seed(0)
base_output = model(**inputs_dict)
model_size = compute_module_sizes(model)[""]
model.cpu().save_pretrained(str(tmp_path))
max_size = int(self.model_split_percents[0] * model_size)
max_memory = {0: max_size, "cpu": max_size}
new_model = self.model_class.from_pretrained(
str(tmp_path), device_map="auto", offload_folder=str(tmp_path), max_memory=max_memory
)
check_device_map_is_respected(new_model, new_model.hf_device_map)
torch.manual_seed(0)
new_output = new_model(**inputs_dict)
assert_tensors_close(
base_output[0],
new_output[0],
atol=atol,
rtol=rtol,
msg="Output should match with disk offloading (safetensors)",
)
@is_group_offload
class GroupOffloadTesterMixin:
"""
Mixin class for testing group offloading functionality.
Expected from config mixin:
- model_class: The model class to test
Expected methods from config mixin:
- get_init_dict(): Returns dict of arguments to initialize the model
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
Pytest mark: group_offload
Use `pytest -m "not group_offload"` to skip these tests
"""
@require_group_offload_support
@pytest.mark.parametrize("record_stream", [False, True])
def test_group_offloading(self, record_stream, atol=1e-5, rtol=0):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
torch.manual_seed(0)
@torch.no_grad()
def run_forward(model):
assert all(
module._diffusers_hook.get_hook("group_offloading") is not None
for module in model.modules()
if hasattr(module, "_diffusers_hook")
), "Group offloading hook should be set"
model.eval()
return model(**inputs_dict)[0]
model = self.model_class(**init_dict)
model.to(torch_device)
output_without_group_offloading = run_forward(model)
torch.manual_seed(0)
model = self.model_class(**init_dict)
model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1)
output_with_group_offloading1 = run_forward(model)
torch.manual_seed(0)
model = self.model_class(**init_dict)
model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1, non_blocking=True)
output_with_group_offloading2 = run_forward(model)
torch.manual_seed(0)
model = self.model_class(**init_dict)
model.enable_group_offload(torch_device, offload_type="leaf_level")
output_with_group_offloading3 = run_forward(model)
torch.manual_seed(0)
model = self.model_class(**init_dict)
model.enable_group_offload(
torch_device, offload_type="leaf_level", use_stream=True, record_stream=record_stream
)
output_with_group_offloading4 = run_forward(model)
assert_tensors_close(
output_without_group_offloading,
output_with_group_offloading1,
atol=atol,
rtol=rtol,
msg="Output should match with block-level offloading",
)
assert_tensors_close(
output_without_group_offloading,
output_with_group_offloading2,
atol=atol,
rtol=rtol,
msg="Output should match with non-blocking block-level offloading",
)
assert_tensors_close(
output_without_group_offloading,
output_with_group_offloading3,
atol=atol,
rtol=rtol,
msg="Output should match with leaf-level offloading",
)
assert_tensors_close(
output_without_group_offloading,
output_with_group_offloading4,
atol=atol,
rtol=rtol,
msg="Output should match with leaf-level offloading with stream",
)
@require_group_offload_support
@pytest.mark.parametrize("record_stream", [False, True])
@pytest.mark.parametrize("offload_type", ["block_level", "leaf_level"])
@torch.no_grad()
def test_group_offloading_with_layerwise_casting(self, record_stream, offload_type):
torch.manual_seed(0)
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
_ = model(**inputs_dict)[0]
torch.manual_seed(0)
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
storage_dtype, compute_dtype = torch.float16, torch.float32
inputs_dict = cast_inputs_to_dtype(inputs_dict, torch.float32, compute_dtype)
model = self.model_class(**init_dict)
model.eval()
additional_kwargs = {} if offload_type == "leaf_level" else {"num_blocks_per_group": 1}
model.enable_group_offload(
torch_device, offload_type=offload_type, use_stream=True, record_stream=record_stream, **additional_kwargs
)
model.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype)
_ = model(**inputs_dict)[0]
@require_group_offload_support
@pytest.mark.parametrize("record_stream", [False, True])
@pytest.mark.parametrize("offload_type", ["block_level", "leaf_level"])
@torch.no_grad()
@torch.inference_mode()
def test_group_offloading_with_disk(self, tmp_path, record_stream, offload_type, atol=1e-5, rtol=0):
def _has_generator_arg(model):
sig = inspect.signature(model.forward)
params = sig.parameters
return "generator" in params
def _run_forward(model, inputs_dict):
accepts_generator = _has_generator_arg(model)
if accepts_generator:
inputs_dict["generator"] = torch.manual_seed(0)
torch.manual_seed(0)
return model(**inputs_dict)[0]
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
torch.manual_seed(0)
model = self.model_class(**init_dict)
model.eval()
model.to(torch_device)
output_without_group_offloading = _run_forward(model, inputs_dict)
torch.manual_seed(0)
model = self.model_class(**init_dict)
model.eval()
num_blocks_per_group = None if offload_type == "leaf_level" else 1
additional_kwargs = {} if offload_type == "leaf_level" else {"num_blocks_per_group": num_blocks_per_group}
tmpdir = str(tmp_path)
model.enable_group_offload(
torch_device,
offload_type=offload_type,
offload_to_disk_path=tmpdir,
use_stream=True,
record_stream=record_stream,
**additional_kwargs,
)
has_safetensors = glob.glob(f"{tmpdir}/*.safetensors")
assert has_safetensors, "No safetensors found in the directory."
# For "leaf-level", there is a prefetching hook which makes this check a bit non-deterministic
# in nature. So, skip it.
if offload_type != "leaf_level":
is_correct, extra_files, missing_files = _check_safetensors_serialization(
module=model,
offload_to_disk_path=tmpdir,
offload_type=offload_type,
num_blocks_per_group=num_blocks_per_group,
)
if not is_correct:
if extra_files:
raise ValueError(f"Found extra files: {', '.join(extra_files)}")
elif missing_files:
raise ValueError(f"Following files are missing: {', '.join(missing_files)}")
output_with_group_offloading = _run_forward(model, inputs_dict)
assert_tensors_close(
output_without_group_offloading,
output_with_group_offloading,
atol=atol,
rtol=rtol,
msg="Output should match with disk-based group offloading",
)
class LayerwiseCastingTesterMixin:
"""
Mixin class for testing layerwise dtype casting for memory optimization.
Expected from config mixin:
- model_class: The model class to test
Expected methods from config mixin:
- get_init_dict(): Returns dict of arguments to initialize the model
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
"""
@torch.no_grad()
def test_layerwise_casting_memory(self):
MB_TOLERANCE = 0.2
LEAST_COMPUTE_CAPABILITY = 8.0
def reset_memory_stats():
gc.collect()
backend_synchronize(torch_device)
backend_empty_cache(torch_device)
backend_reset_peak_memory_stats(torch_device)
def get_memory_usage(storage_dtype, compute_dtype):
torch.manual_seed(0)
config = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
inputs_dict = cast_inputs_to_dtype(inputs_dict, torch.float32, compute_dtype)
model = self.model_class(**config).eval()
model = model.to(torch_device, dtype=compute_dtype)
model.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype)
reset_memory_stats()
model(**inputs_dict)
model_memory_footprint = model.get_memory_footprint()
peak_inference_memory_allocated_mb = backend_max_memory_allocated(torch_device) / 1024**2
return model_memory_footprint, peak_inference_memory_allocated_mb
fp32_memory_footprint, fp32_max_memory = get_memory_usage(torch.float32, torch.float32)
fp8_e4m3_fp32_memory_footprint, fp8_e4m3_fp32_max_memory = get_memory_usage(torch.float8_e4m3fn, torch.float32)
fp8_e4m3_bf16_memory_footprint, fp8_e4m3_bf16_max_memory = get_memory_usage(
torch.float8_e4m3fn, torch.bfloat16
)
compute_capability = get_torch_cuda_device_capability() if torch_device == "cuda" else None
assert fp8_e4m3_bf16_memory_footprint < fp8_e4m3_fp32_memory_footprint < fp32_memory_footprint, (
"Memory footprint should decrease with lower precision storage"
)
# NOTE: the following assertion would fail on our CI (running Tesla T4) due to bf16 using more memory than fp32.
# On other devices, such as DGX (Ampere) and Audace (Ada), the test passes. So, we conditionally check it.
if compute_capability and compute_capability >= LEAST_COMPUTE_CAPABILITY:
assert fp8_e4m3_bf16_max_memory < fp8_e4m3_fp32_max_memory, (
"Peak memory should be lower with bf16 compute on newer GPUs"
)
# On this dummy test case with a small model, sometimes fp8_e4m3_fp32 max memory usage is higher than fp32 by a few
# bytes. This only happens for some models, so we allow a small tolerance.
# For any real model being tested, the order would be fp8_e4m3_bf16 < fp8_e4m3_fp32 < fp32.
assert (
fp8_e4m3_fp32_max_memory < fp32_max_memory
or abs(fp8_e4m3_fp32_max_memory - fp32_max_memory) < MB_TOLERANCE
), "Peak memory should be lower or within tolerance with fp8 storage"
def test_layerwise_casting_training(self):
def test_fn(storage_dtype, compute_dtype):
if torch.device(torch_device).type == "cpu" and compute_dtype == torch.bfloat16:
pytest.skip("Skipping test because CPU doesn't go well with bfloat16.")
model = self.model_class(**self.get_init_dict())
model = model.to(torch_device, dtype=compute_dtype)
model.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype)
model.train()
inputs_dict = self.get_dummy_inputs()
inputs_dict = cast_inputs_to_dtype(inputs_dict, torch.float32, compute_dtype)
with torch.amp.autocast(device_type=torch.device(torch_device).type):
output = model(**inputs_dict, return_dict=False)[0]
input_tensor = inputs_dict[self.main_input_name]
noise = torch.randn((input_tensor.shape[0],) + self.output_shape).to(torch_device)
noise = cast_inputs_to_dtype(noise, torch.float32, compute_dtype)
loss = torch.nn.functional.mse_loss(output, noise)
loss.backward()
test_fn(torch.float16, torch.float32)
test_fn(torch.float8_e4m3fn, torch.float32)
test_fn(torch.float8_e5m2, torch.float32)
test_fn(torch.float8_e4m3fn, torch.bfloat16)
@is_memory
@require_accelerator
class MemoryTesterMixin(CPUOffloadTesterMixin, GroupOffloadTesterMixin, LayerwiseCastingTesterMixin):
"""
Combined mixin class for all memory optimization tests including CPU/disk offloading,
group offloading, and layerwise dtype casting.
This mixin inherits from:
- CPUOffloadTesterMixin: CPU and disk offloading tests
- GroupOffloadTesterMixin: Group offloading tests (block-level and leaf-level)
- LayerwiseCastingTesterMixin: Layerwise dtype casting tests
Expected from config mixin:
- model_class: The model class to test
Optional properties:
- model_split_percents: List of percentages for splitting model across devices (default: [0.5, 0.7])
Expected methods from config mixin:
- get_init_dict(): Returns dict of arguments to initialize the model
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
Pytest mark: memory
Use `pytest -m "not memory"` to skip these tests
"""

View File

@@ -0,0 +1,99 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pytest
import torch
import torch.multiprocessing as mp
from diffusers.models._modeling_parallel import ContextParallelConfig
from ...testing_utils import (
is_context_parallel,
require_torch_multi_accelerator,
)
@torch.no_grad()
def _context_parallel_worker(rank, world_size, model_class, init_dict, cp_dict, inputs_dict, result_queue):
try:
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"
torch.distributed.init_process_group(
backend="nccl",
init_method="env://",
world_size=world_size,
rank=rank,
)
torch.cuda.set_device(rank)
device = torch.device(f"cuda:{rank}")
model = model_class(**init_dict)
model.to(device)
model.eval()
inputs_on_device = {}
for key, value in inputs_dict.items():
if isinstance(value, torch.Tensor):
inputs_on_device[key] = value.to(device)
else:
inputs_on_device[key] = value
cp_config = ContextParallelConfig(**cp_dict)
model.enable_parallelism(config=cp_config)
output = model(**inputs_on_device, return_dict=False)[0]
if rank == 0:
result_queue.put(("success", output.shape))
except Exception as e:
if rank == 0:
result_queue.put(("error", str(e)))
finally:
if torch.distributed.is_initialized():
torch.distributed.destroy_process_group()
@is_context_parallel
@require_torch_multi_accelerator
class ContextParallelTesterMixin:
@pytest.mark.parametrize("cp_type", ["ulysses_degree", "ring_degree"], ids=["ulysses", "ring"])
def test_context_parallel_inference(self, cp_type):
if not torch.distributed.is_available():
pytest.skip("torch.distributed is not available.")
if not hasattr(self.model_class, "_cp_plan") or self.model_class._cp_plan is None:
pytest.skip("Model does not have a _cp_plan defined for context parallel inference.")
world_size = 2
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
cp_dict = {cp_type: world_size}
ctx = mp.get_context("spawn")
result_queue = ctx.Queue()
mp.spawn(
_context_parallel_worker,
args=(world_size, self.model_class, init_dict, cp_dict, inputs_dict, result_queue),
nprocs=world_size,
join=True,
)
status, result = result_queue.get(timeout=60)
assert status == "success", f"Context parallel inference failed: {result}"

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,265 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import torch
from huggingface_hub import hf_hub_download, snapshot_download
from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name
from ...testing_utils import (
assert_tensors_close,
backend_empty_cache,
is_single_file,
nightly,
require_torch_accelerator,
torch_device,
)
from .common import check_device_map_is_respected
def download_single_file_checkpoint(pretrained_model_name_or_path, filename, tmpdir):
"""Download a single file checkpoint from the Hub to a temporary directory."""
path = hf_hub_download(pretrained_model_name_or_path, filename=filename, local_dir=tmpdir)
return path
def download_diffusers_config(pretrained_model_name_or_path, tmpdir):
"""Download diffusers config files (excluding weights) from a repository."""
path = snapshot_download(
pretrained_model_name_or_path,
ignore_patterns=[
"**/*.ckpt",
"*.ckpt",
"**/*.bin",
"*.bin",
"**/*.pt",
"*.pt",
"**/*.safetensors",
"*.safetensors",
],
allow_patterns=["**/*.json", "*.json", "*.txt", "**/*.txt"],
local_dir=tmpdir,
)
return path
@nightly
@require_torch_accelerator
@is_single_file
class SingleFileTesterMixin:
"""
Mixin class for testing single file loading for models.
Required properties (must be implemented by subclasses):
- ckpt_path: Path or Hub path to the single file checkpoint
Optional properties:
- torch_dtype: torch dtype to use for testing (default: None)
- alternate_ckpt_paths: List of alternate checkpoint paths for variant testing (default: None)
Expected from config mixin:
- model_class: The model class to test
- pretrained_model_name_or_path: Hub repository ID for the pretrained model
- pretrained_model_kwargs: Additional kwargs for from_pretrained (e.g., subfolder)
Pytest mark: single_file
Use `pytest -m "not single_file"` to skip these tests
"""
# ==================== Required Properties ====================
@property
def ckpt_path(self) -> str:
"""Path or Hub path to the single file checkpoint. Must be implemented by subclasses."""
raise NotImplementedError("Subclasses must implement the `ckpt_path` property.")
# ==================== Optional Properties ====================
@property
def torch_dtype(self) -> torch.dtype | None:
"""torch dtype to use for single file testing."""
return None
@property
def alternate_ckpt_paths(self) -> list[str] | None:
"""List of alternate checkpoint paths for variant testing."""
return None
def setup_method(self):
gc.collect()
backend_empty_cache(torch_device)
def teardown_method(self):
gc.collect()
backend_empty_cache(torch_device)
def test_single_file_model_config(self):
pretrained_kwargs = {"device": torch_device, **self.pretrained_model_kwargs}
single_file_kwargs = {"device": torch_device}
if self.torch_dtype:
pretrained_kwargs["torch_dtype"] = self.torch_dtype
single_file_kwargs["torch_dtype"] = self.torch_dtype
model = self.model_class.from_pretrained(self.pretrained_model_name_or_path, **pretrained_kwargs)
model_single_file = self.model_class.from_single_file(self.ckpt_path, **single_file_kwargs)
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
for param_name, param_value in model_single_file.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
assert model.config[param_name] == param_value, (
f"{param_name} differs between pretrained loading and single file loading: "
f"pretrained={model.config[param_name]}, single_file={param_value}"
)
def test_single_file_model_parameters(self, atol=1e-5, rtol=1e-5):
pretrained_kwargs = {"device": torch_device, **self.pretrained_model_kwargs}
single_file_kwargs = {"device": torch_device}
if self.torch_dtype:
pretrained_kwargs["torch_dtype"] = self.torch_dtype
single_file_kwargs["torch_dtype"] = self.torch_dtype
model = self.model_class.from_pretrained(self.pretrained_model_name_or_path, **pretrained_kwargs)
model_single_file = self.model_class.from_single_file(self.ckpt_path, **single_file_kwargs)
state_dict = model.state_dict()
state_dict_single_file = model_single_file.state_dict()
assert set(state_dict.keys()) == set(state_dict_single_file.keys()), (
"Model parameters keys differ between pretrained and single file loading. "
f"Missing in single file: {set(state_dict.keys()) - set(state_dict_single_file.keys())}. "
f"Extra in single file: {set(state_dict_single_file.keys()) - set(state_dict.keys())}"
)
for key in state_dict.keys():
param = state_dict[key]
param_single_file = state_dict_single_file[key]
assert param.shape == param_single_file.shape, (
f"Parameter shape mismatch for {key}: "
f"pretrained {param.shape} vs single file {param_single_file.shape}"
)
assert_tensors_close(
param, param_single_file, atol=atol, rtol=rtol, msg=f"Parameter values differ for {key}"
)
def test_single_file_loading_local_files_only(self, tmp_path):
single_file_kwargs = {}
if self.torch_dtype:
single_file_kwargs["torch_dtype"] = self.torch_dtype
pretrained_model_name_or_path, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path)
local_ckpt_path = download_single_file_checkpoint(pretrained_model_name_or_path, weight_name, str(tmp_path))
model_single_file = self.model_class.from_single_file(
local_ckpt_path, local_files_only=True, **single_file_kwargs
)
assert model_single_file is not None, "Failed to load model with local_files_only=True"
def test_single_file_loading_with_diffusers_config(self):
single_file_kwargs = {}
if self.torch_dtype:
single_file_kwargs["torch_dtype"] = self.torch_dtype
# Load with config parameter
model_single_file = self.model_class.from_single_file(
self.ckpt_path, config=self.pretrained_model_name_or_path, **single_file_kwargs
)
# Load pretrained for comparison
pretrained_kwargs = {**self.pretrained_model_kwargs}
if self.torch_dtype:
pretrained_kwargs["torch_dtype"] = self.torch_dtype
model = self.model_class.from_pretrained(self.pretrained_model_name_or_path, **pretrained_kwargs)
# Compare configs
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
for param_name, param_value in model_single_file.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
assert model.config[param_name] == param_value, (
f"{param_name} differs: pretrained={model.config[param_name]}, single_file={param_value}"
)
def test_single_file_loading_with_diffusers_config_local_files_only(self, tmp_path):
single_file_kwargs = {}
if self.torch_dtype:
single_file_kwargs["torch_dtype"] = self.torch_dtype
pretrained_model_name_or_path, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path)
local_ckpt_path = download_single_file_checkpoint(pretrained_model_name_or_path, weight_name, str(tmp_path))
local_diffusers_config = download_diffusers_config(self.pretrained_model_name_or_path, str(tmp_path))
model_single_file = self.model_class.from_single_file(
local_ckpt_path, config=local_diffusers_config, local_files_only=True, **single_file_kwargs
)
assert model_single_file is not None, "Failed to load model with config and local_files_only=True"
def test_single_file_loading_dtype(self):
for dtype in [torch.float32, torch.float16]:
if torch_device == "mps" and dtype == torch.bfloat16:
continue
model_single_file = self.model_class.from_single_file(self.ckpt_path, torch_dtype=dtype)
assert model_single_file.dtype == dtype, f"Expected dtype {dtype}, got {model_single_file.dtype}"
# Cleanup
del model_single_file
gc.collect()
backend_empty_cache(torch_device)
def test_checkpoint_variant_loading(self):
if not self.alternate_ckpt_paths:
return
for ckpt_path in self.alternate_ckpt_paths:
backend_empty_cache(torch_device)
single_file_kwargs = {}
if self.torch_dtype:
single_file_kwargs["torch_dtype"] = self.torch_dtype
model = self.model_class.from_single_file(ckpt_path, **single_file_kwargs)
assert model is not None, f"Failed to load checkpoint from {ckpt_path}"
del model
gc.collect()
backend_empty_cache(torch_device)
def test_single_file_loading_with_device_map(self):
single_file_kwargs = {"device_map": torch_device}
if self.torch_dtype:
single_file_kwargs["torch_dtype"] = self.torch_dtype
model = self.model_class.from_single_file(self.ckpt_path, **single_file_kwargs)
assert model is not None, "Failed to load model with device_map"
assert hasattr(model, "hf_device_map"), "Model should have hf_device_map attribute when loaded with device_map"
assert model.hf_device_map is not None, "hf_device_map should not be None when loaded with device_map"
check_device_map_is_respected(model, model.hf_device_map)

View File

@@ -0,0 +1,220 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import gc
import pytest
import torch
from diffusers.training_utils import EMAModel
from ...testing_utils import (
backend_empty_cache,
is_training,
require_torch_accelerator_with_training,
torch_all_close,
torch_device,
)
@is_training
@require_torch_accelerator_with_training
class TrainingTesterMixin:
"""
Mixin class for testing training functionality on models.
Expected from config mixin:
- model_class: The model class to test
- output_shape: Tuple defining the expected output shape
Expected methods from config mixin:
- get_init_dict(): Returns dict of arguments to initialize the model
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
Pytest mark: training
Use `pytest -m "not training"` to skip these tests
"""
def setup_method(self):
gc.collect()
backend_empty_cache(torch_device)
def teardown_method(self):
gc.collect()
backend_empty_cache(torch_device)
def test_training(self):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict)
model.to(torch_device)
model.train()
output = model(**inputs_dict, return_dict=False)[0]
noise = torch.randn((output.shape[0],) + self.output_shape).to(torch_device)
loss = torch.nn.functional.mse_loss(output, noise)
loss.backward()
def test_training_with_ema(self):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict)
model.to(torch_device)
model.train()
ema_model = EMAModel(model.parameters())
output = model(**inputs_dict, return_dict=False)[0]
noise = torch.randn((output.shape[0],) + self.output_shape).to(torch_device)
loss = torch.nn.functional.mse_loss(output, noise)
loss.backward()
ema_model.step(model.parameters())
def test_gradient_checkpointing(self):
if not self.model_class._supports_gradient_checkpointing:
pytest.skip("Gradient checkpointing is not supported.")
init_dict = self.get_init_dict()
# at init model should have gradient checkpointing disabled
model = self.model_class(**init_dict)
assert not model.is_gradient_checkpointing, "Gradient checkpointing should be disabled at init"
# check enable works
model.enable_gradient_checkpointing()
assert model.is_gradient_checkpointing, "Gradient checkpointing should be enabled"
# check disable works
model.disable_gradient_checkpointing()
assert not model.is_gradient_checkpointing, "Gradient checkpointing should be disabled"
def test_gradient_checkpointing_is_applied(self, expected_set=None):
if not self.model_class._supports_gradient_checkpointing:
pytest.skip("Gradient checkpointing is not supported.")
if expected_set is None:
pytest.skip("expected_set must be provided to verify gradient checkpointing is applied.")
init_dict = self.get_init_dict()
model_class_copy = copy.copy(self.model_class)
model = model_class_copy(**init_dict)
model.enable_gradient_checkpointing()
modules_with_gc_enabled = {}
for submodule in model.modules():
if hasattr(submodule, "gradient_checkpointing"):
assert submodule.gradient_checkpointing, f"{submodule.__class__.__name__} should have GC enabled"
modules_with_gc_enabled[submodule.__class__.__name__] = True
assert set(modules_with_gc_enabled.keys()) == expected_set, (
f"Modules with GC enabled {set(modules_with_gc_enabled.keys())} do not match expected set {expected_set}"
)
assert all(modules_with_gc_enabled.values()), "All modules should have GC enabled"
def test_gradient_checkpointing_equivalence(self, loss_tolerance=1e-5, param_grad_tol=5e-5, skip=None):
if not self.model_class._supports_gradient_checkpointing:
pytest.skip("Gradient checkpointing is not supported.")
if skip is None:
skip = set()
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
inputs_dict_copy = copy.deepcopy(inputs_dict)
torch.manual_seed(0)
model = self.model_class(**init_dict)
model.to(torch_device)
assert not model.is_gradient_checkpointing and model.training
out = model(**inputs_dict, return_dict=False)[0]
# run the backwards pass on the model
model.zero_grad()
labels = torch.randn_like(out)
loss = (out - labels).mean()
loss.backward()
# re-instantiate the model now enabling gradient checkpointing
torch.manual_seed(0)
model_2 = self.model_class(**init_dict)
# clone model
model_2.load_state_dict(model.state_dict())
model_2.to(torch_device)
model_2.enable_gradient_checkpointing()
assert model_2.is_gradient_checkpointing and model_2.training
out_2 = model_2(**inputs_dict_copy, return_dict=False)[0]
# run the backwards pass on the model
model_2.zero_grad()
loss_2 = (out_2 - labels).mean()
loss_2.backward()
# compare the output and parameters gradients
assert (loss - loss_2).abs() < loss_tolerance, (
f"Loss difference {(loss - loss_2).abs()} exceeds tolerance {loss_tolerance}"
)
named_params = dict(model.named_parameters())
named_params_2 = dict(model_2.named_parameters())
for name, param in named_params.items():
if "post_quant_conv" in name:
continue
if name in skip:
continue
if param.grad is None:
continue
assert torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=param_grad_tol), (
f"Gradient mismatch for {name}"
)
def test_mixed_precision_training(self):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict)
model.to(torch_device)
model.train()
# Test with float16
if torch.device(torch_device).type != "cpu":
with torch.amp.autocast(device_type=torch.device(torch_device).type, dtype=torch.float16):
output = model(**inputs_dict, return_dict=False)[0]
noise = torch.randn((output.shape[0],) + self.output_shape).to(torch_device)
loss = torch.nn.functional.mse_loss(output, noise)
loss.backward()
# Test with bfloat16
if torch.device(torch_device).type != "cpu":
model.zero_grad()
with torch.amp.autocast(device_type=torch.device(torch_device).type, dtype=torch.bfloat16):
output = model(**inputs_dict, return_dict=False)[0]
noise = torch.randn((output.shape[0],) + self.output_shape).to(torch_device)
loss = torch.nn.functional.mse_loss(output, noise)
loss.backward()

View File

@@ -13,23 +13,51 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from typing import Any
import torch
from diffusers import FluxTransformer2DModel
from diffusers.models.attention_processor import FluxIPAdapterJointAttnProcessor2_0
from diffusers.models.embeddings import ImageProjection
from diffusers.models.transformers.transformer_flux import FluxIPAdapterAttnProcessor
from diffusers.utils.torch_utils import randn_tensor
from ...testing_utils import enable_full_determinism, is_peft_available, torch_device
from ..test_modeling_common import LoraHotSwappingForModelTesterMixin, ModelTesterMixin, TorchCompileTesterMixin
from ...testing_utils import enable_full_determinism, torch_device
from ..testing_utils import (
AttentionTesterMixin,
BaseModelTesterConfig,
BitsAndBytesCompileTesterMixin,
BitsAndBytesTesterMixin,
ContextParallelTesterMixin,
FasterCacheTesterMixin,
FirstBlockCacheTesterMixin,
GGUFCompileTesterMixin,
GGUFTesterMixin,
IPAdapterTesterMixin,
LoraHotSwappingForModelTesterMixin,
LoraTesterMixin,
MemoryTesterMixin,
ModelOptCompileTesterMixin,
ModelOptTesterMixin,
ModelTesterMixin,
PyramidAttentionBroadcastTesterMixin,
QuantoCompileTesterMixin,
QuantoTesterMixin,
SingleFileTesterMixin,
TorchAoCompileTesterMixin,
TorchAoTesterMixin,
TorchCompileTesterMixin,
TrainingTesterMixin,
)
enable_full_determinism()
def create_flux_ip_adapter_state_dict(model):
# "ip_adapter" (cross-attention weights)
# TODO: This standalone function maintains backward compatibility with pipeline tests
# (tests/pipelines/test_pipelines_common.py) and will be refactored.
def create_flux_ip_adapter_state_dict(model) -> dict[str, dict[str, Any]]:
"""Create a dummy IP Adapter state dict for Flux transformer testing."""
ip_cross_attn_state_dict = {}
key_id = 0
@@ -39,7 +67,7 @@ def create_flux_ip_adapter_state_dict(model):
joint_attention_dim = model.config["joint_attention_dim"]
hidden_size = model.config["num_attention_heads"] * model.config["attention_head_dim"]
sd = FluxIPAdapterJointAttnProcessor2_0(
sd = FluxIPAdapterAttnProcessor(
hidden_size=hidden_size, cross_attention_dim=joint_attention_dim, scale=1.0
).state_dict()
ip_cross_attn_state_dict.update(
@@ -50,11 +78,8 @@ def create_flux_ip_adapter_state_dict(model):
f"{key_id}.to_v_ip.bias": sd["to_v_ip.0.bias"],
}
)
key_id += 1
# "image_proj" (ImageProjection layer weights)
image_projection = ImageProjection(
cross_attention_dim=model.config["joint_attention_dim"],
image_embed_dim=(
@@ -75,57 +100,37 @@ def create_flux_ip_adapter_state_dict(model):
)
del sd
ip_state_dict = {}
ip_state_dict.update({"image_proj": ip_image_projection_state_dict, "ip_adapter": ip_cross_attn_state_dict})
return ip_state_dict
return {"image_proj": ip_image_projection_state_dict, "ip_adapter": ip_cross_attn_state_dict}
class FluxTransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = FluxTransformer2DModel
main_input_name = "hidden_states"
# We override the items here because the transformer under consideration is small.
model_split_percents = [0.7, 0.6, 0.6]
# Skip setting testing with default: AttnProcessor
uses_custom_attn_processor = True
class FluxTransformerTesterConfig(BaseModelTesterConfig):
@property
def model_class(self):
return FluxTransformer2DModel
@property
def dummy_input(self):
return self.prepare_dummy_input()
def pretrained_model_name_or_path(self):
return "hf-internal-testing/tiny-flux-pipe"
@property
def input_shape(self):
def pretrained_model_kwargs(self):
return {"subfolder": "transformer"}
@property
def output_shape(self) -> tuple[int, int]:
return (16, 4)
@property
def output_shape(self):
def input_shape(self) -> tuple[int, int]:
return (16, 4)
def prepare_dummy_input(self, height=4, width=4):
batch_size = 1
num_latent_channels = 4
num_image_channels = 3
sequence_length = 48
embedding_dim = 32
hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
pooled_prompt_embeds = torch.randn((batch_size, embedding_dim)).to(torch_device)
text_ids = torch.randn((sequence_length, num_image_channels)).to(torch_device)
image_ids = torch.randn((height * width, num_image_channels)).to(torch_device)
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
@property
def generator(self):
return torch.Generator("cpu").manual_seed(0)
def get_init_dict(self) -> dict[str, int | list[int]]:
"""Return Flux model initialization arguments."""
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"img_ids": image_ids,
"txt_ids": text_ids,
"pooled_projections": pooled_prompt_embeds,
"timestep": timestep,
}
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"patch_size": 1,
"in_channels": 4,
"num_layers": 1,
@@ -137,11 +142,40 @@ class FluxTransformerTests(ModelTesterMixin, unittest.TestCase):
"axes_dims_rope": [4, 4, 8],
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
batch_size = 1
height = width = 4
num_latent_channels = 4
num_image_channels = 3
sequence_length = 48
embedding_dim = 32
return {
"hidden_states": randn_tensor(
(batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device
),
"encoder_hidden_states": randn_tensor(
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
),
"pooled_projections": randn_tensor(
(batch_size, embedding_dim), generator=self.generator, device=torch_device
),
"img_ids": randn_tensor(
(height * width, num_image_channels), generator=self.generator, device=torch_device
),
"txt_ids": randn_tensor(
(sequence_length, num_image_channels), generator=self.generator, device=torch_device
),
"timestep": torch.tensor([1.0]).to(torch_device).expand(batch_size),
}
class TestFluxTransformer(FluxTransformerTesterConfig, ModelTesterMixin):
def test_deprecated_inputs_img_txt_ids_3d(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
"""Test that deprecated 3D img_ids and txt_ids still work."""
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
@@ -162,63 +196,269 @@ class FluxTransformerTests(ModelTesterMixin, unittest.TestCase):
with torch.no_grad():
output_2 = model(**inputs_dict).to_tuple()[0]
self.assertEqual(output_1.shape, output_2.shape)
self.assertTrue(
torch.allclose(output_1, output_2, atol=1e-5),
msg="output with deprecated inputs (img_ids and txt_ids as 3d torch tensors) are not equal as them as 2d inputs",
assert output_1.shape == output_2.shape
assert torch.allclose(output_1, output_2, atol=1e-5), (
"output with deprecated inputs (img_ids and txt_ids as 3d torch tensors) "
"are not equal as them as 2d inputs"
)
def test_gradient_checkpointing_is_applied(self):
expected_set = {"FluxTransformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
# The test exists for cases like
# https://github.com/huggingface/diffusers/issues/11874
@unittest.skipIf(not is_peft_available(), "Only with PEFT")
def test_lora_exclude_modules(self):
from peft import LoraConfig, get_peft_model_state_dict, inject_adapter_in_model, set_peft_model_state_dict
class TestFluxTransformerMemory(FluxTransformerTesterConfig, MemoryTesterMixin):
"""Memory optimization tests for Flux Transformer."""
lora_rank = 4
target_module = "single_transformer_blocks.0.proj_out"
adapter_name = "foo"
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)
state_dict = model.state_dict()
target_mod_shape = state_dict[f"{target_module}.weight"].shape
lora_state_dict = {
f"{target_module}.lora_A.weight": torch.ones(lora_rank, target_mod_shape[1]) * 22,
f"{target_module}.lora_B.weight": torch.ones(target_mod_shape[0], lora_rank) * 33,
class TestFluxTransformerTraining(FluxTransformerTesterConfig, TrainingTesterMixin):
"""Training tests for Flux Transformer."""
class TestFluxTransformerAttention(FluxTransformerTesterConfig, AttentionTesterMixin):
"""Attention processor tests for Flux Transformer."""
class TestFluxTransformerContextParallel(FluxTransformerTesterConfig, ContextParallelTesterMixin):
"""Context Parallel inference tests for Flux Transformer"""
class TestFluxTransformerIPAdapter(FluxTransformerTesterConfig, IPAdapterTesterMixin):
"""IP Adapter tests for Flux Transformer."""
@property
def ip_adapter_processor_cls(self):
return FluxIPAdapterAttnProcessor
def modify_inputs_for_ip_adapter(self, model, inputs_dict):
torch.manual_seed(0)
# Create dummy image embeds for IP adapter
cross_attention_dim = getattr(model.config, "joint_attention_dim", 32)
image_embeds = torch.randn(1, 1, cross_attention_dim).to(torch_device)
inputs_dict.update({"joint_attention_kwargs": {"ip_adapter_image_embeds": image_embeds}})
return inputs_dict
def create_ip_adapter_state_dict(self, model: Any) -> dict[str, dict[str, Any]]:
return create_flux_ip_adapter_state_dict(model)
class TestFluxTransformerLoRA(FluxTransformerTesterConfig, LoraTesterMixin):
"""LoRA adapter tests for Flux Transformer."""
class TestFluxTransformerLoRAHotSwap(FluxTransformerTesterConfig, LoraHotSwappingForModelTesterMixin):
"""LoRA hot-swapping tests for Flux Transformer."""
@property
def different_shapes_for_compilation(self):
return [(4, 4), (4, 8), (8, 8)]
def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:
"""Override to support dynamic height/width for LoRA hotswap tests."""
batch_size = 1
num_latent_channels = 4
num_image_channels = 3
sequence_length = 24
embedding_dim = 8
return {
"hidden_states": randn_tensor((batch_size, height * width, num_latent_channels), device=torch_device),
"encoder_hidden_states": randn_tensor((batch_size, sequence_length, embedding_dim), device=torch_device),
"pooled_projections": randn_tensor((batch_size, embedding_dim), device=torch_device),
"img_ids": randn_tensor((height * width, num_image_channels), device=torch_device),
"txt_ids": randn_tensor((sequence_length, num_image_channels), device=torch_device),
"timestep": torch.tensor([1.0]).to(torch_device).expand(batch_size),
}
# Passing exclude_modules should no longer be necessary (or even passing target_modules, for that matter).
config = LoraConfig(
r=lora_rank, target_modules=["single_transformer_blocks.0.proj_out"], exclude_modules=["proj_out"]
)
inject_adapter_in_model(config, model, adapter_name=adapter_name, state_dict=lora_state_dict)
set_peft_model_state_dict(model, lora_state_dict, adapter_name)
retrieved_lora_state_dict = get_peft_model_state_dict(model, adapter_name=adapter_name)
assert len(retrieved_lora_state_dict) == len(lora_state_dict)
assert (retrieved_lora_state_dict["single_transformer_blocks.0.proj_out.lora_A.weight"] == 22).all()
assert (retrieved_lora_state_dict["single_transformer_blocks.0.proj_out.lora_B.weight"] == 33).all()
class FluxTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
model_class = FluxTransformer2DModel
different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]
class TestFluxTransformerCompile(FluxTransformerTesterConfig, TorchCompileTesterMixin):
@property
def different_shapes_for_compilation(self):
return [(4, 4), (4, 8), (8, 8)]
def prepare_init_args_and_inputs_for_common(self):
return FluxTransformerTests().prepare_init_args_and_inputs_for_common()
def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:
"""Override to support dynamic height/width for compilation tests."""
batch_size = 1
num_latent_channels = 4
num_image_channels = 3
sequence_length = 24
embedding_dim = 8
def prepare_dummy_input(self, height, width):
return FluxTransformerTests().prepare_dummy_input(height=height, width=width)
return {
"hidden_states": randn_tensor((batch_size, height * width, num_latent_channels), device=torch_device),
"encoder_hidden_states": randn_tensor((batch_size, sequence_length, embedding_dim), device=torch_device),
"pooled_projections": randn_tensor((batch_size, embedding_dim), device=torch_device),
"img_ids": randn_tensor((height * width, num_image_channels), device=torch_device),
"txt_ids": randn_tensor((sequence_length, num_image_channels), device=torch_device),
"timestep": torch.tensor([1.0]).to(torch_device).expand(batch_size),
}
class FluxTransformerLoRAHotSwapTests(LoraHotSwappingForModelTesterMixin, unittest.TestCase):
model_class = FluxTransformer2DModel
different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]
class TestFluxSingleFile(FluxTransformerTesterConfig, SingleFileTesterMixin):
@property
def ckpt_path(self):
return "https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/flux1-dev.safetensors"
def prepare_init_args_and_inputs_for_common(self):
return FluxTransformerTests().prepare_init_args_and_inputs_for_common()
@property
def alternate_ckpt_paths(self):
return ["https://huggingface.co/Comfy-Org/flux1-dev/blob/main/flux1-dev-fp8.safetensors"]
def prepare_dummy_input(self, height, width):
return FluxTransformerTests().prepare_dummy_input(height=height, width=width)
@property
def pretrained_model_name_or_path(self):
return "black-forest-labs/FLUX.1-dev"
class TestFluxTransformerBitsAndBytes(FluxTransformerTesterConfig, BitsAndBytesTesterMixin):
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
return {
"hidden_states": randn_tensor((1, 4096, 64), device=torch_device),
"encoder_hidden_states": randn_tensor((1, 512, 4096), device=torch_device),
"pooled_projections": randn_tensor((1, 768), device=torch_device),
"timestep": torch.tensor([1.0]).to(torch_device),
"img_ids": randn_tensor((4096, 3), device=torch_device),
"txt_ids": randn_tensor((512, 3), device=torch_device),
"guidance": torch.tensor([3.5]).to(torch_device),
}
class TestFluxTransformerQuanto(FluxTransformerTesterConfig, QuantoTesterMixin):
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
return {
"hidden_states": randn_tensor((1, 4096, 64), device=torch_device),
"encoder_hidden_states": randn_tensor((1, 512, 4096), device=torch_device),
"pooled_projections": randn_tensor((1, 768), device=torch_device),
"timestep": torch.tensor([1.0]).to(torch_device),
"img_ids": randn_tensor((4096, 3), device=torch_device),
"txt_ids": randn_tensor((512, 3), device=torch_device),
"guidance": torch.tensor([3.5]).to(torch_device),
}
class TestFluxTransformerTorchAo(FluxTransformerTesterConfig, TorchAoTesterMixin):
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
return {
"hidden_states": randn_tensor((1, 4096, 64), device=torch_device),
"encoder_hidden_states": randn_tensor((1, 512, 4096), device=torch_device),
"pooled_projections": randn_tensor((1, 768), device=torch_device),
"timestep": torch.tensor([1.0]).to(torch_device),
"img_ids": randn_tensor((4096, 3), device=torch_device),
"txt_ids": randn_tensor((512, 3), device=torch_device),
"guidance": torch.tensor([3.5]).to(torch_device),
}
class TestFluxTransformerGGUF(FluxTransformerTesterConfig, GGUFTesterMixin):
@property
def gguf_filename(self):
return "https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q8_0.gguf"
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
return {
"hidden_states": randn_tensor((1, 4096, 64), device=torch_device),
"encoder_hidden_states": randn_tensor((1, 512, 4096), device=torch_device),
"pooled_projections": randn_tensor((1, 768), device=torch_device),
"timestep": torch.tensor([1.0]).to(torch_device),
"img_ids": randn_tensor((4096, 3), device=torch_device),
"txt_ids": randn_tensor((512, 3), device=torch_device),
"guidance": torch.tensor([3.5]).to(torch_device),
}
class TestFluxTransformerModelOpt(FluxTransformerTesterConfig, ModelOptTesterMixin):
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
return {
"hidden_states": randn_tensor((1, 4096, 64), device=torch_device),
"encoder_hidden_states": randn_tensor((1, 512, 4096), device=torch_device),
"pooled_projections": randn_tensor((1, 768), device=torch_device),
"timestep": torch.tensor([1.0]).to(torch_device),
"img_ids": randn_tensor((4096, 3), device=torch_device),
"txt_ids": randn_tensor((512, 3), device=torch_device),
"guidance": torch.tensor([3.5]).to(torch_device),
}
class TestFluxTransformerBitsAndBytesCompile(FluxTransformerTesterConfig, BitsAndBytesCompileTesterMixin):
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
return {
"hidden_states": randn_tensor((1, 4096, 64), device=torch_device),
"encoder_hidden_states": randn_tensor((1, 512, 4096), device=torch_device),
"pooled_projections": randn_tensor((1, 768), device=torch_device),
"timestep": torch.tensor([1.0]).to(torch_device),
"img_ids": randn_tensor((4096, 3), device=torch_device),
"txt_ids": randn_tensor((512, 3), device=torch_device),
"guidance": torch.tensor([3.5]).to(torch_device),
}
class TestFluxTransformerQuantoCompile(FluxTransformerTesterConfig, QuantoCompileTesterMixin):
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
return {
"hidden_states": randn_tensor((1, 4096, 64), device=torch_device),
"encoder_hidden_states": randn_tensor((1, 512, 4096), device=torch_device),
"pooled_projections": randn_tensor((1, 768), device=torch_device),
"timestep": torch.tensor([1.0]).to(torch_device),
"img_ids": randn_tensor((4096, 3), device=torch_device),
"txt_ids": randn_tensor((512, 3), device=torch_device),
"guidance": torch.tensor([3.5]).to(torch_device),
}
class TestFluxTransformerTorchAoCompile(FluxTransformerTesterConfig, TorchAoCompileTesterMixin):
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
return {
"hidden_states": randn_tensor((1, 4096, 64), device=torch_device),
"encoder_hidden_states": randn_tensor((1, 512, 4096), device=torch_device),
"pooled_projections": randn_tensor((1, 768), device=torch_device),
"timestep": torch.tensor([1.0]).to(torch_device),
"img_ids": randn_tensor((4096, 3), device=torch_device),
"txt_ids": randn_tensor((512, 3), device=torch_device),
"guidance": torch.tensor([3.5]).to(torch_device),
}
class TestFluxTransformerGGUFCompile(FluxTransformerTesterConfig, GGUFCompileTesterMixin):
@property
def gguf_filename(self):
return "https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q8_0.gguf"
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
return {
"hidden_states": randn_tensor((1, 4096, 64), device=torch_device),
"encoder_hidden_states": randn_tensor((1, 512, 4096), device=torch_device),
"pooled_projections": randn_tensor((1, 768), device=torch_device),
"timestep": torch.tensor([1.0]).to(torch_device),
"img_ids": randn_tensor((4096, 3), device=torch_device),
"txt_ids": randn_tensor((512, 3), device=torch_device),
"guidance": torch.tensor([3.5]).to(torch_device),
}
class TestFluxTransformerModelOptCompile(FluxTransformerTesterConfig, ModelOptCompileTesterMixin):
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
return {
"hidden_states": randn_tensor((1, 4096, 64), device=torch_device),
"encoder_hidden_states": randn_tensor((1, 512, 4096), device=torch_device),
"pooled_projections": randn_tensor((1, 768), device=torch_device),
"timestep": torch.tensor([1.0]).to(torch_device),
"img_ids": randn_tensor((4096, 3), device=torch_device),
"txt_ids": randn_tensor((512, 3), device=torch_device),
"guidance": torch.tensor([3.5]).to(torch_device),
}
class TestFluxTransformerPABCache(FluxTransformerTesterConfig, PyramidAttentionBroadcastTesterMixin):
"""PyramidAttentionBroadcast cache tests for Flux Transformer."""
class TestFluxTransformerFBCCache(FluxTransformerTesterConfig, FirstBlockCacheTesterMixin):
"""FirstBlockCache tests for Flux Transformer."""
class TestFluxTransformerFasterCache(FluxTransformerTesterConfig, FasterCacheTesterMixin):
"""FasterCache tests for Flux Transformer."""
# Flux is guidance distilled, so we can test at model level without CFG batch handling
FASTER_CACHE_CONFIG = {
"spatial_attention_block_skip_range": 2,
"spatial_attention_timestep_skip_range": (-1, 901),
"tensor_format": "BCHW",
"is_guidance_distilled": True,
}

View File

@@ -0,0 +1,183 @@
import unittest
import numpy as np
import torch
from PIL import Image
from transformers import Qwen2TokenizerFast, Qwen3Config, Qwen3ForCausalLM
from diffusers import (
AutoencoderKLFlux2,
FlowMatchEulerDiscreteScheduler,
Flux2KleinPipeline,
Flux2Transformer2DModel,
)
from ...testing_utils import torch_device
from ..test_pipelines_common import PipelineTesterMixin, check_qkv_fused_layers_exist
class Flux2KleinPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = Flux2KleinPipeline
params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds"])
batch_params = frozenset(["prompt"])
test_xformers_attention = False
test_layerwise_casting = True
test_group_offloading = True
supports_dduf = False
def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1):
torch.manual_seed(0)
transformer = Flux2Transformer2DModel(
patch_size=1,
in_channels=4,
num_layers=num_layers,
num_single_layers=num_single_layers,
attention_head_dim=16,
num_attention_heads=2,
joint_attention_dim=16,
timestep_guidance_channels=256,
axes_dims_rope=[4, 4, 4, 4],
guidance_embeds=False,
)
# Create minimal Qwen3 config
config = Qwen3Config(
intermediate_size=16,
hidden_size=16,
num_hidden_layers=2,
num_attention_heads=2,
num_key_value_heads=2,
vocab_size=151936,
max_position_embeddings=512,
)
torch.manual_seed(0)
text_encoder = Qwen3ForCausalLM(config)
# Use a simple tokenizer for testing
tokenizer = Qwen2TokenizerFast.from_pretrained(
"hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration"
)
torch.manual_seed(0)
vae = AutoencoderKLFlux2(
sample_size=32,
in_channels=3,
out_channels=3,
down_block_types=("DownEncoderBlock2D",),
up_block_types=("UpDecoderBlock2D",),
block_out_channels=(4,),
layers_per_block=1,
latent_channels=1,
norm_num_groups=1,
use_quant_conv=False,
use_post_quant_conv=False,
)
scheduler = FlowMatchEulerDiscreteScheduler()
return {
"scheduler": scheduler,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"transformer": transformer,
"vae": vae,
}
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device="cpu").manual_seed(seed)
inputs = {
"prompt": "a dog is dancing",
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 4.0,
"height": 8,
"width": 8,
"max_sequence_length": 64,
"output_type": "np",
"text_encoder_out_layers": (1,),
}
return inputs
def test_fused_qkv_projections(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
image = pipe(**inputs).images
original_image_slice = image[0, -3:, -3:, -1]
pipe.transformer.fuse_qkv_projections()
self.assertTrue(
check_qkv_fused_layers_exist(pipe.transformer, ["to_qkv"]),
("Something wrong with the fused attention layers. Expected all the attention projections to be fused."),
)
inputs = self.get_dummy_inputs(device)
image = pipe(**inputs).images
image_slice_fused = image[0, -3:, -3:, -1]
pipe.transformer.unfuse_qkv_projections()
inputs = self.get_dummy_inputs(device)
image = pipe(**inputs).images
image_slice_disabled = image[0, -3:, -3:, -1]
self.assertTrue(
np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3),
("Fusion of QKV projections shouldn't affect the outputs."),
)
self.assertTrue(
np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3),
("Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."),
)
self.assertTrue(
np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2),
("Original outputs should match when fused QKV projections are disabled."),
)
def test_image_output_shape(self):
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
inputs = self.get_dummy_inputs(torch_device)
height_width_pairs = [(32, 32), (72, 57)]
for height, width in height_width_pairs:
expected_height = height - height % (pipe.vae_scale_factor * 2)
expected_width = width - width % (pipe.vae_scale_factor * 2)
inputs.update({"height": height, "width": width})
image = pipe(**inputs).images[0]
output_height, output_width, _ = image.shape
self.assertEqual(
(output_height, output_width),
(expected_height, expected_width),
f"Output shape {image.shape} does not match expected shape {(expected_height, expected_width)}",
)
def test_image_input(self):
device = "cpu"
pipe = self.pipeline_class(**self.get_dummy_components()).to(device)
inputs = self.get_dummy_inputs(device)
inputs["image"] = Image.new("RGB", (64, 64))
image = pipe(**inputs).images.flatten()
generated_slice = np.concatenate([image[:8], image[-8:]])
# fmt: off
expected_slice = np.array(
[
0.8255048 , 0.66054785, 0.6643694 , 0.67462724, 0.5494932 , 0.3480271 , 0.52535003, 0.44510138, 0.23549396, 0.21372932, 0.21166152, 0.63198495, 0.49942136, 0.39147034, 0.49156153, 0.3713916
]
)
# fmt: on
assert np.allclose(expected_slice, generated_slice, atol=1e-4, rtol=1e-4)
@unittest.skip("Needs to be revisited")
def test_encode_prompt_works_in_isolation(self):
pass

View File

@@ -38,6 +38,7 @@ from diffusers.utils.import_utils import (
is_gguf_available,
is_kernels_available,
is_note_seq_available,
is_nvidia_modelopt_available,
is_onnx_available,
is_opencv_available,
is_optimum_quanto_available,
@@ -130,6 +131,59 @@ def torch_all_close(a, b, *args, **kwargs):
return True
def assert_tensors_close(
actual: "torch.Tensor",
expected: "torch.Tensor",
atol: float = 1e-5,
rtol: float = 1e-5,
msg: str = "",
) -> None:
"""
Assert that two tensors are close within tolerance.
Uses the same formula as torch.allclose: |actual - expected| <= atol + rtol * |expected|
Provides concise, actionable error messages without dumping full tensors.
Args:
actual: The actual tensor from the computation.
expected: The expected tensor to compare against.
atol: Absolute tolerance.
rtol: Relative tolerance.
msg: Optional message prefix for the assertion error.
Raises:
AssertionError: If tensors have different shapes or values exceed tolerance.
Example:
>>> assert_tensors_close(output, expected_output, atol=1e-5, rtol=1e-5, msg="Forward pass")
"""
if not is_torch_available():
raise ValueError("PyTorch needs to be installed to use this function.")
if actual.shape != expected.shape:
raise AssertionError(f"{msg} Shape mismatch: actual {actual.shape} vs expected {expected.shape}")
if not torch.allclose(actual, expected, atol=atol, rtol=rtol):
abs_diff = (actual - expected).abs()
max_diff = abs_diff.max().item()
flat_idx = abs_diff.argmax().item()
max_idx = tuple(torch.unravel_index(torch.tensor(flat_idx), actual.shape).tolist())
threshold = atol + rtol * expected.abs()
mismatched = (abs_diff > threshold).sum().item()
total = actual.numel()
raise AssertionError(
f"{msg}\n"
f"Tensors not close! Mismatched elements: {mismatched}/{total} ({100 * mismatched / total:.1f}%)\n"
f" Max diff: {max_diff:.6e} at index {max_idx}\n"
f" Actual: {actual.flatten()[flat_idx].item():.6e}\n"
f" Expected: {expected.flatten()[flat_idx].item():.6e}\n"
f" atol: {atol:.6e}, rtol: {rtol:.6e}"
)
def numpy_cosine_similarity_distance(a, b):
similarity = np.dot(a, b) / (norm(a) * norm(b))
distance = 1.0 - similarity.mean()
@@ -241,7 +295,6 @@ def parse_flag_from_env(key, default=False):
_run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False)
_run_nightly_tests = parse_flag_from_env("RUN_NIGHTLY", default=False)
_run_compile_tests = parse_flag_from_env("RUN_COMPILE", default=False)
def floats_tensor(shape, scale=1.0, rng=None, name=None):
@@ -282,12 +335,155 @@ def nightly(test_case):
def is_torch_compile(test_case):
"""
Decorator marking a test that runs compile tests in the diffusers CI.
Compile tests are skipped by default. Set the RUN_COMPILE environment variable to a truthy value to run them.
Decorator marking a test as a torch.compile test. These tests can be filtered using:
pytest -m "not compile" to skip
pytest -m compile to run only these tests
"""
return pytest.mark.skipif(not _run_compile_tests, reason="test is torch compile")(test_case)
return pytest.mark.compile(test_case)
def is_single_file(test_case):
"""
Decorator marking a test as a single file loading test. These tests can be filtered using:
pytest -m "not single_file" to skip
pytest -m single_file to run only these tests
"""
return pytest.mark.single_file(test_case)
def is_lora(test_case):
"""
Decorator marking a test as a LoRA test. These tests can be filtered using:
pytest -m "not lora" to skip
pytest -m lora to run only these tests
"""
return pytest.mark.lora(test_case)
def is_ip_adapter(test_case):
"""
Decorator marking a test as an IP Adapter test. These tests can be filtered using:
pytest -m "not ip_adapter" to skip
pytest -m ip_adapter to run only these tests
"""
return pytest.mark.ip_adapter(test_case)
def is_training(test_case):
"""
Decorator marking a test as a training test. These tests can be filtered using:
pytest -m "not training" to skip
pytest -m training to run only these tests
"""
return pytest.mark.training(test_case)
def is_attention(test_case):
"""
Decorator marking a test as an attention test. These tests can be filtered using:
pytest -m "not attention" to skip
pytest -m attention to run only these tests
"""
return pytest.mark.attention(test_case)
def is_memory(test_case):
"""
Decorator marking a test as a memory optimization test. These tests can be filtered using:
pytest -m "not memory" to skip
pytest -m memory to run only these tests
"""
return pytest.mark.memory(test_case)
def is_cpu_offload(test_case):
"""
Decorator marking a test as a CPU offload test. These tests can be filtered using:
pytest -m "not cpu_offload" to skip
pytest -m cpu_offload to run only these tests
"""
return pytest.mark.cpu_offload(test_case)
def is_group_offload(test_case):
"""
Decorator marking a test as a group offload test. These tests can be filtered using:
pytest -m "not group_offload" to skip
pytest -m group_offload to run only these tests
"""
return pytest.mark.group_offload(test_case)
def is_quantization(test_case):
"""
Decorator marking a test as a quantization test. These tests can be filtered using:
pytest -m "not quantization" to skip
pytest -m quantization to run only these tests
"""
return pytest.mark.quantization(test_case)
def is_bitsandbytes(test_case):
"""
Decorator marking a test as a BitsAndBytes quantization test. These tests can be filtered using:
pytest -m "not bitsandbytes" to skip
pytest -m bitsandbytes to run only these tests
"""
return pytest.mark.bitsandbytes(test_case)
def is_quanto(test_case):
"""
Decorator marking a test as a Quanto quantization test. These tests can be filtered using:
pytest -m "not quanto" to skip
pytest -m quanto to run only these tests
"""
return pytest.mark.quanto(test_case)
def is_torchao(test_case):
"""
Decorator marking a test as a TorchAO quantization test. These tests can be filtered using:
pytest -m "not torchao" to skip
pytest -m torchao to run only these tests
"""
return pytest.mark.torchao(test_case)
def is_gguf(test_case):
"""
Decorator marking a test as a GGUF quantization test. These tests can be filtered using:
pytest -m "not gguf" to skip
pytest -m gguf to run only these tests
"""
return pytest.mark.gguf(test_case)
def is_modelopt(test_case):
"""
Decorator marking a test as a NVIDIA ModelOpt quantization test. These tests can be filtered using:
pytest -m "not modelopt" to skip
pytest -m modelopt to run only these tests
"""
return pytest.mark.modelopt(test_case)
def is_context_parallel(test_case):
"""
Decorator marking a test as a context parallel inference test. These tests can be filtered using:
pytest -m "not context_parallel" to skip
pytest -m context_parallel to run only these tests
"""
return pytest.mark.context_parallel(test_case)
def is_cache(test_case):
"""
Decorator marking a test as a cache test. These tests can be filtered using:
pytest -m "not cache" to skip
pytest -m cache to run only these tests
"""
return pytest.mark.cache(test_case)
def require_torch(test_case):
@@ -650,6 +846,19 @@ def require_kernels_version_greater_or_equal(kernels_version):
return decorator
def require_modelopt_version_greater_or_equal(modelopt_version):
def decorator(test_case):
correct_nvidia_modelopt_version = is_nvidia_modelopt_available() and version.parse(
version.parse(importlib.metadata.version("modelopt")).base_version
) >= version.parse(modelopt_version)
return pytest.mark.skipif(
not correct_nvidia_modelopt_version,
reason=f"Test requires modelopt with version greater than {modelopt_version}.",
)(test_case)
return decorator
def deprecate_after_peft_backend(test_case):
"""
Decorator marking a test that will be skipped after PEFT backend

View File

@@ -0,0 +1,592 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Utility script to generate test suites for diffusers model classes.
Usage:
python utils/generate_model_tests.py src/diffusers/models/transformers/transformer_flux.py
This will analyze the model file and generate a test file with appropriate
test classes based on the model's mixins and attributes.
"""
import argparse
import ast
import sys
from pathlib import Path
MIXIN_TO_TESTER = {
"ModelMixin": "ModelTesterMixin",
"PeftAdapterMixin": "LoraTesterMixin",
}
ATTRIBUTE_TO_TESTER = {
"_cp_plan": "ContextParallelTesterMixin",
"_supports_gradient_checkpointing": "TrainingTesterMixin",
}
ALWAYS_INCLUDE_TESTERS = [
"ModelTesterMixin",
"MemoryTesterMixin",
"TorchCompileTesterMixin",
]
# Attention-related class names that indicate the model uses attention
ATTENTION_INDICATORS = {
"AttentionMixin",
"AttentionModuleMixin",
}
OPTIONAL_TESTERS = [
# Quantization testers
("BitsAndBytesTesterMixin", "bnb"),
("QuantoTesterMixin", "quanto"),
("TorchAoTesterMixin", "torchao"),
("GGUFTesterMixin", "gguf"),
("ModelOptTesterMixin", "modelopt"),
# Quantization compile testers
("BitsAndBytesCompileTesterMixin", "bnb_compile"),
("QuantoCompileTesterMixin", "quanto_compile"),
("TorchAoCompileTesterMixin", "torchao_compile"),
("GGUFCompileTesterMixin", "gguf_compile"),
("ModelOptCompileTesterMixin", "modelopt_compile"),
# Cache testers
("PyramidAttentionBroadcastTesterMixin", "pab_cache"),
("FirstBlockCacheTesterMixin", "fbc_cache"),
("FasterCacheTesterMixin", "faster_cache"),
# Other testers
("SingleFileTesterMixin", "single_file"),
("IPAdapterTesterMixin", "ip_adapter"),
]
class ModelAnalyzer(ast.NodeVisitor):
def __init__(self):
self.model_classes = []
self.current_class = None
self.imports = set()
def visit_Import(self, node: ast.Import):
for alias in node.names:
self.imports.add(alias.name.split(".")[-1])
self.generic_visit(node)
def visit_ImportFrom(self, node: ast.ImportFrom):
for alias in node.names:
self.imports.add(alias.name)
self.generic_visit(node)
def visit_ClassDef(self, node: ast.ClassDef):
base_names = []
for base in node.bases:
if isinstance(base, ast.Name):
base_names.append(base.id)
elif isinstance(base, ast.Attribute):
base_names.append(base.attr)
if "ModelMixin" in base_names:
class_info = {
"name": node.name,
"bases": base_names,
"attributes": {},
"has_forward": False,
"init_params": [],
}
for item in node.body:
if isinstance(item, ast.Assign):
for target in item.targets:
if isinstance(target, ast.Name):
attr_name = target.id
if attr_name.startswith("_"):
class_info["attributes"][attr_name] = self._get_value(item.value)
elif isinstance(item, ast.FunctionDef):
if item.name == "forward":
class_info["has_forward"] = True
class_info["forward_params"] = self._extract_func_params(item)
elif item.name == "__init__":
class_info["init_params"] = self._extract_func_params(item)
self.model_classes.append(class_info)
self.generic_visit(node)
def _extract_func_params(self, func_node: ast.FunctionDef) -> list[dict]:
params = []
args = func_node.args
num_defaults = len(args.defaults)
num_args = len(args.args)
first_default_idx = num_args - num_defaults
for i, arg in enumerate(args.args):
if arg.arg == "self":
continue
param_info = {"name": arg.arg, "type": None, "default": None}
if arg.annotation:
param_info["type"] = self._get_annotation_str(arg.annotation)
default_idx = i - first_default_idx
if default_idx >= 0 and default_idx < len(args.defaults):
param_info["default"] = self._get_value(args.defaults[default_idx])
params.append(param_info)
return params
def _get_annotation_str(self, node) -> str:
if isinstance(node, ast.Name):
return node.id
elif isinstance(node, ast.Constant):
return repr(node.value)
elif isinstance(node, ast.Subscript):
base = self._get_annotation_str(node.value)
if isinstance(node.slice, ast.Tuple):
args = ", ".join(self._get_annotation_str(el) for el in node.slice.elts)
else:
args = self._get_annotation_str(node.slice)
return f"{base}[{args}]"
elif isinstance(node, ast.Attribute):
return f"{self._get_annotation_str(node.value)}.{node.attr}"
elif isinstance(node, ast.BinOp) and isinstance(node.op, ast.BitOr):
left = self._get_annotation_str(node.left)
right = self._get_annotation_str(node.right)
return f"{left} | {right}"
elif isinstance(node, ast.Tuple):
return ", ".join(self._get_annotation_str(el) for el in node.elts)
return "Any"
def _get_value(self, node):
if isinstance(node, ast.Constant):
return node.value
elif isinstance(node, ast.Name):
if node.id == "None":
return None
elif node.id == "True":
return True
elif node.id == "False":
return False
return node.id
elif isinstance(node, ast.List):
return [self._get_value(el) for el in node.elts]
elif isinstance(node, ast.Dict):
return {self._get_value(k): self._get_value(v) for k, v in zip(node.keys, node.values)}
return "<complex>"
def analyze_model_file(filepath: str) -> tuple[list[dict], set[str]]:
with open(filepath) as f:
source = f.read()
tree = ast.parse(source)
analyzer = ModelAnalyzer()
analyzer.visit(tree)
return analyzer.model_classes, analyzer.imports
def determine_testers(model_info: dict, include_optional: list[str], imports: set[str]) -> list[str]:
testers = list(ALWAYS_INCLUDE_TESTERS)
for base in model_info["bases"]:
if base in MIXIN_TO_TESTER:
tester = MIXIN_TO_TESTER[base]
if tester not in testers:
testers.append(tester)
for attr, tester in ATTRIBUTE_TO_TESTER.items():
if attr in model_info["attributes"]:
value = model_info["attributes"][attr]
if value is not None and value is not False:
if tester not in testers:
testers.append(tester)
if "_cp_plan" in model_info["attributes"] and model_info["attributes"]["_cp_plan"] is not None:
if "ContextParallelTesterMixin" not in testers:
testers.append("ContextParallelTesterMixin")
# Include AttentionTesterMixin if the model imports attention-related classes
if imports & ATTENTION_INDICATORS:
testers.append("AttentionTesterMixin")
for tester, flag in OPTIONAL_TESTERS:
if flag in include_optional:
if tester not in testers:
testers.append(tester)
return testers
def generate_config_class(model_info: dict, model_name: str) -> str:
class_name = f"{model_name}TesterConfig"
model_class = model_info["name"]
forward_params = model_info.get("forward_params", [])
init_params = model_info.get("init_params", [])
lines = [
f"class {class_name}:",
" @property",
" def model_class(self):",
f" return {model_class}",
"",
" @property",
" def pretrained_model_name_or_path(self):",
' return "" # TODO: Set Hub repository ID',
"",
" @property",
" def pretrained_model_kwargs(self):",
' return {"subfolder": "transformer"}',
"",
" @property",
" def generator(self):",
' return torch.Generator("cpu").manual_seed(0)',
"",
" def get_init_dict(self) -> dict[str, int | list[int]]:",
]
if init_params:
lines.append(" # __init__ parameters:")
for param in init_params:
type_str = f": {param['type']}" if param["type"] else ""
default_str = f" = {param['default']}" if param["default"] is not None else ""
lines.append(f" # {param['name']}{type_str}{default_str}")
lines.extend(
[
" return {}",
"",
" def get_dummy_inputs(self) -> dict[str, torch.Tensor]:",
]
)
if forward_params:
lines.append(" # forward() parameters:")
for param in forward_params:
type_str = f": {param['type']}" if param["type"] else ""
default_str = f" = {param['default']}" if param["default"] is not None else ""
lines.append(f" # {param['name']}{type_str}{default_str}")
lines.extend(
[
" # TODO: Fill in dummy inputs",
" return {}",
"",
" @property",
" def input_shape(self) -> tuple[int, ...]:",
" return (1, 1)",
"",
" @property",
" def output_shape(self) -> tuple[int, ...]:",
" return (1, 1)",
]
)
return "\n".join(lines)
def generate_test_class(model_name: str, config_class: str, tester: str) -> str:
tester_short = tester.replace("TesterMixin", "")
class_name = f"Test{model_name}{tester_short}"
lines = [f"class {class_name}({config_class}, {tester}):"]
if tester == "TorchCompileTesterMixin":
lines.extend(
[
" @property",
" def different_shapes_for_compilation(self):",
" return [(4, 4), (4, 8), (8, 8)]",
"",
" def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:",
" # TODO: Implement dynamic input generation",
" return {}",
]
)
elif tester == "IPAdapterTesterMixin":
lines.extend(
[
" @property",
" def ip_adapter_processor_cls(self):",
" return None # TODO: Set processor class",
"",
" def modify_inputs_for_ip_adapter(self, model, inputs_dict):",
" # TODO: Add IP adapter image embeds to inputs",
" return inputs_dict",
"",
" def create_ip_adapter_state_dict(self, model):",
" # TODO: Create IP adapter state dict",
" return {}",
]
)
elif tester == "SingleFileTesterMixin":
lines.extend(
[
" @property",
" def ckpt_path(self):",
' return "" # TODO: Set checkpoint path',
"",
" @property",
" def alternate_ckpt_paths(self):",
" return []",
"",
" @property",
" def pretrained_model_name_or_path(self):",
' return "" # TODO: Set Hub repository ID',
]
)
elif tester == "GGUFTesterMixin":
lines.extend(
[
" @property",
" def gguf_filename(self):",
' return "" # TODO: Set GGUF filename',
"",
" def get_dummy_inputs(self) -> dict[str, torch.Tensor]:",
" # TODO: Override with larger inputs for quantization tests",
" return {}",
]
)
elif tester in ["BitsAndBytesTesterMixin", "QuantoTesterMixin", "TorchAoTesterMixin", "ModelOptTesterMixin"]:
lines.extend(
[
" def get_dummy_inputs(self) -> dict[str, torch.Tensor]:",
" # TODO: Override with larger inputs for quantization tests",
" return {}",
]
)
elif tester in [
"BitsAndBytesCompileTesterMixin",
"QuantoCompileTesterMixin",
"TorchAoCompileTesterMixin",
"ModelOptCompileTesterMixin",
]:
lines.extend(
[
" def get_dummy_inputs(self) -> dict[str, torch.Tensor]:",
" # TODO: Override with larger inputs for quantization compile tests",
" return {}",
]
)
elif tester == "GGUFCompileTesterMixin":
lines.extend(
[
" @property",
" def gguf_filename(self):",
' return "" # TODO: Set GGUF filename',
"",
" def get_dummy_inputs(self) -> dict[str, torch.Tensor]:",
" # TODO: Override with larger inputs for quantization compile tests",
" return {}",
]
)
elif tester in [
"PyramidAttentionBroadcastTesterMixin",
"FirstBlockCacheTesterMixin",
"FasterCacheTesterMixin",
]:
lines.append(" pass")
elif tester == "LoraHotSwappingForModelTesterMixin":
lines.extend(
[
" @property",
" def different_shapes_for_compilation(self):",
" return [(4, 4), (4, 8), (8, 8)]",
"",
" def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:",
" # TODO: Implement dynamic input generation",
" return {}",
]
)
else:
lines.append(" pass")
return "\n".join(lines)
def generate_test_file(model_info: dict, model_filepath: str, include_optional: list[str], imports: set[str]) -> str:
model_name = model_info["name"].replace("2DModel", "").replace("3DModel", "").replace("Model", "")
testers = determine_testers(model_info, include_optional, imports)
tester_imports = sorted(set(testers) - {"LoraHotSwappingForModelTesterMixin"})
lines = [
"# coding=utf-8",
"# Copyright 2025 HuggingFace Inc.",
"#",
'# Licensed under the Apache License, Version 2.0 (the "License");',
"# you may not use this file except in compliance with the License.",
"# You may obtain a copy of the License at",
"#",
"# http://www.apache.org/licenses/LICENSE-2.0",
"#",
"# Unless required by applicable law or agreed to in writing, software",
'# distributed under the License is distributed on an "AS IS" BASIS,',
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.",
"# See the License for the specific language governing permissions and",
"# limitations under the License.",
"",
"import torch",
"",
f"from diffusers import {model_info['name']}",
"from diffusers.utils.torch_utils import randn_tensor",
"",
"from ...testing_utils import enable_full_determinism, torch_device",
]
if "LoraTesterMixin" in testers:
lines.append("from ..test_modeling_common import LoraHotSwappingForModelTesterMixin")
lines.extend(
[
"from ..testing_utils import (",
*[f" {tester}," for tester in sorted(tester_imports)],
")",
"",
"",
"enable_full_determinism()",
"",
"",
]
)
config_class = f"{model_name}TesterConfig"
lines.append(generate_config_class(model_info, model_name))
lines.append("")
lines.append("")
for tester in testers:
lines.append(generate_test_class(model_name, config_class, tester))
lines.append("")
lines.append("")
if "LoraTesterMixin" in testers:
lines.append(generate_test_class(model_name, config_class, "LoraHotSwappingForModelTesterMixin"))
lines.append("")
lines.append("")
return "\n".join(lines).rstrip() + "\n"
def get_test_output_path(model_filepath: str) -> str:
path = Path(model_filepath)
model_filename = path.stem
if "transformers" in path.parts:
return f"tests/models/transformers/test_models_{model_filename}.py"
elif "unets" in path.parts:
return f"tests/models/unets/test_models_{model_filename}.py"
elif "autoencoders" in path.parts:
return f"tests/models/autoencoders/test_models_{model_filename}.py"
else:
return f"tests/models/test_models_{model_filename}.py"
def main():
parser = argparse.ArgumentParser(description="Generate test suite for a diffusers model class")
parser.add_argument(
"model_filepath",
type=str,
help="Path to the model file (e.g., src/diffusers/models/transformers/transformer_flux.py)",
)
parser.add_argument(
"--output", "-o", type=str, default=None, help="Output file path (default: auto-generated based on model path)"
)
parser.add_argument(
"--include",
"-i",
type=str,
nargs="*",
default=[],
choices=[
"bnb",
"quanto",
"torchao",
"gguf",
"modelopt",
"bnb_compile",
"quanto_compile",
"torchao_compile",
"gguf_compile",
"modelopt_compile",
"pab_cache",
"fbc_cache",
"faster_cache",
"single_file",
"ip_adapter",
"all",
],
help="Optional testers to include",
)
parser.add_argument(
"--class-name",
"-c",
type=str,
default=None,
help="Specific model class to generate tests for (default: first model class found)",
)
parser.add_argument("--dry-run", action="store_true", help="Print generated code without writing to file")
args = parser.parse_args()
if not Path(args.model_filepath).exists():
print(f"Error: File not found: {args.model_filepath}", file=sys.stderr)
sys.exit(1)
model_classes, imports = analyze_model_file(args.model_filepath)
if not model_classes:
print(f"Error: No model classes found in {args.model_filepath}", file=sys.stderr)
sys.exit(1)
if args.class_name:
model_info = next((m for m in model_classes if m["name"] == args.class_name), None)
if not model_info:
available = [m["name"] for m in model_classes]
print(f"Error: Class '{args.class_name}' not found. Available: {available}", file=sys.stderr)
sys.exit(1)
else:
model_info = model_classes[0]
if len(model_classes) > 1:
print(f"Multiple model classes found, using: {model_info['name']}", file=sys.stderr)
print("Use --class-name to specify a different class", file=sys.stderr)
include_optional = args.include
if "all" in include_optional:
include_optional = [flag for _, flag in OPTIONAL_TESTERS]
generated_code = generate_test_file(model_info, args.model_filepath, include_optional, imports)
if args.dry_run:
print(generated_code)
else:
output_path = args.output or get_test_output_path(args.model_filepath)
output_dir = Path(output_path).parent
output_dir.mkdir(parents=True, exist_ok=True)
with open(output_path, "w") as f:
f.write(generated_code)
print(f"Generated test file: {output_path}")
print(f"Model class: {model_info['name']}")
print(f"Detected attributes: {list(model_info['attributes'].keys())}")
if __name__ == "__main__":
main()