Compare commits

..

9 Commits

Author SHA1 Message Date
Sayak Paul
8cd0f86ad8 Merge branch 'main' into enable-cp-kernels 2026-02-09 16:28:41 +05:30
Sayak Paul
dab372dd27 Merge branch 'main' into enable-cp-kernels 2026-01-26 21:58:00 +08:00
Sayak Paul
79438572e0 Merge branch 'main' into enable-cp-kernels 2026-01-19 10:28:00 +05:30
sayakpaul
2268583f39 up 2026-01-11 20:05:26 +05:30
Sayak Paul
dfbd4857b2 Merge branch 'main' into enable-cp-kernels 2025-12-17 12:14:40 +08:00
Sayak Paul
9bd83616bf Merge branch 'main' into enable-cp-kernels 2025-12-10 12:33:18 +08:00
sayakpaul
f732ff1144 up 2025-12-09 15:30:33 +05:30
sayakpaul
7a8f85b047 up 2025-12-09 14:59:01 +05:30
sayakpaul
82d20e64a5 up 2025-12-09 14:39:07 +05:30
24 changed files with 572 additions and 3277 deletions

View File

@@ -29,31 +29,8 @@ text_encoder = AutoModel.from_pretrained(
)
```
## Custom models
[`AutoModel`] also loads models from the [Hub](https://huggingface.co/models) that aren't included in Diffusers. Set `trust_remote_code=True` in [`AutoModel.from_pretrained`] to load custom models.
A custom model repository needs a Python module with the model class, and a `config.json` with an `auto_map` entry that maps `"AutoModel"` to `"module_file.ClassName"`.
```
custom/custom-transformer-model/
├── config.json
├── my_model.py
└── diffusion_pytorch_model.safetensors
```
The `config.json` includes the `auto_map` field pointing to the custom class.
```json
{
"auto_map": {
"AutoModel": "my_model.MyCustomModel"
}
}
```
Then load it with `trust_remote_code=True`.
```py
import torch
from diffusers import AutoModel
@@ -63,39 +40,7 @@ transformer = AutoModel.from_pretrained(
)
```
For a real-world example, [Overworld/Waypoint-1-Small](https://huggingface.co/Overworld/Waypoint-1-Small/tree/main/transformer) hosts a custom `WorldModel` class across several modules in its `transformer` subfolder.
```
transformer/
├── config.json # auto_map: "model.WorldModel"
├── model.py
├── attn.py
├── nn.py
├── cache.py
├── quantize.py
├── __init__.py
└── diffusion_pytorch_model.safetensors
```
```py
import torch
from diffusers import AutoModel
transformer = AutoModel.from_pretrained(
"Overworld/Waypoint-1-Small", subfolder="transformer", trust_remote_code=True, torch_dtype=torch.bfloat16, device_map="cuda"
)
```
If the custom model inherits from the [`ModelMixin`] class, it gets access to the same features as Diffusers model classes, like [regional compilation](../optimization/fp16#regional-compilation) and [group offloading](../optimization/memory#group-offloading).
> [!WARNING]
> As a precaution with `trust_remote_code=True`, pass a commit hash to the `revision` argument in [`AutoModel.from_pretrained`] to make sure the code hasn't been updated with new malicious code (unless you fully trust the model owners).
>
> ```py
> transformer = AutoModel.from_pretrained(
> "Overworld/Waypoint-1-Small", subfolder="transformer", trust_remote_code=True, revision="a3d8cb2"
> )
> ```
> [!NOTE]
> Learn more about implementing custom models in the [Community components](../using-diffusers/custom_pipeline_overview#community-components) guide.

View File

@@ -1,347 +0,0 @@
# DreamBooth training example for Z-Image
[DreamBooth](https://huggingface.co/papers/2208.12242) is a method to personalize image generation models given just a few (3~5) images of a subject/concept.
[LoRA](https://huggingface.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora) is a popular parameter-efficient fine-tuning technique that allows you to achieve full-finetuning like performance but with a fraction of learnable parameters.
The `train_dreambooth_lora_z_image.py` script shows how to implement the training procedure for [LoRAs](https://huggingface.co/blog/lora) and adapt it for [Z-Image](https://huggingface.co/Tongyi-MAI/Z-Image).
> [!NOTE]
> **About Z-Image**
>
> Z-Image is a high-quality text-to-image generation model from Alibaba's Tongyi Lab. It uses a DiT (Diffusion Transformer) architecture with Qwen3 as the text encoder. The model excels at generating images with accurate text rendering, especially for Chinese characters.
> [!NOTE]
> **Memory consumption**
>
> Z-Image is relatively memory efficient compared to other large-scale diffusion models. Below we provide some tips and tricks to further reduce memory consumption during training.
## Running locally with PyTorch
### Installing the dependencies
Before running the scripts, make sure to install the library's training dependencies:
**Important**
To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
```bash
git clone https://github.com/huggingface/diffusers
cd diffusers
pip install -e .
```
Then cd in the `examples/dreambooth` folder and run
```bash
pip install -r requirements_z_image.txt
```
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
```bash
accelerate config
```
Or for a default accelerate configuration without answering questions about your environment
```bash
accelerate config default
```
Or if your environment doesn't support an interactive shell (e.g., a notebook)
```python
from accelerate.utils import write_basic_config
write_basic_config()
```
When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.
### Dog toy example
Now let's get our dataset. For this example we will use some dog images: https://huggingface.co/datasets/diffusers/dog-example.
Let's first download it locally:
```python
from huggingface_hub import snapshot_download
local_dir = "./dog"
snapshot_download(
"diffusers/dog-example",
local_dir=local_dir, repo_type="dataset",
ignore_patterns=".gitattributes",
)
```
This will also allow us to push the trained LoRA parameters to the Hugging Face Hub platform.
## Memory Optimizations
> [!NOTE]
> Many of these techniques complement each other and can be used together to further reduce memory consumption. However some techniques may be mutually exclusive so be sure to check before launching a training run.
### CPU Offloading
To offload parts of the model to CPU memory, you can use `--offload` flag. This will offload the VAE and text encoder to CPU memory and only move them to GPU when needed.
### Latent Caching
Pre-encode the training images with the VAE, and then delete it to free up some memory. To enable `latent_caching` simply pass `--cache_latents`.
### QLoRA: Low Precision Training with Quantization
Perform low precision training using 8-bit or 4-bit quantization to reduce memory usage. You can use the following flags:
- **FP8 training** with `torchao`:
Enable FP8 training by passing `--do_fp8_training`.
> [!IMPORTANT]
> Since we are utilizing FP8 tensor cores we need CUDA GPUs with compute capability at least 8.9 or greater. If you're looking for memory-efficient training on relatively older cards, we encourage you to check out other trainers.
- **NF4 training** with `bitsandbytes`:
Alternatively, you can use 8-bit or 4-bit quantization with `bitsandbytes` by passing `--bnb_quantization_config_path` to enable 4-bit NF4 quantization.
### Gradient Checkpointing and Accumulation
* `--gradient_accumulation` refers to the number of updates steps to accumulate before performing a backward/update pass. By passing a value > 1 you can reduce the amount of backward/update passes and hence also memory requirements.
* With `--gradient_checkpointing` we can save memory by not storing all intermediate activations during the forward pass. Instead, only a subset of these activations (the checkpoints) are stored and the rest is recomputed as needed during the backward pass. Note that this comes at the expense of a slower backward pass.
### 8-bit-Adam Optimizer
When training with `AdamW` (doesn't apply to `prodigy`) you can pass `--use_8bit_adam` to reduce the memory requirements of training. Make sure to install `bitsandbytes` if you want to do so.
### Image Resolution
An easy way to mitigate some of the memory requirements is through `--resolution`. `--resolution` refers to the resolution for input images, all the images in the train/validation dataset are resized to this.
Note that by default, images are resized to resolution of 1024, but it's good to keep in mind in case you're training on higher resolutions.
### Precision of saved LoRA layers
By default, trained transformer layers are saved in the precision dtype in which training was performed. E.g. when training in mixed precision is enabled with `--mixed_precision="bf16"`, final finetuned layers will be saved in `torch.bfloat16` as well.
This reduces memory requirements significantly without a significant quality loss. Note that if you do wish to save the final layers in float32 at the expense of more memory usage, you can do so by passing `--upcast_before_saving`.
## Training Examples
### Z-Image Training
To perform DreamBooth with LoRA on Z-Image, run:
```bash
export MODEL_NAME="Tongyi-MAI/Z-Image"
export INSTANCE_DIR="dog"
export OUTPUT_DIR="trained-z-image-lora"
accelerate launch train_dreambooth_lora_z_image.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--instance_data_dir=$INSTANCE_DIR \
--output_dir=$OUTPUT_DIR \
--mixed_precision="bf16" \
--gradient_checkpointing \
--cache_latents \
--instance_prompt="a photo of sks dog" \
--resolution=1024 \
--train_batch_size=1 \
--guidance_scale=5.0 \
--use_8bit_adam \
--gradient_accumulation_steps=4 \
--optimizer="adamW" \
--learning_rate=1e-4 \
--report_to="wandb" \
--lr_scheduler="constant" \
--lr_warmup_steps=100 \
--max_train_steps=500 \
--validation_prompt="A photo of sks dog in a bucket" \
--validation_epochs=25 \
--seed="0" \
--push_to_hub
```
To better track our training experiments, we're using the following flags in the command above:
* `report_to="wandb"` will ensure the training runs are tracked on [Weights and Biases](https://wandb.ai/site). To use it, be sure to install `wandb` with `pip install wandb`. Don't forget to call `wandb login <your_api_key>` before training if you haven't done it before.
* `validation_prompt` and `validation_epochs` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.
> [!NOTE]
> If you want to train using long prompts, you can use `--max_sequence_length` to set the token limit. The default is 512. Note that this will use more resources and may slow down the training in some cases.
### Training with FP8 Quantization
For reduced memory usage with FP8 training:
```bash
export MODEL_NAME="Tongyi-MAI/Z-Image"
export INSTANCE_DIR="dog"
export OUTPUT_DIR="trained-z-image-lora-fp8"
accelerate launch train_dreambooth_lora_z_image.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--instance_data_dir=$INSTANCE_DIR \
--output_dir=$OUTPUT_DIR \
--do_fp8_training \
--gradient_checkpointing \
--cache_latents \
--instance_prompt="a photo of sks dog" \
--resolution=1024 \
--train_batch_size=1 \
--guidance_scale=5.0 \
--use_8bit_adam \
--gradient_accumulation_steps=4 \
--optimizer="adamW" \
--learning_rate=1e-4 \
--report_to="wandb" \
--lr_scheduler="constant" \
--lr_warmup_steps=100 \
--max_train_steps=500 \
--validation_prompt="A photo of sks dog in a bucket" \
--validation_epochs=25 \
--seed="0" \
--push_to_hub
```
### FSDP on the transformer
By setting the accelerate configuration with FSDP, the transformer block will be wrapped automatically. E.g. set the configuration to:
```yaml
distributed_type: FSDP
fsdp_config:
fsdp_version: 2
fsdp_offload_params: false
fsdp_sharding_strategy: HYBRID_SHARD
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: ZImageTransformerBlock
fsdp_forward_prefetch: true
fsdp_sync_module_states: false
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_use_orig_params: false
fsdp_activation_checkpointing: true
fsdp_reshard_after_forward: true
fsdp_cpu_ram_efficient_loading: false
```
### Prodigy Optimizer
Prodigy is an adaptive optimizer that dynamically adjusts the learning rate learned parameters based on past gradients, allowing for more efficient convergence.
By using prodigy we can "eliminate" the need for manual learning rate tuning. Read more [here](https://huggingface.co/blog/sdxl_lora_advanced_script#adaptive-optimizers).
To use prodigy, first make sure to install the prodigyopt library: `pip install prodigyopt`, and then specify:
```bash
--optimizer="prodigy"
```
> [!TIP]
> When using prodigy it's generally good practice to set `--learning_rate=1.0`
```bash
export MODEL_NAME="Tongyi-MAI/Z-Image"
export INSTANCE_DIR="dog"
export OUTPUT_DIR="trained-z-image-lora-prodigy"
accelerate launch train_dreambooth_lora_z_image.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--instance_data_dir=$INSTANCE_DIR \
--output_dir=$OUTPUT_DIR \
--mixed_precision="bf16" \
--gradient_checkpointing \
--cache_latents \
--instance_prompt="a photo of sks dog" \
--resolution=1024 \
--train_batch_size=1 \
--guidance_scale=5.0 \
--gradient_accumulation_steps=4 \
--optimizer="prodigy" \
--learning_rate=1.0 \
--report_to="wandb" \
--lr_scheduler="constant_with_warmup" \
--lr_warmup_steps=100 \
--max_train_steps=500 \
--validation_prompt="A photo of sks dog in a bucket" \
--validation_epochs=25 \
--seed="0" \
--push_to_hub
```
### LoRA Rank and Alpha
Two key LoRA hyperparameters are LoRA rank and LoRA alpha:
- `--rank`: Defines the dimension of the trainable LoRA matrices. A higher rank means more expressiveness and capacity to learn (and more parameters).
- `--lora_alpha`: A scaling factor for the LoRA's output. The LoRA update is scaled by `lora_alpha / lora_rank`.
**lora_alpha vs. rank:**
This ratio dictates the LoRA's effective strength:
- `lora_alpha == rank`: Scaling factor is 1. The LoRA is applied with its learned strength. (e.g., alpha=16, rank=16)
- `lora_alpha < rank`: Scaling factor < 1. Reduces the LoRA's impact. Useful for subtle changes or to prevent overpowering the base model. (e.g., alpha=8, rank=16)
- `lora_alpha > rank`: Scaling factor > 1. Amplifies the LoRA's impact. Allows a lower rank LoRA to have a stronger effect. (e.g., alpha=32, rank=16)
> [!TIP]
> A common starting point is to set `lora_alpha` equal to `rank`.
> Some also set `lora_alpha` to be twice the `rank` (e.g., lora_alpha=32 for lora_rank=16)
> to give the LoRA updates more influence without increasing parameter count.
> If you find your LoRA is "overcooking" or learning too aggressively, consider setting `lora_alpha` to half of `rank`
> (e.g., lora_alpha=8 for rank=16). Experimentation is often key to finding the optimal balance for your use case.
### Target Modules
When LoRA was first adapted from language models to diffusion models, it was applied to the cross-attention layers in the UNet that relate the image representations with the prompts that describe them.
More recently, SOTA text-to-image diffusion models replaced the UNet with a diffusion Transformer (DiT). With this change, we may also want to explore applying LoRA training onto different types of layers and blocks.
To allow more flexibility and control over the targeted modules we added `--lora_layers`, in which you can specify in a comma separated string the exact modules for LoRA training. Here are some examples of target modules you can provide:
- For attention only layers: `--lora_layers="to_k,to_q,to_v,to_out.0"`
- For attention and feed-forward layers: `--lora_layers="to_k,to_q,to_v,to_out.0,ff.net.0.proj,ff.net.2"`
> [!NOTE]
> `--lora_layers` can also be used to specify which **blocks** to apply LoRA training to. To do so, simply add a block prefix to each layer in the comma separated string.
> [!NOTE]
> Keep in mind that while training more layers can improve quality and expressiveness, it also increases the size of the output LoRA weights.
### Aspect Ratio Bucketing
We've added aspect ratio bucketing support which allows training on images with different aspect ratios without cropping them to a single square resolution. This technique helps preserve the original composition of training images and can improve training efficiency.
To enable aspect ratio bucketing, pass `--aspect_ratio_buckets` argument with a semicolon-separated list of height,width pairs, such as:
```bash
--aspect_ratio_buckets="672,1568;688,1504;720,1456;752,1392;800,1328;832,1248;880,1184;944,1104;1024,1024;1104,944;1184,880;1248,832;1328,800;1392,752;1456,720;1504,688;1568,672"
```
### Bilingual Prompts
Z-Image has strong support for both Chinese and English prompts. When training with Chinese prompts, ensure your dataset captions are properly encoded in UTF-8:
```bash
--instance_prompt="一只sks狗的照片"
--validation_prompt="一只sks狗在桶里的照片"
```
> [!TIP]
> Z-Image excels at text rendering in generated images, especially for Chinese characters. If your use case involves generating images with text, consider including text-related examples in your training data.
## Inference
Once you have trained a LoRA, you can load it for inference:
```python
import torch
from diffusers import ZImagePipeline
pipe = ZImagePipeline.from_pretrained("Tongyi-MAI/Z-Image", torch_dtype=torch.bfloat16)
pipe.to("cuda")
# Load your trained LoRA
pipe.load_lora_weights("path/to/your/trained-z-image-lora")
# Generate an image
image = pipe(
prompt="A photo of sks dog in a bucket",
height=1024,
width=1024,
num_inference_steps=50,
guidance_scale=5.0,
generator=torch.Generator("cuda").manual_seed(42),
).images[0]
image.save("output.png")
```
---
Since Z-Image finetuning is still in an experimental phase, we encourage you to explore different settings and share your insights! 🤗

File diff suppressed because it is too large Load Diff

View File

@@ -2321,14 +2321,6 @@ def _convert_non_diffusers_flux2_lora_to_diffusers(state_dict):
prefix = "diffusion_model."
original_state_dict = {k[len(prefix) :]: v for k, v in state_dict.items()}
has_lora_down_up = any("lora_down" in k or "lora_up" in k for k in original_state_dict.keys())
if has_lora_down_up:
temp_state_dict = {}
for k, v in original_state_dict.items():
new_key = k.replace("lora_down", "lora_A").replace("lora_up", "lora_B")
temp_state_dict[new_key] = v
original_state_dict = temp_state_dict
num_double_layers = 0
num_single_layers = 0
for key in original_state_dict.keys():
@@ -2345,15 +2337,13 @@ def _convert_non_diffusers_flux2_lora_to_diffusers(state_dict):
attn_prefix = f"single_transformer_blocks.{sl}.attn"
for lora_key in lora_keys:
linear1_key = f"{single_block_prefix}.linear1.{lora_key}.weight"
if linear1_key in original_state_dict:
converted_state_dict[f"{attn_prefix}.to_qkv_mlp_proj.{lora_key}.weight"] = original_state_dict.pop(
linear1_key
)
converted_state_dict[f"{attn_prefix}.to_qkv_mlp_proj.{lora_key}.weight"] = original_state_dict.pop(
f"{single_block_prefix}.linear1.{lora_key}.weight"
)
linear2_key = f"{single_block_prefix}.linear2.{lora_key}.weight"
if linear2_key in original_state_dict:
converted_state_dict[f"{attn_prefix}.to_out.{lora_key}.weight"] = original_state_dict.pop(linear2_key)
converted_state_dict[f"{attn_prefix}.to_out.{lora_key}.weight"] = original_state_dict.pop(
f"{single_block_prefix}.linear2.{lora_key}.weight"
)
for dl in range(num_double_layers):
transformer_block_prefix = f"transformer_blocks.{dl}"
@@ -2362,10 +2352,6 @@ def _convert_non_diffusers_flux2_lora_to_diffusers(state_dict):
for attn_type in attn_types:
attn_prefix = f"{transformer_block_prefix}.attn"
qkv_key = f"double_blocks.{dl}.{attn_type}.qkv.{lora_key}.weight"
if qkv_key not in original_state_dict:
continue
fused_qkv_weight = original_state_dict.pop(qkv_key)
if lora_key == "lora_A":
@@ -2397,9 +2383,8 @@ def _convert_non_diffusers_flux2_lora_to_diffusers(state_dict):
for org_proj, diff_proj in proj_mappings:
for lora_key in lora_keys:
original_key = f"double_blocks.{dl}.{org_proj}.{lora_key}.weight"
if original_key in original_state_dict:
diffusers_key = f"{transformer_block_prefix}.{diff_proj}.{lora_key}.weight"
converted_state_dict[diffusers_key] = original_state_dict.pop(original_key)
diffusers_key = f"{transformer_block_prefix}.{diff_proj}.{lora_key}.weight"
converted_state_dict[diffusers_key] = original_state_dict.pop(original_key)
mlp_mappings = [
("img_mlp.0", "ff.linear_in"),
@@ -2410,27 +2395,8 @@ def _convert_non_diffusers_flux2_lora_to_diffusers(state_dict):
for org_mlp, diff_mlp in mlp_mappings:
for lora_key in lora_keys:
original_key = f"double_blocks.{dl}.{org_mlp}.{lora_key}.weight"
if original_key in original_state_dict:
diffusers_key = f"{transformer_block_prefix}.{diff_mlp}.{lora_key}.weight"
converted_state_dict[diffusers_key] = original_state_dict.pop(original_key)
extra_mappings = {
"img_in": "x_embedder",
"txt_in": "context_embedder",
"time_in.in_layer": "time_guidance_embed.timestep_embedder.linear_1",
"time_in.out_layer": "time_guidance_embed.timestep_embedder.linear_2",
"final_layer.linear": "proj_out",
"final_layer.adaLN_modulation.1": "norm_out.linear",
"single_stream_modulation.lin": "single_stream_modulation.linear",
"double_stream_modulation_img.lin": "double_stream_modulation_img.linear",
"double_stream_modulation_txt.lin": "double_stream_modulation_txt.linear",
}
for org_key, diff_key in extra_mappings.items():
for lora_key in lora_keys:
original_key = f"{org_key}.{lora_key}.weight"
if original_key in original_state_dict:
converted_state_dict[f"{diff_key}.{lora_key}.weight"] = original_state_dict.pop(original_key)
diffusers_key = f"{transformer_block_prefix}.{diff_mlp}.{lora_key}.weight"
converted_state_dict[diffusers_key] = original_state_dict.pop(original_key)
if len(original_state_dict) > 0:
raise ValueError(f"`original_state_dict` should be empty at this point but has {original_state_dict.keys()=}.")
@@ -2455,22 +2421,18 @@ def _convert_non_diffusers_z_image_lora_to_diffusers(state_dict):
if has_diffusion_model:
state_dict = {k.removeprefix("diffusion_model."): v for k, v in state_dict.items()}
has_lora_unet = any(k.startswith("lora_unet_") or k.startswith("lora_unet__") for k in state_dict)
has_lora_unet = any(k.startswith("lora_unet_") for k in state_dict)
if has_lora_unet:
state_dict = {k.removeprefix("lora_unet__").removeprefix("lora_unet_"): v for k, v in state_dict.items()}
state_dict = {k.removeprefix("lora_unet_"): v for k, v in state_dict.items()}
def convert_key(key: str) -> str:
# ZImage has: layers, noise_refiner, context_refiner blocks
# Keys may be like: layers_0_attention_to_q.lora_down.weight
suffix = ""
for sfx in (".lora_down.weight", ".lora_up.weight", ".alpha"):
if key.endswith(sfx):
base = key[: -len(sfx)]
suffix = sfx
break
else:
base = key
if "." in key:
base, suffix = key.rsplit(".", 1)
else:
base, suffix = key, ""
# Protected n-grams that must keep their internal underscores
protected = {
@@ -2481,9 +2443,6 @@ def _convert_non_diffusers_z_image_lora_to_diffusers(state_dict):
("to", "out"),
# feed_forward
("feed", "forward"),
# noise and context refiner
("noise", "refiner"),
("context", "refiner"),
}
prot_by_len = {}
@@ -2508,7 +2467,7 @@ def _convert_non_diffusers_z_image_lora_to_diffusers(state_dict):
i += 1
converted_base = ".".join(merged)
return converted_base + suffix
return converted_base + (("." + suffix) if suffix else "")
state_dict = {convert_key(k): v for k, v in state_dict.items()}

View File

@@ -264,6 +264,10 @@ class _HubKernelConfig:
function_attr: str
revision: Optional[str] = None
kernel_fn: Optional[Callable] = None
wrapped_forward_attr: Optional[str] = None
wrapped_backward_attr: Optional[str] = None
wrapped_forward_fn: Optional[Callable] = None
wrapped_backward_fn: Optional[Callable] = None
# Registry for hub-based attention kernels
@@ -278,7 +282,11 @@ _HUB_KERNELS_REGISTRY: Dict["AttentionBackendName", _HubKernelConfig] = {
# revision="fake-ops-return-probs",
),
AttentionBackendName.FLASH_HUB: _HubKernelConfig(
repo_id="kernels-community/flash-attn2", function_attr="flash_attn_func", revision=None
repo_id="kernels-community/flash-attn2",
function_attr="flash_attn_func",
revision=None,
wrapped_forward_attr="flash_attn_interface._wrapped_flash_attn_forward",
wrapped_backward_attr="flash_attn_interface._wrapped_flash_attn_backward",
),
AttentionBackendName.FLASH_VARLEN_HUB: _HubKernelConfig(
repo_id="kernels-community/flash-attn2", function_attr="flash_attn_varlen_func", revision=None
@@ -603,22 +611,39 @@ def _flex_attention_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx):
# ===== Helpers for downloading kernels =====
def _resolve_kernel_attr(module, attr_path: str):
target = module
for attr in attr_path.split("."):
if not hasattr(target, attr):
raise AttributeError(f"Kernel module '{module.__name__}' does not define attribute path '{attr_path}'.")
target = getattr(target, attr)
return target
def _maybe_download_kernel_for_backend(backend: AttentionBackendName) -> None:
if backend not in _HUB_KERNELS_REGISTRY:
return
config = _HUB_KERNELS_REGISTRY[backend]
if config.kernel_fn is not None:
needs_kernel = config.kernel_fn is None
needs_wrapped_forward = config.wrapped_forward_attr is not None and config.wrapped_forward_fn is None
needs_wrapped_backward = config.wrapped_backward_attr is not None and config.wrapped_backward_fn is None
if not (needs_kernel or needs_wrapped_forward or needs_wrapped_backward):
return
try:
from kernels import get_kernel
kernel_module = get_kernel(config.repo_id, revision=config.revision)
kernel_func = getattr(kernel_module, config.function_attr)
if needs_kernel:
config.kernel_fn = _resolve_kernel_attr(kernel_module, config.function_attr)
# Cache the downloaded kernel function in the config object
config.kernel_fn = kernel_func
if needs_wrapped_forward:
config.wrapped_forward_fn = _resolve_kernel_attr(kernel_module, config.wrapped_forward_attr)
if needs_wrapped_backward:
config.wrapped_backward_fn = _resolve_kernel_attr(kernel_module, config.wrapped_backward_attr)
except Exception as e:
logger.error(f"An error occurred while fetching kernel '{config.repo_id}' from the Hub: {e}")
@@ -1069,6 +1094,231 @@ def _flash_attention_backward_op(
return grad_query, grad_key, grad_value
def _flash_attention_hub_forward_op(
ctx: torch.autograd.function.FunctionCtx,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
enable_gqa: bool = False,
return_lse: bool = False,
_save_ctx: bool = True,
_parallel_config: Optional["ParallelConfig"] = None,
):
if attn_mask is not None:
raise ValueError("`attn_mask` is not yet supported for flash-attn hub kernels.")
if enable_gqa:
raise ValueError("`enable_gqa` is not yet supported for flash-attn hub kernels.")
config = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_HUB]
wrapped_forward_fn = config.wrapped_forward_fn
wrapped_backward_fn = config.wrapped_backward_fn
if wrapped_forward_fn is None or wrapped_backward_fn is None:
raise RuntimeError(
"Flash attention hub kernels must expose `_wrapped_flash_attn_forward` and `_wrapped_flash_attn_backward` "
"for context parallel execution."
)
if scale is None:
scale = query.shape[-1] ** (-0.5)
window_size = (-1, -1)
softcap = 0.0
alibi_slopes = None
deterministic = False
grad_enabled = any(x.requires_grad for x in (query, key, value))
if grad_enabled or (_parallel_config is not None and _parallel_config.context_parallel_config._world_size > 1):
dropout_p = dropout_p if dropout_p > 0 else 1e-30
with torch.set_grad_enabled(grad_enabled):
out, lse, S_dmask, rng_state = wrapped_forward_fn(
query,
key,
value,
dropout_p,
scale,
is_causal,
window_size[0],
window_size[1],
softcap,
alibi_slopes,
return_lse,
)
lse = lse.permute(0, 2, 1).contiguous()
if _save_ctx:
ctx.save_for_backward(query, key, value, out, lse, rng_state)
ctx.dropout_p = dropout_p
ctx.scale = scale
ctx.is_causal = is_causal
ctx.window_size = window_size
ctx.softcap = softcap
ctx.alibi_slopes = alibi_slopes
ctx.deterministic = deterministic
return (out, lse) if return_lse else out
def _flash_attention_hub_backward_op(
ctx: torch.autograd.function.FunctionCtx,
grad_out: torch.Tensor,
*args,
**kwargs,
):
config = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_HUB]
wrapped_backward_fn = config.wrapped_backward_fn
if wrapped_backward_fn is None:
raise RuntimeError(
"Flash attention hub kernels must expose `_wrapped_flash_attn_backward` for context parallel execution."
)
query, key, value, out, lse, rng_state = ctx.saved_tensors
grad_query, grad_key, grad_value = torch.empty_like(query), torch.empty_like(key), torch.empty_like(value)
_ = wrapped_backward_fn(
grad_out,
query,
key,
value,
out,
lse,
grad_query,
grad_key,
grad_value,
ctx.dropout_p,
ctx.scale,
ctx.is_causal,
ctx.window_size[0],
ctx.window_size[1],
ctx.softcap,
ctx.alibi_slopes,
ctx.deterministic,
rng_state,
)
grad_query = grad_query[..., : grad_out.shape[-1]]
grad_key = grad_key[..., : grad_out.shape[-1]]
grad_value = grad_value[..., : grad_out.shape[-1]]
return grad_query, grad_key, grad_value
def _flash_attention_3_hub_forward_op(
ctx: torch.autograd.function.FunctionCtx,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
enable_gqa: bool = False,
return_lse: bool = False,
_save_ctx: bool = True,
_parallel_config: Optional["ParallelConfig"] = None,
*,
window_size: Tuple[int, int] = (-1, -1),
softcap: float = 0.0,
num_splits: int = 1,
pack_gqa: Optional[bool] = None,
deterministic: bool = False,
sm_margin: int = 0,
):
if attn_mask is not None:
raise ValueError("`attn_mask` is not yet supported for flash-attn 3 hub kernels.")
if dropout_p != 0.0:
raise ValueError("`dropout_p` is not yet supported for flash-attn 3 hub kernels.")
if enable_gqa:
raise ValueError("`enable_gqa` is not yet supported for flash-attn 3 hub kernels.")
func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB].kernel_fn
out = func(
q=query,
k=key,
v=value,
softmax_scale=scale,
causal=is_causal,
qv=None,
q_descale=None,
k_descale=None,
v_descale=None,
window_size=window_size,
softcap=softcap,
num_splits=num_splits,
pack_gqa=pack_gqa,
deterministic=deterministic,
sm_margin=sm_margin,
return_attn_probs=return_lse,
)
lse = None
if return_lse:
out, lse = out
lse = lse.permute(0, 2, 1).contiguous()
if _save_ctx:
ctx.save_for_backward(query, key, value)
ctx.scale = scale
ctx.is_causal = is_causal
ctx._hub_kernel = func
return (out, lse) if return_lse else out
def _flash_attention_3_hub_backward_op(
ctx: torch.autograd.function.FunctionCtx,
grad_out: torch.Tensor,
*args,
window_size: Tuple[int, int] = (-1, -1),
softcap: float = 0.0,
num_splits: int = 1,
pack_gqa: Optional[bool] = None,
deterministic: bool = False,
sm_margin: int = 0,
):
query, key, value = ctx.saved_tensors
kernel_fn = ctx._hub_kernel
with torch.enable_grad():
query_r = query.detach().requires_grad_(True)
key_r = key.detach().requires_grad_(True)
value_r = value.detach().requires_grad_(True)
out = kernel_fn(
q=query_r,
k=key_r,
v=value_r,
softmax_scale=ctx.scale,
causal=ctx.is_causal,
qv=None,
q_descale=None,
k_descale=None,
v_descale=None,
window_size=window_size,
softcap=softcap,
num_splits=num_splits,
pack_gqa=pack_gqa,
deterministic=deterministic,
sm_margin=sm_margin,
return_attn_probs=False,
)
if isinstance(out, tuple):
out = out[0]
grad_query, grad_key, grad_value = torch.autograd.grad(
out,
(query_r, key_r, value_r),
grad_out,
retain_graph=False,
allow_unused=False,
)
return grad_query, grad_key, grad_value
def _sage_attention_forward_op(
ctx: torch.autograd.function.FunctionCtx,
query: torch.Tensor,
@@ -1107,6 +1357,46 @@ def _sage_attention_forward_op(
return (out, lse) if return_lse else out
def _sage_attention_hub_forward_op(
ctx: torch.autograd.function.FunctionCtx,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
enable_gqa: bool = False,
return_lse: bool = False,
_save_ctx: bool = True,
_parallel_config: Optional["ParallelConfig"] = None,
):
if attn_mask is not None:
raise ValueError("`attn_mask` is not yet supported for Sage attention.")
if dropout_p > 0.0:
raise ValueError("`dropout_p` is not yet supported for Sage attention.")
if enable_gqa:
raise ValueError("`enable_gqa` is not yet supported for Sage attention.")
func = _HUB_KERNELS_REGISTRY[AttentionBackendName.SAGE_HUB].kernel_fn
out = func(
q=query,
k=key,
v=value,
tensor_layout="NHD",
is_causal=is_causal,
sm_scale=scale,
return_lse=return_lse,
)
lse = None
if return_lse:
out, lse, *_ = out
lse = lse.permute(0, 2, 1).contiguous()
return (out, lse) if return_lse else out
def _sage_attention_backward_op(
ctx: torch.autograd.function.FunctionCtx,
grad_out: torch.Tensor,
@@ -1940,7 +2230,7 @@ def _flash_attention(
@_AttentionBackendRegistry.register(
AttentionBackendName.FLASH_HUB,
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
supports_context_parallel=False,
supports_context_parallel=True,
)
def _flash_attention_hub(
query: torch.Tensor,
@@ -1958,17 +2248,35 @@ def _flash_attention_hub(
raise ValueError("`attn_mask` is not supported for flash-attn 2.")
func = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_HUB].kernel_fn
out = func(
q=query,
k=key,
v=value,
dropout_p=dropout_p,
softmax_scale=scale,
causal=is_causal,
return_attn_probs=return_lse,
)
if return_lse:
out, lse, *_ = out
if _parallel_config is None:
out = func(
q=query,
k=key,
v=value,
dropout_p=dropout_p,
softmax_scale=scale,
causal=is_causal,
return_attn_probs=return_lse,
)
if return_lse:
out, lse, *_ = out
else:
out = _templated_context_parallel_attention(
query,
key,
value,
None,
dropout_p,
is_causal,
scale,
False,
return_lse,
forward_op=_flash_attention_hub_forward_op,
backward_op=_flash_attention_hub_backward_op,
_parallel_config=_parallel_config,
)
if return_lse:
out, lse = out
return (out, lse) if return_lse else out
@@ -2115,7 +2423,7 @@ def _flash_attention_3(
@_AttentionBackendRegistry.register(
AttentionBackendName._FLASH_3_HUB,
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
supports_context_parallel=False,
supports_context_parallel=True,
)
def _flash_attention_3_hub(
query: torch.Tensor,
@@ -2130,33 +2438,68 @@ def _flash_attention_3_hub(
return_attn_probs: bool = False,
_parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
if _parallel_config:
raise NotImplementedError(f"{AttentionBackendName._FLASH_3_HUB.value} is not implemented for parallelism yet.")
if attn_mask is not None:
raise ValueError("`attn_mask` is not supported for flash-attn 3.")
func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB].kernel_fn
out = func(
q=query,
k=key,
v=value,
softmax_scale=scale,
causal=is_causal,
qv=None,
q_descale=None,
k_descale=None,
v_descale=None,
if _parallel_config is None:
out = func(
q=query,
k=key,
v=value,
softmax_scale=scale,
causal=is_causal,
qv=None,
q_descale=None,
k_descale=None,
v_descale=None,
window_size=window_size,
softcap=softcap,
num_splits=1,
pack_gqa=None,
deterministic=deterministic,
sm_margin=0,
return_attn_probs=return_attn_probs,
)
return (out[0], out[1]) if return_attn_probs else out
forward_op = functools.partial(
_flash_attention_3_hub_forward_op,
window_size=window_size,
softcap=softcap,
num_splits=1,
pack_gqa=None,
deterministic=deterministic,
sm_margin=0,
return_attn_probs=return_attn_probs,
)
# When `return_attn_probs` is True, the above returns a tuple of
# actual outputs and lse.
return (out[0], out[1]) if return_attn_probs else out
backward_op = functools.partial(
_flash_attention_3_hub_backward_op,
window_size=window_size,
softcap=softcap,
num_splits=1,
pack_gqa=None,
deterministic=deterministic,
sm_margin=0,
)
out = _templated_context_parallel_attention(
query,
key,
value,
None,
0.0,
is_causal,
scale,
False,
return_attn_probs,
forward_op=forward_op,
backward_op=backward_op,
_parallel_config=_parallel_config,
)
if return_attn_probs:
out, lse = out
return out, lse
return out
@_AttentionBackendRegistry.register(
@@ -2787,7 +3130,7 @@ def _sage_attention(
@_AttentionBackendRegistry.register(
AttentionBackendName.SAGE_HUB,
constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape],
supports_context_parallel=False,
supports_context_parallel=True,
)
def _sage_attention_hub(
query: torch.Tensor,
@@ -2815,6 +3158,23 @@ def _sage_attention_hub(
)
if return_lse:
out, lse, *_ = out
else:
out = _templated_context_parallel_attention(
query,
key,
value,
None,
0.0,
is_causal,
scale,
False,
return_lse,
forward_op=_sage_attention_hub_forward_op,
backward_op=_sage_attention_backward_op,
_parallel_config=_parallel_config,
)
if return_lse:
out, lse = out
return (out, lse) if return_lse else out

View File

@@ -43,7 +43,7 @@ def _get_qkv_projections(attn: "WanAttention", hidden_states: torch.Tensor, enco
encoder_hidden_states = hidden_states
if attn.fused_projections:
if not attn.is_cross_attention:
if attn.cross_attention_dim_head is None:
# In self-attention layers, we can fuse the entire QKV projection into a single linear
query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1)
else:
@@ -219,10 +219,7 @@ class WanAttention(torch.nn.Module, AttentionModuleMixin):
self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True)
self.norm_added_k = torch.nn.RMSNorm(dim_head * heads, eps=eps)
if is_cross_attention is not None:
self.is_cross_attention = is_cross_attention
else:
self.is_cross_attention = cross_attention_dim_head is not None
self.is_cross_attention = cross_attention_dim_head is not None
self.set_processor(processor)
@@ -230,7 +227,7 @@ class WanAttention(torch.nn.Module, AttentionModuleMixin):
if getattr(self, "fused_projections", False):
return
if not self.is_cross_attention:
if self.cross_attention_dim_head is None:
concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data])
out_features, in_features = concatenated_weights.shape

View File

@@ -42,7 +42,7 @@ def _get_qkv_projections(attn: "WanAttention", hidden_states: torch.Tensor, enco
encoder_hidden_states = hidden_states
if attn.fused_projections:
if not attn.is_cross_attention:
if attn.cross_attention_dim_head is None:
# In self-attention layers, we can fuse the entire QKV projection into a single linear
query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1)
else:
@@ -214,10 +214,7 @@ class WanAttention(torch.nn.Module, AttentionModuleMixin):
self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True)
self.norm_added_k = torch.nn.RMSNorm(dim_head * heads, eps=eps)
if is_cross_attention is not None:
self.is_cross_attention = is_cross_attention
else:
self.is_cross_attention = cross_attention_dim_head is not None
self.is_cross_attention = cross_attention_dim_head is not None
self.set_processor(processor)
@@ -225,7 +222,7 @@ class WanAttention(torch.nn.Module, AttentionModuleMixin):
if getattr(self, "fused_projections", False):
return
if not self.is_cross_attention:
if self.cross_attention_dim_head is None:
concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data])
out_features, in_features = concatenated_weights.shape

View File

@@ -54,7 +54,7 @@ def _get_qkv_projections(attn: "WanAttention", hidden_states: torch.Tensor, enco
encoder_hidden_states = hidden_states
if attn.fused_projections:
if not attn.is_cross_attention:
if attn.cross_attention_dim_head is None:
# In self-attention layers, we can fuse the entire QKV projection into a single linear
query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1)
else:
@@ -502,16 +502,13 @@ class WanAnimateFaceBlockCrossAttention(nn.Module, AttentionModuleMixin):
dim_head: int = 64,
eps: float = 1e-6,
cross_attention_dim_head: Optional[int] = None,
bias: bool = True,
processor=None,
):
super().__init__()
self.inner_dim = dim_head * heads
self.heads = heads
self.cross_attention_dim_head = cross_attention_dim_head
self.cross_attention_head_dim = cross_attention_dim_head
self.kv_inner_dim = self.inner_dim if cross_attention_dim_head is None else cross_attention_dim_head * heads
self.use_bias = bias
self.is_cross_attention = cross_attention_dim_head is not None
# 1. Pre-Attention Norms for the hidden_states (video latents) and encoder_hidden_states (motion vector).
# NOTE: this is not used in "vanilla" WanAttention
@@ -519,10 +516,10 @@ class WanAnimateFaceBlockCrossAttention(nn.Module, AttentionModuleMixin):
self.pre_norm_kv = nn.LayerNorm(dim, eps, elementwise_affine=False)
# 2. QKV and Output Projections
self.to_q = torch.nn.Linear(dim, self.inner_dim, bias=bias)
self.to_k = torch.nn.Linear(dim, self.kv_inner_dim, bias=bias)
self.to_v = torch.nn.Linear(dim, self.kv_inner_dim, bias=bias)
self.to_out = torch.nn.Linear(self.inner_dim, dim, bias=bias)
self.to_q = torch.nn.Linear(dim, self.inner_dim, bias=True)
self.to_k = torch.nn.Linear(dim, self.kv_inner_dim, bias=True)
self.to_v = torch.nn.Linear(dim, self.kv_inner_dim, bias=True)
self.to_out = torch.nn.Linear(self.inner_dim, dim, bias=True)
# 3. QK Norm
# NOTE: this is applied after the reshape, so only over dim_head rather than dim_head * heads
@@ -685,10 +682,7 @@ class WanAttention(torch.nn.Module, AttentionModuleMixin):
self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True)
self.norm_added_k = torch.nn.RMSNorm(dim_head * heads, eps=eps)
if is_cross_attention is not None:
self.is_cross_attention = is_cross_attention
else:
self.is_cross_attention = cross_attention_dim_head is not None
self.is_cross_attention = cross_attention_dim_head is not None
self.set_processor(processor)
@@ -696,7 +690,7 @@ class WanAttention(torch.nn.Module, AttentionModuleMixin):
if getattr(self, "fused_projections", False):
return
if not self.is_cross_attention:
if self.cross_attention_dim_head is None:
concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data])
out_features, in_features = concatenated_weights.shape

View File

@@ -76,7 +76,6 @@ class WanVACETransformerBlock(nn.Module):
eps=eps,
added_kv_proj_dim=added_kv_proj_dim,
processor=WanAttnProcessor(),
is_cross_attention=True,
)
self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity()
@@ -179,7 +178,6 @@ class WanVACETransformer3DModel(
_no_split_modules = ["WanTransformerBlock", "WanVACETransformerBlock"]
_keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"]
_keys_to_ignore_on_load_unexpected = ["norm_added_q"]
_repeated_blocks = ["WanTransformerBlock", "WanVACETransformerBlock"]
@register_to_config
def __init__(

View File

@@ -18,6 +18,7 @@ import re
from copy import deepcopy
from typing import Any, Callable, Dict, List, Optional, Union
import ftfy
import torch
from transformers import AutoTokenizer, UMT5EncoderModel

View File

@@ -18,6 +18,7 @@ import re
from copy import deepcopy
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import ftfy
import PIL
import torch
from transformers import AutoTokenizer, UMT5EncoderModel

View File

@@ -19,6 +19,7 @@ import re
from copy import deepcopy
from typing import Any, Callable, Dict, List, Optional, Union
import ftfy
import torch
from PIL import Image
from transformers import AutoTokenizer, UMT5EncoderModel

View File

@@ -41,7 +41,7 @@ class GGUFQuantizer(DiffusersQuantizer):
self.compute_dtype = quantization_config.compute_dtype
self.pre_quantized = quantization_config.pre_quantized
self.modules_to_not_convert = quantization_config.modules_to_not_convert or []
self.modules_to_not_convert = quantization_config.modules_to_not_convert
if not isinstance(self.modules_to_not_convert, list):
self.modules_to_not_convert = [self.modules_to_not_convert]

View File

@@ -446,17 +446,16 @@ class ModelTesterMixin:
torch_device not in ["cuda", "xpu"],
reason="float16 and bfloat16 can only be used with an accelerator",
)
def test_keep_in_fp32_modules(self, tmp_path):
def test_keep_in_fp32_modules(self):
model = self.model_class(**self.get_init_dict())
fp32_modules = model._keep_in_fp32_modules
if fp32_modules is None or len(fp32_modules) == 0:
pytest.skip("Model does not have _keep_in_fp32_modules defined.")
# Save the model and reload with float16 dtype
# _keep_in_fp32_modules is only enforced during from_pretrained loading
model.save_pretrained(tmp_path)
model = self.model_class.from_pretrained(tmp_path, torch_dtype=torch.float16).to(torch_device)
# Test with float16
model.to(torch_device)
model.to(torch.float16)
for name, param in model.named_parameters():
if any(module_to_keep_in_fp32 in name.split(".") for module_to_keep_in_fp32 in fp32_modules):
@@ -471,7 +470,7 @@ class ModelTesterMixin:
)
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"])
@torch.no_grad()
def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype, atol=1e-4, rtol=0):
def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype):
model = self.model_class(**self.get_init_dict())
model.to(torch_device)
fp32_modules = model._keep_in_fp32_modules or []
@@ -491,6 +490,10 @@ class ModelTesterMixin:
output = model(**inputs, return_dict=False)[0]
output_loaded = model_loaded(**inputs, return_dict=False)[0]
self._check_dtype_inference_output(output, output_loaded, dtype)
def _check_dtype_inference_output(self, output, output_loaded, dtype, atol=1e-4, rtol=0):
"""Check dtype inference output with configurable tolerance."""
assert_tensors_close(
output, output_loaded, atol=atol, rtol=rtol, msg=f"Loaded model output differs for {dtype}"
)

View File

@@ -176,7 +176,15 @@ class QuantizationTesterMixin:
model_quantized = self._create_quantized_model(config_kwargs)
model_quantized.to(torch_device)
# Get model dtype from first parameter
model_dtype = next(model_quantized.parameters()).dtype
inputs = self.get_dummy_inputs()
# Cast inputs to model dtype
inputs = {
k: v.to(model_dtype) if isinstance(v, torch.Tensor) and v.is_floating_point() else v
for k, v in inputs.items()
}
output = model_quantized(**inputs, return_dict=False)[0]
assert output is not None, "Model output is None"
@@ -221,8 +229,6 @@ class QuantizationTesterMixin:
init_lora_weights=False,
)
model.add_adapter(lora_config)
# Move LoRA adapter weights to device (they default to CPU)
model.to(torch_device)
inputs = self.get_dummy_inputs()
output = model(**inputs, return_dict=False)[0]
@@ -1015,6 +1021,9 @@ class GGUFTesterMixin(GGUFConfigMixin, QuantizationTesterMixin):
"""Test that dequantize() works correctly."""
self._test_dequantize({"compute_dtype": torch.bfloat16})
def test_gguf_quantized_layers(self):
self._test_quantized_layers({"compute_dtype": torch.bfloat16})
@is_quantization
@is_modelopt

View File

@@ -12,57 +12,57 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
import unittest
import torch
from diffusers import WanTransformer3DModel
from diffusers.utils.torch_utils import randn_tensor
from ...testing_utils import enable_full_determinism, torch_device
from ..testing_utils import (
AttentionTesterMixin,
BaseModelTesterConfig,
BitsAndBytesTesterMixin,
GGUFCompileTesterMixin,
GGUFTesterMixin,
MemoryTesterMixin,
ModelTesterMixin,
TorchAoTesterMixin,
TorchCompileTesterMixin,
TrainingTesterMixin,
from ...testing_utils import (
enable_full_determinism,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
enable_full_determinism()
class WanTransformer3DTesterConfig(BaseModelTesterConfig):
@property
def model_class(self):
return WanTransformer3DModel
class WanTransformer3DTests(ModelTesterMixin, unittest.TestCase):
model_class = WanTransformer3DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
@property
def pretrained_model_name_or_path(self):
return "hf-internal-testing/tiny-wan22-transformer"
def dummy_input(self):
batch_size = 1
num_channels = 4
num_frames = 2
height = 16
width = 16
text_encoder_embedding_dim = 16
sequence_length = 12
@property
def output_shape(self) -> tuple[int, ...]:
return (4, 2, 16, 16)
hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, text_encoder_embedding_dim)).to(torch_device)
@property
def input_shape(self) -> tuple[int, ...]:
return (4, 2, 16, 16)
@property
def main_input_name(self) -> str:
return "hidden_states"
@property
def generator(self):
return torch.Generator("cpu").manual_seed(0)
def get_init_dict(self) -> dict[str, int | list[int] | tuple | str | bool]:
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"timestep": timestep,
}
@property
def input_shape(self):
return (4, 1, 16, 16)
@property
def output_shape(self):
return (4, 1, 16, 16)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"patch_size": (1, 2, 2),
"num_attention_heads": 2,
"attention_head_dim": 12,
@@ -76,160 +76,16 @@ class WanTransformer3DTesterConfig(BaseModelTesterConfig):
"qk_norm": "rms_norm_across_heads",
"rope_max_seq_len": 32,
}
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
batch_size = 1
num_channels = 4
num_frames = 2
height = 16
width = 16
text_encoder_embedding_dim = 16
sequence_length = 12
return {
"hidden_states": randn_tensor(
(batch_size, num_channels, num_frames, height, width),
generator=self.generator,
device=torch_device,
),
"encoder_hidden_states": randn_tensor(
(batch_size, sequence_length, text_encoder_embedding_dim),
generator=self.generator,
device=torch_device,
),
"timestep": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device),
}
class TestWanTransformer3D(WanTransformer3DTesterConfig, ModelTesterMixin):
"""Core model tests for Wan Transformer 3D."""
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"])
def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype):
# Skip: fp16/bf16 require very high atol to pass, providing little signal.
# Dtype preservation is already tested by test_from_save_pretrained_dtype and test_keep_in_fp32_modules.
pytest.skip("Tolerance requirements too high for meaningful test")
class TestWanTransformer3DMemory(WanTransformer3DTesterConfig, MemoryTesterMixin):
"""Memory optimization tests for Wan Transformer 3D."""
class TestWanTransformer3DTraining(WanTransformer3DTesterConfig, TrainingTesterMixin):
"""Training tests for Wan Transformer 3D."""
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_gradient_checkpointing_is_applied(self):
expected_set = {"WanTransformer3DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class TestWanTransformer3DAttention(WanTransformer3DTesterConfig, AttentionTesterMixin):
"""Attention processor tests for Wan Transformer 3D."""
class WanTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
model_class = WanTransformer3DModel
class TestWanTransformer3DCompile(WanTransformer3DTesterConfig, TorchCompileTesterMixin):
"""Torch compile tests for Wan Transformer 3D."""
class TestWanTransformer3DBitsAndBytes(WanTransformer3DTesterConfig, BitsAndBytesTesterMixin):
"""BitsAndBytes quantization tests for Wan Transformer 3D."""
@property
def torch_dtype(self):
return torch.float16
def get_dummy_inputs(self):
"""Override to provide inputs matching the tiny Wan model dimensions."""
return {
"hidden_states": randn_tensor(
(1, 36, 2, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"encoder_hidden_states": randn_tensor(
(1, 512, 4096), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"timestep": torch.tensor([1.0]).to(torch_device, self.torch_dtype),
}
class TestWanTransformer3DTorchAo(WanTransformer3DTesterConfig, TorchAoTesterMixin):
"""TorchAO quantization tests for Wan Transformer 3D."""
@property
def torch_dtype(self):
return torch.bfloat16
def get_dummy_inputs(self):
"""Override to provide inputs matching the tiny Wan model dimensions."""
return {
"hidden_states": randn_tensor(
(1, 36, 2, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"encoder_hidden_states": randn_tensor(
(1, 512, 4096), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"timestep": torch.tensor([1.0]).to(torch_device, self.torch_dtype),
}
class TestWanTransformer3DGGUF(WanTransformer3DTesterConfig, GGUFTesterMixin):
"""GGUF quantization tests for Wan Transformer 3D."""
@property
def gguf_filename(self):
return "https://huggingface.co/QuantStack/Wan2.2-I2V-A14B-GGUF/blob/main/LowNoise/Wan2.2-I2V-A14B-LowNoise-Q2_K.gguf"
@property
def torch_dtype(self):
return torch.bfloat16
def _create_quantized_model(self, config_kwargs=None, **extra_kwargs):
return super()._create_quantized_model(
config_kwargs, config="Wan-AI/Wan2.2-I2V-A14B-Diffusers", subfolder="transformer", **extra_kwargs
)
def get_dummy_inputs(self):
"""Override to provide inputs matching the real Wan I2V model dimensions.
Wan 2.2 I2V: in_channels=36, text_dim=4096
"""
return {
"hidden_states": randn_tensor(
(1, 36, 2, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"encoder_hidden_states": randn_tensor(
(1, 512, 4096), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"timestep": torch.tensor([1.0]).to(torch_device, self.torch_dtype),
}
class TestWanTransformer3DGGUFCompile(WanTransformer3DTesterConfig, GGUFCompileTesterMixin):
"""GGUF + compile tests for Wan Transformer 3D."""
@property
def gguf_filename(self):
return "https://huggingface.co/QuantStack/Wan2.2-I2V-A14B-GGUF/blob/main/LowNoise/Wan2.2-I2V-A14B-LowNoise-Q2_K.gguf"
@property
def torch_dtype(self):
return torch.bfloat16
def _create_quantized_model(self, config_kwargs=None, **extra_kwargs):
return super()._create_quantized_model(
config_kwargs, config="Wan-AI/Wan2.2-I2V-A14B-Diffusers", subfolder="transformer", **extra_kwargs
)
def get_dummy_inputs(self):
"""Override to provide inputs matching the real Wan I2V model dimensions.
Wan 2.2 I2V: in_channels=36, text_dim=4096
"""
return {
"hidden_states": randn_tensor(
(1, 36, 2, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"encoder_hidden_states": randn_tensor(
(1, 512, 4096), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"timestep": torch.tensor([1.0]).to(torch_device, self.torch_dtype),
}
def prepare_init_args_and_inputs_for_common(self):
return WanTransformer3DTests().prepare_init_args_and_inputs_for_common()

View File

@@ -12,62 +12,76 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
import unittest
import torch
from diffusers import WanAnimateTransformer3DModel
from diffusers.utils.torch_utils import randn_tensor
from ...testing_utils import enable_full_determinism, torch_device
from ..testing_utils import (
AttentionTesterMixin,
BaseModelTesterConfig,
BitsAndBytesTesterMixin,
GGUFCompileTesterMixin,
GGUFTesterMixin,
MemoryTesterMixin,
ModelTesterMixin,
TorchAoTesterMixin,
TorchCompileTesterMixin,
TrainingTesterMixin,
from ...testing_utils import (
enable_full_determinism,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
enable_full_determinism()
class WanAnimateTransformer3DTesterConfig(BaseModelTesterConfig):
@property
def model_class(self):
return WanAnimateTransformer3DModel
class WanAnimateTransformer3DTests(ModelTesterMixin, unittest.TestCase):
model_class = WanAnimateTransformer3DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
@property
def pretrained_model_name_or_path(self):
return "hf-internal-testing/tiny-wan-animate-transformer"
def dummy_input(self):
batch_size = 1
num_channels = 4
num_frames = 20 # To make the shapes work out; for complicated reasons we want 21 to divide num_frames + 1
height = 16
width = 16
text_encoder_embedding_dim = 16
sequence_length = 12
clip_seq_len = 12
clip_dim = 16
inference_segment_length = 77 # The inference segment length in the full Wan2.2-Animate-14B model
face_height = 16 # Should be square and match `motion_encoder_size` below
face_width = 16
hidden_states = torch.randn((batch_size, 2 * num_channels + 4, num_frames + 1, height, width)).to(torch_device)
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, text_encoder_embedding_dim)).to(torch_device)
clip_ref_features = torch.randn((batch_size, clip_seq_len, clip_dim)).to(torch_device)
pose_latents = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
face_pixel_values = torch.randn((batch_size, 3, inference_segment_length, face_height, face_width)).to(
torch_device
)
return {
"hidden_states": hidden_states,
"timestep": timestep,
"encoder_hidden_states": encoder_hidden_states,
"encoder_hidden_states_image": clip_ref_features,
"pose_hidden_states": pose_latents,
"face_pixel_values": face_pixel_values,
}
@property
def output_shape(self) -> tuple[int, ...]:
# Output has fewer channels than input (4 vs 12)
return (4, 21, 16, 16)
def input_shape(self):
return (12, 1, 16, 16)
@property
def input_shape(self) -> tuple[int, ...]:
return (12, 21, 16, 16)
def output_shape(self):
return (4, 1, 16, 16)
@property
def main_input_name(self) -> str:
return "hidden_states"
@property
def generator(self):
return torch.Generator("cpu").manual_seed(0)
def get_init_dict(self) -> dict[str, int | list[int] | tuple | str | bool | float | dict]:
def prepare_init_args_and_inputs_for_common(self):
# Use custom channel sizes since the default Wan Animate channel sizes will cause the motion encoder to
# contain the vast majority of the parameters in the test model
channel_sizes = {"4": 16, "8": 16, "16": 16}
return {
init_dict = {
"patch_size": (1, 2, 2),
"num_attention_heads": 2,
"attention_head_dim": 12,
@@ -91,219 +105,22 @@ class WanAnimateTransformer3DTesterConfig(BaseModelTesterConfig):
"face_encoder_num_heads": 2,
"inject_face_latents_blocks": 2,
}
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
batch_size = 1
num_channels = 4
num_frames = 20 # To make the shapes work out; for complicated reasons we want 21 to divide num_frames + 1
height = 16
width = 16
text_encoder_embedding_dim = 16
sequence_length = 12
clip_seq_len = 12
clip_dim = 16
inference_segment_length = 77 # The inference segment length in the full Wan2.2-Animate-14B model
face_height = 16 # Should be square and match `motion_encoder_size`
face_width = 16
return {
"hidden_states": randn_tensor(
(batch_size, 2 * num_channels + 4, num_frames + 1, height, width),
generator=self.generator,
device=torch_device,
),
"timestep": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device),
"encoder_hidden_states": randn_tensor(
(batch_size, sequence_length, text_encoder_embedding_dim),
generator=self.generator,
device=torch_device,
),
"encoder_hidden_states_image": randn_tensor(
(batch_size, clip_seq_len, clip_dim),
generator=self.generator,
device=torch_device,
),
"pose_hidden_states": randn_tensor(
(batch_size, num_channels, num_frames, height, width),
generator=self.generator,
device=torch_device,
),
"face_pixel_values": randn_tensor(
(batch_size, 3, inference_segment_length, face_height, face_width),
generator=self.generator,
device=torch_device,
),
}
class TestWanAnimateTransformer3D(WanAnimateTransformer3DTesterConfig, ModelTesterMixin):
"""Core model tests for Wan Animate Transformer 3D."""
def test_output(self):
# Override test_output because the transformer output is expected to have less channels
# than the main transformer input.
expected_output_shape = (1, 4, 21, 16, 16)
super().test_output(expected_output_shape=expected_output_shape)
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"])
def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype):
# Skip: fp16/bf16 require very high atol (~1e-2) to pass, providing little signal.
# Dtype preservation is already tested by test_from_save_pretrained_dtype and test_keep_in_fp32_modules.
pytest.skip("Tolerance requirements too high for meaningful test")
class TestWanAnimateTransformer3DMemory(WanAnimateTransformer3DTesterConfig, MemoryTesterMixin):
"""Memory optimization tests for Wan Animate Transformer 3D."""
class TestWanAnimateTransformer3DTraining(WanAnimateTransformer3DTesterConfig, TrainingTesterMixin):
"""Training tests for Wan Animate Transformer 3D."""
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_gradient_checkpointing_is_applied(self):
expected_set = {"WanAnimateTransformer3DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class TestWanAnimateTransformer3DAttention(WanAnimateTransformer3DTesterConfig, AttentionTesterMixin):
"""Attention processor tests for Wan Animate Transformer 3D."""
# Override test_output because the transformer output is expected to have less channels than the main transformer
# input.
def test_output(self):
expected_output_shape = (1, 4, 21, 16, 16)
super().test_output(expected_output_shape=expected_output_shape)
class TestWanAnimateTransformer3DCompile(WanAnimateTransformer3DTesterConfig, TorchCompileTesterMixin):
"""Torch compile tests for Wan Animate Transformer 3D."""
class WanAnimateTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
model_class = WanAnimateTransformer3DModel
def test_torch_compile_recompilation_and_graph_break(self):
# Skip: F.pad with mode="replicate" in WanAnimateFaceEncoder triggers importlib.import_module
# internally, which dynamo doesn't support tracing through.
pytest.skip("F.pad with replicate mode triggers unsupported import in torch.compile")
class TestWanAnimateTransformer3DBitsAndBytes(WanAnimateTransformer3DTesterConfig, BitsAndBytesTesterMixin):
"""BitsAndBytes quantization tests for Wan Animate Transformer 3D."""
@property
def torch_dtype(self):
return torch.float16
def get_dummy_inputs(self):
"""Override to provide inputs matching the tiny Wan Animate model dimensions."""
return {
"hidden_states": randn_tensor(
(1, 36, 21, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"encoder_hidden_states": randn_tensor(
(1, 512, 4096), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"encoder_hidden_states_image": randn_tensor(
(1, 257, 1280), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"pose_hidden_states": randn_tensor(
(1, 16, 20, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"face_pixel_values": randn_tensor(
(1, 3, 77, 512, 512), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"timestep": torch.tensor([1.0]).to(torch_device, self.torch_dtype),
}
class TestWanAnimateTransformer3DTorchAo(WanAnimateTransformer3DTesterConfig, TorchAoTesterMixin):
"""TorchAO quantization tests for Wan Animate Transformer 3D."""
@property
def torch_dtype(self):
return torch.bfloat16
def get_dummy_inputs(self):
"""Override to provide inputs matching the tiny Wan Animate model dimensions."""
return {
"hidden_states": randn_tensor(
(1, 36, 21, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"encoder_hidden_states": randn_tensor(
(1, 512, 4096), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"encoder_hidden_states_image": randn_tensor(
(1, 257, 1280), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"pose_hidden_states": randn_tensor(
(1, 16, 20, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"face_pixel_values": randn_tensor(
(1, 3, 77, 512, 512), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"timestep": torch.tensor([1.0]).to(torch_device, self.torch_dtype),
}
class TestWanAnimateTransformer3DGGUF(WanAnimateTransformer3DTesterConfig, GGUFTesterMixin):
"""GGUF quantization tests for Wan Animate Transformer 3D."""
@property
def gguf_filename(self):
return "https://huggingface.co/QuantStack/Wan2.2-Animate-14B-GGUF/blob/main/Wan2.2-Animate-14B-Q2_K.gguf"
@property
def torch_dtype(self):
return torch.bfloat16
def get_dummy_inputs(self):
"""Override to provide inputs matching the real Wan Animate model dimensions.
Wan 2.2 Animate: in_channels=36 (2*16+4), text_dim=4096, image_dim=1280
"""
return {
"hidden_states": randn_tensor(
(1, 36, 21, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"encoder_hidden_states": randn_tensor(
(1, 512, 4096), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"encoder_hidden_states_image": randn_tensor(
(1, 257, 1280), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"pose_hidden_states": randn_tensor(
(1, 16, 20, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"face_pixel_values": randn_tensor(
(1, 3, 77, 512, 512), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"timestep": torch.tensor([1.0]).to(torch_device, self.torch_dtype),
}
class TestWanAnimateTransformer3DGGUFCompile(WanAnimateTransformer3DTesterConfig, GGUFCompileTesterMixin):
"""GGUF + compile tests for Wan Animate Transformer 3D."""
@property
def gguf_filename(self):
return "https://huggingface.co/QuantStack/Wan2.2-Animate-14B-GGUF/blob/main/Wan2.2-Animate-14B-Q2_K.gguf"
@property
def torch_dtype(self):
return torch.bfloat16
def get_dummy_inputs(self):
"""Override to provide inputs matching the real Wan Animate model dimensions.
Wan 2.2 Animate: in_channels=36 (2*16+4), text_dim=4096, image_dim=1280
"""
return {
"hidden_states": randn_tensor(
(1, 36, 21, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"encoder_hidden_states": randn_tensor(
(1, 512, 4096), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"encoder_hidden_states_image": randn_tensor(
(1, 257, 1280), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"pose_hidden_states": randn_tensor(
(1, 16, 20, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"face_pixel_values": randn_tensor(
(1, 3, 77, 512, 512), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"timestep": torch.tensor([1.0]).to(torch_device, self.torch_dtype),
}
def prepare_init_args_and_inputs_for_common(self):
return WanAnimateTransformer3DTests().prepare_init_args_and_inputs_for_common()

View File

@@ -1,271 +0,0 @@
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
import torch
from diffusers import WanVACETransformer3DModel
from diffusers.utils.torch_utils import randn_tensor
from ...testing_utils import enable_full_determinism, torch_device
from ..testing_utils import (
AttentionTesterMixin,
BaseModelTesterConfig,
BitsAndBytesTesterMixin,
GGUFCompileTesterMixin,
GGUFTesterMixin,
MemoryTesterMixin,
ModelTesterMixin,
TorchAoTesterMixin,
TorchCompileTesterMixin,
TrainingTesterMixin,
)
enable_full_determinism()
class WanVACETransformer3DTesterConfig(BaseModelTesterConfig):
@property
def model_class(self):
return WanVACETransformer3DModel
@property
def pretrained_model_name_or_path(self):
return "hf-internal-testing/tiny-wan-vace-transformer"
@property
def output_shape(self) -> tuple[int, ...]:
return (16, 2, 16, 16)
@property
def input_shape(self) -> tuple[int, ...]:
return (16, 2, 16, 16)
@property
def main_input_name(self) -> str:
return "hidden_states"
@property
def generator(self):
return torch.Generator("cpu").manual_seed(0)
def get_init_dict(self) -> dict[str, int | list[int] | tuple | str | bool | None]:
return {
"patch_size": (1, 2, 2),
"num_attention_heads": 2,
"attention_head_dim": 12,
"in_channels": 16,
"out_channels": 16,
"text_dim": 32,
"freq_dim": 256,
"ffn_dim": 32,
"num_layers": 4,
"cross_attn_norm": True,
"qk_norm": "rms_norm_across_heads",
"rope_max_seq_len": 32,
"vace_layers": [0, 2],
"vace_in_channels": 48, # 3 * in_channels = 3 * 16 = 48
}
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
batch_size = 1
num_channels = 16
num_frames = 2
height = 16
width = 16
text_encoder_embedding_dim = 32
sequence_length = 12
# VACE requires control_hidden_states with vace_in_channels (3 * in_channels)
vace_in_channels = 48
return {
"hidden_states": randn_tensor(
(batch_size, num_channels, num_frames, height, width),
generator=self.generator,
device=torch_device,
),
"encoder_hidden_states": randn_tensor(
(batch_size, sequence_length, text_encoder_embedding_dim),
generator=self.generator,
device=torch_device,
),
"control_hidden_states": randn_tensor(
(batch_size, vace_in_channels, num_frames, height, width),
generator=self.generator,
device=torch_device,
),
"timestep": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device),
}
class TestWanVACETransformer3D(WanVACETransformer3DTesterConfig, ModelTesterMixin):
"""Core model tests for Wan VACE Transformer 3D."""
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"])
def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype):
# Skip: fp16/bf16 require very high atol to pass, providing little signal.
# Dtype preservation is already tested by test_from_save_pretrained_dtype and test_keep_in_fp32_modules.
pytest.skip("Tolerance requirements too high for meaningful test")
def test_model_parallelism(self, tmp_path):
# Skip: Device mismatch between cuda:0 and cuda:1 in VACE control flow
pytest.skip("Model parallelism not yet supported for WanVACE")
class TestWanVACETransformer3DMemory(WanVACETransformer3DTesterConfig, MemoryTesterMixin):
"""Memory optimization tests for Wan VACE Transformer 3D."""
class TestWanVACETransformer3DTraining(WanVACETransformer3DTesterConfig, TrainingTesterMixin):
"""Training tests for Wan VACE Transformer 3D."""
def test_gradient_checkpointing_is_applied(self):
expected_set = {"WanVACETransformer3DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class TestWanVACETransformer3DAttention(WanVACETransformer3DTesterConfig, AttentionTesterMixin):
"""Attention processor tests for Wan VACE Transformer 3D."""
class TestWanVACETransformer3DCompile(WanVACETransformer3DTesterConfig, TorchCompileTesterMixin):
"""Torch compile tests for Wan VACE Transformer 3D."""
def test_torch_compile_repeated_blocks(self):
# WanVACE has two block types (WanTransformerBlock and WanVACETransformerBlock),
# so we need recompile_limit=2 instead of the default 1.
import torch._dynamo
import torch._inductor.utils
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict).to(torch_device)
model.eval()
model.compile_repeated_blocks(fullgraph=True)
with (
torch._inductor.utils.fresh_inductor_cache(),
torch._dynamo.config.patch(recompile_limit=2),
):
_ = model(**inputs_dict)
_ = model(**inputs_dict)
class TestWanVACETransformer3DBitsAndBytes(WanVACETransformer3DTesterConfig, BitsAndBytesTesterMixin):
"""BitsAndBytes quantization tests for Wan VACE Transformer 3D."""
@property
def torch_dtype(self):
return torch.float16
def get_dummy_inputs(self):
"""Override to provide inputs matching the tiny Wan VACE model dimensions."""
return {
"hidden_states": randn_tensor(
(1, 16, 2, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"encoder_hidden_states": randn_tensor(
(1, 512, 4096), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"control_hidden_states": randn_tensor(
(1, 96, 2, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"timestep": torch.tensor([1.0]).to(torch_device, self.torch_dtype),
}
class TestWanVACETransformer3DTorchAo(WanVACETransformer3DTesterConfig, TorchAoTesterMixin):
"""TorchAO quantization tests for Wan VACE Transformer 3D."""
@property
def torch_dtype(self):
return torch.bfloat16
def get_dummy_inputs(self):
"""Override to provide inputs matching the tiny Wan VACE model dimensions."""
return {
"hidden_states": randn_tensor(
(1, 16, 2, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"encoder_hidden_states": randn_tensor(
(1, 512, 4096), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"control_hidden_states": randn_tensor(
(1, 96, 2, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"timestep": torch.tensor([1.0]).to(torch_device, self.torch_dtype),
}
class TestWanVACETransformer3DGGUF(WanVACETransformer3DTesterConfig, GGUFTesterMixin):
"""GGUF quantization tests for Wan VACE Transformer 3D."""
@property
def gguf_filename(self):
return "https://huggingface.co/QuantStack/Wan2.1_14B_VACE-GGUF/blob/main/Wan2.1_14B_VACE-Q3_K_S.gguf"
@property
def torch_dtype(self):
return torch.bfloat16
def get_dummy_inputs(self):
"""Override to provide inputs matching the real Wan VACE model dimensions.
Wan 2.1 VACE: in_channels=16, text_dim=4096, vace_in_channels=96
"""
return {
"hidden_states": randn_tensor(
(1, 16, 2, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"encoder_hidden_states": randn_tensor(
(1, 512, 4096), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"control_hidden_states": randn_tensor(
(1, 96, 2, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"timestep": torch.tensor([1.0]).to(torch_device, self.torch_dtype),
}
class TestWanVACETransformer3DGGUFCompile(WanVACETransformer3DTesterConfig, GGUFCompileTesterMixin):
"""GGUF + compile tests for Wan VACE Transformer 3D."""
@property
def gguf_filename(self):
return "https://huggingface.co/QuantStack/Wan2.1_14B_VACE-GGUF/blob/main/Wan2.1_14B_VACE-Q3_K_S.gguf"
@property
def torch_dtype(self):
return torch.bfloat16
def get_dummy_inputs(self):
"""Override to provide inputs matching the real Wan VACE model dimensions.
Wan 2.1 VACE: in_channels=16, text_dim=4096, vace_in_channels=96
"""
return {
"hidden_states": randn_tensor(
(1, 16, 2, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"encoder_hidden_states": randn_tensor(
(1, 512, 4096), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"control_hidden_states": randn_tensor(
(1, 96, 2, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"timestep": torch.tensor([1.0]).to(torch_device, self.torch_dtype),
}

View File

@@ -37,9 +37,6 @@ class ModularPipelineTesterMixin:
optional_params = frozenset(["num_inference_steps", "num_images_per_prompt", "latents", "output_type"])
# this is modular specific: generator needs to be a intermediate input because it's mutable
intermediate_params = frozenset(["generator"])
# Output type for the pipeline (e.g., "images" for image pipelines, "videos" for video pipelines)
# Subclasses can override this to change the expected output type
output_name = "images"
def get_generator(self, seed=0):
generator = torch.Generator("cpu").manual_seed(seed)
@@ -166,7 +163,7 @@ class ModularPipelineTesterMixin:
logger.setLevel(level=diffusers.logging.WARNING)
for batch_size, batched_input in zip(batch_sizes, batched_inputs):
output = pipe(**batched_input, output=self.output_name)
output = pipe(**batched_input, output="images")
assert len(output) == batch_size, "Output is different from expected batch size"
def test_inference_batch_single_identical(
@@ -200,16 +197,12 @@ class ModularPipelineTesterMixin:
if "batch_size" in inputs:
batched_inputs["batch_size"] = batch_size
output = pipe(**inputs, output=self.output_name)
output_batch = pipe(**batched_inputs, output=self.output_name)
output = pipe(**inputs, output="images")
output_batch = pipe(**batched_inputs, output="images")
assert output_batch.shape[0] == batch_size
# For batch comparison, we only need to compare the first item
if output_batch.shape[0] == batch_size and output.shape[0] == 1:
output_batch = output_batch[0:1]
max_diff = torch.abs(output_batch - output).max()
max_diff = torch.abs(output_batch[0] - output[0]).max()
assert max_diff < expected_max_diff, "Batch inference results different from single inference results"
@require_accelerator
@@ -224,32 +217,19 @@ class ModularPipelineTesterMixin:
# Reset generator in case it is used inside dummy inputs
if "generator" in inputs:
inputs["generator"] = self.get_generator(0)
output = pipe(**inputs, output=self.output_name)
output = pipe(**inputs, output="images")
fp16_inputs = self.get_dummy_inputs()
# Reset generator in case it is used inside dummy inputs
if "generator" in fp16_inputs:
fp16_inputs["generator"] = self.get_generator(0)
output_fp16 = pipe_fp16(**fp16_inputs, output="images")
output_fp16 = pipe_fp16(**fp16_inputs, output=self.output_name)
output = output.cpu()
output_fp16 = output_fp16.cpu()
output_tensor = output.float().cpu()
output_fp16_tensor = output_fp16.float().cpu()
# Check for NaNs in outputs (can happen with tiny models in FP16)
if torch.isnan(output_tensor).any() or torch.isnan(output_fp16_tensor).any():
pytest.skip("FP16 inference produces NaN values - this is a known issue with tiny models")
max_diff = numpy_cosine_similarity_distance(
output_tensor.flatten().numpy(), output_fp16_tensor.flatten().numpy()
)
# Check if cosine similarity is NaN (which can happen if vectors are zero or very small)
if torch.isnan(torch.tensor(max_diff)):
pytest.skip("Cosine similarity is NaN - outputs may be too small for reliable comparison")
assert max_diff < expected_max_diff, f"FP16 inference is different from FP32 inference (max_diff: {max_diff})"
max_diff = numpy_cosine_similarity_distance(output.flatten(), output_fp16.flatten())
assert max_diff < expected_max_diff, "FP16 inference is different from FP32 inference"
@require_accelerator
def test_to_device(self):
@@ -271,16 +251,14 @@ class ModularPipelineTesterMixin:
def test_inference_is_not_nan_cpu(self):
pipe = self.get_pipeline().to("cpu")
inputs = self.get_dummy_inputs()
output = pipe(**inputs, output=self.output_name)
output = pipe(**self.get_dummy_inputs(), output="images")
assert torch.isnan(output).sum() == 0, "CPU Inference returns NaN"
@require_accelerator
def test_inference_is_not_nan(self):
pipe = self.get_pipeline().to(torch_device)
inputs = self.get_dummy_inputs()
output = pipe(**inputs, output=self.output_name)
output = pipe(**self.get_dummy_inputs(), output="images")
assert torch.isnan(output).sum() == 0, "Accelerator Inference returns NaN"
def test_num_images_per_prompt(self):
@@ -300,7 +278,7 @@ class ModularPipelineTesterMixin:
if key in self.batch_params:
inputs[key] = batch_size * [inputs[key]]
images = pipe(**inputs, num_images_per_prompt=num_images_per_prompt, output=self.output_name)
images = pipe(**inputs, num_images_per_prompt=num_images_per_prompt, output="images")
assert images.shape[0] == batch_size * num_images_per_prompt
@@ -315,7 +293,8 @@ class ModularPipelineTesterMixin:
image_slices = []
for pipe in [base_pipe, offload_pipe]:
inputs = self.get_dummy_inputs()
image = pipe(**inputs, output=self.output_name)
image = pipe(**inputs, output="images")
image_slices.append(image[0, -3:, -3:, -1].flatten())
assert torch.abs(image_slices[0] - image_slices[1]).max() < 1e-3
@@ -336,7 +315,8 @@ class ModularPipelineTesterMixin:
image_slices = []
for pipe in pipes:
inputs = self.get_dummy_inputs()
image = pipe(**inputs, output=self.output_name)
image = pipe(**inputs, output="images")
image_slices.append(image[0, -3:, -3:, -1].flatten())
assert torch.abs(image_slices[0] - image_slices[1]).max() < 1e-3
@@ -351,13 +331,13 @@ class ModularGuiderTesterMixin:
pipe.update_components(guider=guider)
inputs = self.get_dummy_inputs()
out_no_cfg = pipe(**inputs, output=self.output_name)
out_no_cfg = pipe(**inputs, output="images")
# forward pass with CFG applied
guider = ClassifierFreeGuidance(guidance_scale=7.5)
pipe.update_components(guider=guider)
inputs = self.get_dummy_inputs()
out_cfg = pipe(**inputs, output=self.output_name)
out_cfg = pipe(**inputs, output="images")
assert out_cfg.shape == out_no_cfg.shape
max_diff = torch.abs(out_cfg - out_no_cfg).max()

View File

@@ -1,49 +0,0 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
from diffusers.modular_pipelines import WanBlocks, WanModularPipeline
from ..test_modular_pipelines_common import ModularPipelineTesterMixin
class TestWanModularPipelineFast(ModularPipelineTesterMixin):
pipeline_class = WanModularPipeline
pipeline_blocks_class = WanBlocks
pretrained_model_name_or_path = "hf-internal-testing/tiny-wan-modular-pipe"
params = frozenset(["prompt", "height", "width", "num_frames"])
batch_params = frozenset(["prompt"])
optional_params = frozenset(["num_inference_steps", "num_videos_per_prompt", "latents"])
output_name = "videos"
def get_dummy_inputs(self, seed=0):
generator = self.get_generator(seed)
inputs = {
"prompt": "A painting of a squirrel eating a burger",
"generator": generator,
"num_inference_steps": 2,
"height": 16,
"width": 16,
"num_frames": 9,
"max_sequence_length": 16,
"output_type": "pt",
}
return inputs
@pytest.mark.skip(reason="num_videos_per_prompt")
def test_num_images_per_prompt(self):
pass

View File

@@ -1,44 +0,0 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from diffusers.modular_pipelines import ZImageAutoBlocks, ZImageModularPipeline
from ..test_modular_pipelines_common import ModularPipelineTesterMixin
class TestZImageModularPipelineFast(ModularPipelineTesterMixin):
pipeline_class = ZImageModularPipeline
pipeline_blocks_class = ZImageAutoBlocks
pretrained_model_name_or_path = "hf-internal-testing/tiny-zimage-modular-pipe"
params = frozenset(["prompt", "height", "width"])
batch_params = frozenset(["prompt"])
def get_dummy_inputs(self, seed=0):
generator = self.get_generator(seed)
inputs = {
"prompt": "A painting of a squirrel eating a burger",
"generator": generator,
"num_inference_steps": 2,
"height": 32,
"width": 32,
"max_sequence_length": 16,
"output_type": "pt",
}
return inputs
def test_inference_batch_single_identical(self):
super().test_inference_batch_single_identical(expected_max_diff=5e-3)

View File

@@ -168,7 +168,7 @@ def assert_tensors_close(
max_diff = abs_diff.max().item()
flat_idx = abs_diff.argmax().item()
max_idx = tuple(idx.item() for idx in torch.unravel_index(torch.tensor(flat_idx), actual.shape))
max_idx = tuple(torch.unravel_index(torch.tensor(flat_idx), actual.shape).tolist())
threshold = atol + rtol * expected.abs()
mismatched = (abs_diff > threshold).sum().item()