Compare commits

...

15 Commits

Author SHA1 Message Date
Sayak Paul
4749c4f74a Merge branch 'main' into z-image-distillation-lora 2026-02-12 21:21:37 +05:30
dg845
427472eb00 [CI] Fix setuptools pkg_resources Errors (#13129)
Try to fix setuptools pkg_resources issue on CI
2026-02-12 17:48:44 +05:30
dg845
985d83c948 Fix LTX-2 Inference when num_videos_per_prompt > 1 and CFG is Enabled (#13121)
Fix LTX-2 inference when num_videos_per_prompt > 1 and CFG is enabled
2026-02-11 22:35:29 -08:00
Sayak Paul
ed77a246c9 [modular] add tests for robust model loading. (#13120)
* add tests for robust model loading.

* apply review feedback.
2026-02-12 10:04:29 +05:30
Miguel Martin
a1816166a5 Cosmos Transfer2.5 inference pipeline: general/{seg, depth, blur, edge} (#13066)
* initial conversion script

* cosmos control net block

* CosmosAttention

* base model conversion

* wip

* pipeline updates

* convert controlnet

* pipeline: working without controls

* wip

* debugging

* Almost working

* temp

* control working

* cleanup + detail on neg_encoder_hidden_states

* convert edge

* pos emb for control latents

* convert all chkpts

* resolve TODOs

* remove prints

* Docs

* add siglip image reference encoder

* Add unit tests

* controlnet: add duplicate layers

* Additional tests

* skip less

* skip less

* remove image_ref

* minor

* docs

* remove skipped test in transfer

* Don't crash process

* formatting

* revert some changes

* remove skipped test

* make style

* Address comment + fix example

* CosmosAttnProcessor2_0 revert + CosmosAttnProcessor2_5 changes

* make style

* make fix-copies
2026-02-11 18:33:09 -10:00
David El Malih
06a0f98e6e docs: improve docstring scheduling_flow_match_euler_discrete.py (#13127)
Improve docstring scheduling flow match euler discrete
2026-02-11 16:39:55 -08:00
Álvaro Somoza
2b16351270 conversion 2026-02-11 16:05:55 -03:00
Jared Wen
d32483913a [Fix]Allow prompt and prior_token_ids to be provided simultaneously in GlmImagePipeline (#13092)
* allow loose input

Signed-off-by: JaredforReal <w13431838023@gmail.com>

* add tests

Signed-off-by: JaredforReal <w13431838023@gmail.com>

* format test_glm_image

Signed-off-by: JaredforReal <w13431838023@gmail.com>

---------

Signed-off-by: JaredforReal <w13431838023@gmail.com>
2026-02-11 08:29:36 -10:00
David El Malih
64e2adf8f5 docs: improve docstring scheduling_edm_dpmsolver_multistep.py (#13122)
Improve docstring scheduling edm dpmsolver multistep
2026-02-11 08:59:33 -08:00
Dhruv Nair
c3a4cd14b8 [CI] Refactor Wan Model Tests (#13082)
* update

* update

* update

* update

* update

* update

* update

* update
2026-02-11 14:42:58 +05:30
Sayak Paul
4d00980e25 [lora] fix non-diffusers lora key handling for flux2 (#13119)
fix non-diffusers lora key handling for flux2
2026-02-11 08:06:36 +05:30
Álvaro Somoza
5bf248ddd8 [SkyReelsV2] Fix ftfy import (#13113)
fix
2026-02-10 12:56:13 +05:30
Dhruv Nair
bedc67c75f [Docs] Add guide for AutoModel with custom code (#13099)
update
2026-02-10 12:19:44 +05:30
Sayak Paul
20efb79d49 [modular] add modular tests for Z-Image and Wan (#13078)
* add wan modular tests

* style.

* add z-image tests and other fixes.

* style.

* increase tolerance for zimage

* style

* address reviewer feedback.

* address reviewer feedback.

* remove unneeded func

* simplify even more.
2026-02-09 08:27:59 -10:00
Linoy Tsaban
8933686770 Z image lora training (#13056)
* initial commit

* initial commit

* initial commit

* initial commit

* initial commit

* initial commit

* initial commit

* fix vae

* fix prompts

* Apply style fixes

* fix license

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2026-02-09 15:45:59 +02:00
47 changed files with 5982 additions and 277 deletions

View File

@@ -126,6 +126,11 @@ jobs:
- name: Install dependencies
run: |
# Install pkgs which depend on setuptools<81 for pkg_resources first with no build isolation
uv pip install pip==25.2 setuptools==80.10.2
uv pip install --no-build-isolation k-diffusion==0.0.12
uv pip install --upgrade pip setuptools
# Install the rest as normal
uv pip install -e ".[quality]"
uv pip install peft@git+https://github.com/huggingface/peft.git
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git

View File

@@ -29,8 +29,31 @@ text_encoder = AutoModel.from_pretrained(
)
```
## Custom models
[`AutoModel`] also loads models from the [Hub](https://huggingface.co/models) that aren't included in Diffusers. Set `trust_remote_code=True` in [`AutoModel.from_pretrained`] to load custom models.
A custom model repository needs a Python module with the model class, and a `config.json` with an `auto_map` entry that maps `"AutoModel"` to `"module_file.ClassName"`.
```
custom/custom-transformer-model/
├── config.json
├── my_model.py
└── diffusion_pytorch_model.safetensors
```
The `config.json` includes the `auto_map` field pointing to the custom class.
```json
{
"auto_map": {
"AutoModel": "my_model.MyCustomModel"
}
}
```
Then load it with `trust_remote_code=True`.
```py
import torch
from diffusers import AutoModel
@@ -40,7 +63,39 @@ transformer = AutoModel.from_pretrained(
)
```
For a real-world example, [Overworld/Waypoint-1-Small](https://huggingface.co/Overworld/Waypoint-1-Small/tree/main/transformer) hosts a custom `WorldModel` class across several modules in its `transformer` subfolder.
```
transformer/
├── config.json # auto_map: "model.WorldModel"
├── model.py
├── attn.py
├── nn.py
├── cache.py
├── quantize.py
├── __init__.py
└── diffusion_pytorch_model.safetensors
```
```py
import torch
from diffusers import AutoModel
transformer = AutoModel.from_pretrained(
"Overworld/Waypoint-1-Small", subfolder="transformer", trust_remote_code=True, torch_dtype=torch.bfloat16, device_map="cuda"
)
```
If the custom model inherits from the [`ModelMixin`] class, it gets access to the same features as Diffusers model classes, like [regional compilation](../optimization/fp16#regional-compilation) and [group offloading](../optimization/memory#group-offloading).
> [!WARNING]
> As a precaution with `trust_remote_code=True`, pass a commit hash to the `revision` argument in [`AutoModel.from_pretrained`] to make sure the code hasn't been updated with new malicious code (unless you fully trust the model owners).
>
> ```py
> transformer = AutoModel.from_pretrained(
> "Overworld/Waypoint-1-Small", subfolder="transformer", trust_remote_code=True, revision="a3d8cb2"
> )
> ```
> [!NOTE]
> Learn more about implementing custom models in the [Community components](../using-diffusers/custom_pipeline_overview#community-components) guide.

View File

@@ -0,0 +1,347 @@
# DreamBooth training example for Z-Image
[DreamBooth](https://huggingface.co/papers/2208.12242) is a method to personalize image generation models given just a few (3~5) images of a subject/concept.
[LoRA](https://huggingface.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora) is a popular parameter-efficient fine-tuning technique that allows you to achieve full-finetuning like performance but with a fraction of learnable parameters.
The `train_dreambooth_lora_z_image.py` script shows how to implement the training procedure for [LoRAs](https://huggingface.co/blog/lora) and adapt it for [Z-Image](https://huggingface.co/Tongyi-MAI/Z-Image).
> [!NOTE]
> **About Z-Image**
>
> Z-Image is a high-quality text-to-image generation model from Alibaba's Tongyi Lab. It uses a DiT (Diffusion Transformer) architecture with Qwen3 as the text encoder. The model excels at generating images with accurate text rendering, especially for Chinese characters.
> [!NOTE]
> **Memory consumption**
>
> Z-Image is relatively memory efficient compared to other large-scale diffusion models. Below we provide some tips and tricks to further reduce memory consumption during training.
## Running locally with PyTorch
### Installing the dependencies
Before running the scripts, make sure to install the library's training dependencies:
**Important**
To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
```bash
git clone https://github.com/huggingface/diffusers
cd diffusers
pip install -e .
```
Then cd in the `examples/dreambooth` folder and run
```bash
pip install -r requirements_z_image.txt
```
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
```bash
accelerate config
```
Or for a default accelerate configuration without answering questions about your environment
```bash
accelerate config default
```
Or if your environment doesn't support an interactive shell (e.g., a notebook)
```python
from accelerate.utils import write_basic_config
write_basic_config()
```
When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.
### Dog toy example
Now let's get our dataset. For this example we will use some dog images: https://huggingface.co/datasets/diffusers/dog-example.
Let's first download it locally:
```python
from huggingface_hub import snapshot_download
local_dir = "./dog"
snapshot_download(
"diffusers/dog-example",
local_dir=local_dir, repo_type="dataset",
ignore_patterns=".gitattributes",
)
```
This will also allow us to push the trained LoRA parameters to the Hugging Face Hub platform.
## Memory Optimizations
> [!NOTE]
> Many of these techniques complement each other and can be used together to further reduce memory consumption. However some techniques may be mutually exclusive so be sure to check before launching a training run.
### CPU Offloading
To offload parts of the model to CPU memory, you can use `--offload` flag. This will offload the VAE and text encoder to CPU memory and only move them to GPU when needed.
### Latent Caching
Pre-encode the training images with the VAE, and then delete it to free up some memory. To enable `latent_caching` simply pass `--cache_latents`.
### QLoRA: Low Precision Training with Quantization
Perform low precision training using 8-bit or 4-bit quantization to reduce memory usage. You can use the following flags:
- **FP8 training** with `torchao`:
Enable FP8 training by passing `--do_fp8_training`.
> [!IMPORTANT]
> Since we are utilizing FP8 tensor cores we need CUDA GPUs with compute capability at least 8.9 or greater. If you're looking for memory-efficient training on relatively older cards, we encourage you to check out other trainers.
- **NF4 training** with `bitsandbytes`:
Alternatively, you can use 8-bit or 4-bit quantization with `bitsandbytes` by passing `--bnb_quantization_config_path` to enable 4-bit NF4 quantization.
### Gradient Checkpointing and Accumulation
* `--gradient_accumulation` refers to the number of updates steps to accumulate before performing a backward/update pass. By passing a value > 1 you can reduce the amount of backward/update passes and hence also memory requirements.
* With `--gradient_checkpointing` we can save memory by not storing all intermediate activations during the forward pass. Instead, only a subset of these activations (the checkpoints) are stored and the rest is recomputed as needed during the backward pass. Note that this comes at the expense of a slower backward pass.
### 8-bit-Adam Optimizer
When training with `AdamW` (doesn't apply to `prodigy`) you can pass `--use_8bit_adam` to reduce the memory requirements of training. Make sure to install `bitsandbytes` if you want to do so.
### Image Resolution
An easy way to mitigate some of the memory requirements is through `--resolution`. `--resolution` refers to the resolution for input images, all the images in the train/validation dataset are resized to this.
Note that by default, images are resized to resolution of 1024, but it's good to keep in mind in case you're training on higher resolutions.
### Precision of saved LoRA layers
By default, trained transformer layers are saved in the precision dtype in which training was performed. E.g. when training in mixed precision is enabled with `--mixed_precision="bf16"`, final finetuned layers will be saved in `torch.bfloat16` as well.
This reduces memory requirements significantly without a significant quality loss. Note that if you do wish to save the final layers in float32 at the expense of more memory usage, you can do so by passing `--upcast_before_saving`.
## Training Examples
### Z-Image Training
To perform DreamBooth with LoRA on Z-Image, run:
```bash
export MODEL_NAME="Tongyi-MAI/Z-Image"
export INSTANCE_DIR="dog"
export OUTPUT_DIR="trained-z-image-lora"
accelerate launch train_dreambooth_lora_z_image.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--instance_data_dir=$INSTANCE_DIR \
--output_dir=$OUTPUT_DIR \
--mixed_precision="bf16" \
--gradient_checkpointing \
--cache_latents \
--instance_prompt="a photo of sks dog" \
--resolution=1024 \
--train_batch_size=1 \
--guidance_scale=5.0 \
--use_8bit_adam \
--gradient_accumulation_steps=4 \
--optimizer="adamW" \
--learning_rate=1e-4 \
--report_to="wandb" \
--lr_scheduler="constant" \
--lr_warmup_steps=100 \
--max_train_steps=500 \
--validation_prompt="A photo of sks dog in a bucket" \
--validation_epochs=25 \
--seed="0" \
--push_to_hub
```
To better track our training experiments, we're using the following flags in the command above:
* `report_to="wandb"` will ensure the training runs are tracked on [Weights and Biases](https://wandb.ai/site). To use it, be sure to install `wandb` with `pip install wandb`. Don't forget to call `wandb login <your_api_key>` before training if you haven't done it before.
* `validation_prompt` and `validation_epochs` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.
> [!NOTE]
> If you want to train using long prompts, you can use `--max_sequence_length` to set the token limit. The default is 512. Note that this will use more resources and may slow down the training in some cases.
### Training with FP8 Quantization
For reduced memory usage with FP8 training:
```bash
export MODEL_NAME="Tongyi-MAI/Z-Image"
export INSTANCE_DIR="dog"
export OUTPUT_DIR="trained-z-image-lora-fp8"
accelerate launch train_dreambooth_lora_z_image.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--instance_data_dir=$INSTANCE_DIR \
--output_dir=$OUTPUT_DIR \
--do_fp8_training \
--gradient_checkpointing \
--cache_latents \
--instance_prompt="a photo of sks dog" \
--resolution=1024 \
--train_batch_size=1 \
--guidance_scale=5.0 \
--use_8bit_adam \
--gradient_accumulation_steps=4 \
--optimizer="adamW" \
--learning_rate=1e-4 \
--report_to="wandb" \
--lr_scheduler="constant" \
--lr_warmup_steps=100 \
--max_train_steps=500 \
--validation_prompt="A photo of sks dog in a bucket" \
--validation_epochs=25 \
--seed="0" \
--push_to_hub
```
### FSDP on the transformer
By setting the accelerate configuration with FSDP, the transformer block will be wrapped automatically. E.g. set the configuration to:
```yaml
distributed_type: FSDP
fsdp_config:
fsdp_version: 2
fsdp_offload_params: false
fsdp_sharding_strategy: HYBRID_SHARD
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: ZImageTransformerBlock
fsdp_forward_prefetch: true
fsdp_sync_module_states: false
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_use_orig_params: false
fsdp_activation_checkpointing: true
fsdp_reshard_after_forward: true
fsdp_cpu_ram_efficient_loading: false
```
### Prodigy Optimizer
Prodigy is an adaptive optimizer that dynamically adjusts the learning rate learned parameters based on past gradients, allowing for more efficient convergence.
By using prodigy we can "eliminate" the need for manual learning rate tuning. Read more [here](https://huggingface.co/blog/sdxl_lora_advanced_script#adaptive-optimizers).
To use prodigy, first make sure to install the prodigyopt library: `pip install prodigyopt`, and then specify:
```bash
--optimizer="prodigy"
```
> [!TIP]
> When using prodigy it's generally good practice to set `--learning_rate=1.0`
```bash
export MODEL_NAME="Tongyi-MAI/Z-Image"
export INSTANCE_DIR="dog"
export OUTPUT_DIR="trained-z-image-lora-prodigy"
accelerate launch train_dreambooth_lora_z_image.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--instance_data_dir=$INSTANCE_DIR \
--output_dir=$OUTPUT_DIR \
--mixed_precision="bf16" \
--gradient_checkpointing \
--cache_latents \
--instance_prompt="a photo of sks dog" \
--resolution=1024 \
--train_batch_size=1 \
--guidance_scale=5.0 \
--gradient_accumulation_steps=4 \
--optimizer="prodigy" \
--learning_rate=1.0 \
--report_to="wandb" \
--lr_scheduler="constant_with_warmup" \
--lr_warmup_steps=100 \
--max_train_steps=500 \
--validation_prompt="A photo of sks dog in a bucket" \
--validation_epochs=25 \
--seed="0" \
--push_to_hub
```
### LoRA Rank and Alpha
Two key LoRA hyperparameters are LoRA rank and LoRA alpha:
- `--rank`: Defines the dimension of the trainable LoRA matrices. A higher rank means more expressiveness and capacity to learn (and more parameters).
- `--lora_alpha`: A scaling factor for the LoRA's output. The LoRA update is scaled by `lora_alpha / lora_rank`.
**lora_alpha vs. rank:**
This ratio dictates the LoRA's effective strength:
- `lora_alpha == rank`: Scaling factor is 1. The LoRA is applied with its learned strength. (e.g., alpha=16, rank=16)
- `lora_alpha < rank`: Scaling factor < 1. Reduces the LoRA's impact. Useful for subtle changes or to prevent overpowering the base model. (e.g., alpha=8, rank=16)
- `lora_alpha > rank`: Scaling factor > 1. Amplifies the LoRA's impact. Allows a lower rank LoRA to have a stronger effect. (e.g., alpha=32, rank=16)
> [!TIP]
> A common starting point is to set `lora_alpha` equal to `rank`.
> Some also set `lora_alpha` to be twice the `rank` (e.g., lora_alpha=32 for lora_rank=16)
> to give the LoRA updates more influence without increasing parameter count.
> If you find your LoRA is "overcooking" or learning too aggressively, consider setting `lora_alpha` to half of `rank`
> (e.g., lora_alpha=8 for rank=16). Experimentation is often key to finding the optimal balance for your use case.
### Target Modules
When LoRA was first adapted from language models to diffusion models, it was applied to the cross-attention layers in the UNet that relate the image representations with the prompts that describe them.
More recently, SOTA text-to-image diffusion models replaced the UNet with a diffusion Transformer (DiT). With this change, we may also want to explore applying LoRA training onto different types of layers and blocks.
To allow more flexibility and control over the targeted modules we added `--lora_layers`, in which you can specify in a comma separated string the exact modules for LoRA training. Here are some examples of target modules you can provide:
- For attention only layers: `--lora_layers="to_k,to_q,to_v,to_out.0"`
- For attention and feed-forward layers: `--lora_layers="to_k,to_q,to_v,to_out.0,ff.net.0.proj,ff.net.2"`
> [!NOTE]
> `--lora_layers` can also be used to specify which **blocks** to apply LoRA training to. To do so, simply add a block prefix to each layer in the comma separated string.
> [!NOTE]
> Keep in mind that while training more layers can improve quality and expressiveness, it also increases the size of the output LoRA weights.
### Aspect Ratio Bucketing
We've added aspect ratio bucketing support which allows training on images with different aspect ratios without cropping them to a single square resolution. This technique helps preserve the original composition of training images and can improve training efficiency.
To enable aspect ratio bucketing, pass `--aspect_ratio_buckets` argument with a semicolon-separated list of height,width pairs, such as:
```bash
--aspect_ratio_buckets="672,1568;688,1504;720,1456;752,1392;800,1328;832,1248;880,1184;944,1104;1024,1024;1104,944;1184,880;1248,832;1328,800;1392,752;1456,720;1504,688;1568,672"
```
### Bilingual Prompts
Z-Image has strong support for both Chinese and English prompts. When training with Chinese prompts, ensure your dataset captions are properly encoded in UTF-8:
```bash
--instance_prompt="一只sks狗的照片"
--validation_prompt="一只sks狗在桶里的照片"
```
> [!TIP]
> Z-Image excels at text rendering in generated images, especially for Chinese characters. If your use case involves generating images with text, consider including text-related examples in your training data.
## Inference
Once you have trained a LoRA, you can load it for inference:
```python
import torch
from diffusers import ZImagePipeline
pipe = ZImagePipeline.from_pretrained("Tongyi-MAI/Z-Image", torch_dtype=torch.bfloat16)
pipe.to("cuda")
# Load your trained LoRA
pipe.load_lora_weights("path/to/your/trained-z-image-lora")
# Generate an image
image = pipe(
prompt="A photo of sks dog in a bucket",
height=1024,
width=1024,
num_inference_steps=50,
guidance_scale=5.0,
generator=torch.Generator("cuda").manual_seed(42),
).images[0]
image.save("output.png")
```
---
Since Z-Image finetuning is still in an experimental phase, we encourage you to explore different settings and share your insights! 🤗

File diff suppressed because it is too large Load Diff

View File

@@ -78,12 +78,67 @@ python scripts/convert_cosmos_to_diffusers.py \
--save_pipeline
```
# Cosmos 2.5 Transfer
Download checkpoint
```bash
hf download nvidia/Cosmos-Transfer2.5-2B
```
Convert checkpoint
```bash
# depth
transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Transfer2.5-2B/snapshots/eb5325b77d358944da58a690157dd2b8071bbf85/general/depth/626e6618-bfcd-4d9a-a077-1409e2ce353f_ema_bf16.pt
python scripts/convert_cosmos_to_diffusers.py \
--transformer_type Cosmos-2.5-Transfer-General-2B \
--transformer_ckpt_path $transformer_ckpt_path \
--vae_type wan2.1 \
--output_path converted/transfer/2b/general/depth \
--save_pipeline
# edge
transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Transfer2.5-2B/snapshots/eb5325b77d358944da58a690157dd2b8071bbf85/general/edge/61f5694b-0ad5-4ecd-8ad7-c8545627d125_ema_bf16.pt
python scripts/convert_cosmos_to_diffusers.py \
--transformer_type Cosmos-2.5-Transfer-General-2B \
--transformer_ckpt_path $transformer_ckpt_path \
--vae_type wan2.1 \
--output_path converted/transfer/2b/general/edge/pipeline \
--save_pipeline
python scripts/convert_cosmos_to_diffusers.py \
--transformer_type Cosmos-2.5-Transfer-General-2B \
--transformer_ckpt_path $transformer_ckpt_path \
--vae_type wan2.1 \
--output_path converted/transfer/2b/general/edge/models
# blur
transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Transfer2.5-2B/snapshots/eb5325b77d358944da58a690157dd2b8071bbf85/general/blur/ba2f44f2-c726-4fe7-949f-597069d9b91c_ema_bf16.pt
python scripts/convert_cosmos_to_diffusers.py \
--transformer_type Cosmos-2.5-Transfer-General-2B \
--transformer_ckpt_path $transformer_ckpt_path \
--vae_type wan2.1 \
--output_path converted/transfer/2b/general/blur \
--save_pipeline
# seg
transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Transfer2.5-2B/snapshots/eb5325b77d358944da58a690157dd2b8071bbf85/general/seg/5136ef49-6d8d-42e8-8abf-7dac722a304a_ema_bf16.pt
python scripts/convert_cosmos_to_diffusers.py \
--transformer_type Cosmos-2.5-Transfer-General-2B \
--transformer_ckpt_path $transformer_ckpt_path \
--vae_type wan2.1 \
--output_path converted/transfer/2b/general/seg \
--save_pipeline
```
"""
import argparse
import pathlib
import sys
from typing import Any, Dict
from typing import Any, Dict, Optional
import torch
from accelerate import init_empty_weights
@@ -95,6 +150,7 @@ from diffusers import (
AutoencoderKLWan,
Cosmos2TextToImagePipeline,
Cosmos2VideoToWorldPipeline,
CosmosControlNetModel,
CosmosTextToWorldPipeline,
CosmosTransformer3DModel,
CosmosVideoToWorldPipeline,
@@ -103,6 +159,7 @@ from diffusers import (
UniPCMultistepScheduler,
)
from diffusers.pipelines.cosmos.pipeline_cosmos2_5_predict import Cosmos2_5_PredictBasePipeline
from diffusers.pipelines.cosmos.pipeline_cosmos2_5_transfer import Cosmos2_5_TransferPipeline
def remove_keys_(key: str, state_dict: Dict[str, Any]):
@@ -356,8 +413,62 @@ TRANSFORMER_CONFIGS = {
"crossattn_proj_in_channels": 100352,
"encoder_hidden_states_channels": 1024,
},
"Cosmos-2.5-Transfer-General-2B": {
"in_channels": 16 + 1,
"out_channels": 16,
"num_attention_heads": 16,
"attention_head_dim": 128,
"num_layers": 28,
"mlp_ratio": 4.0,
"text_embed_dim": 1024,
"adaln_lora_dim": 256,
"max_size": (128, 240, 240),
"patch_size": (1, 2, 2),
"rope_scale": (1.0, 3.0, 3.0),
"concat_padding_mask": True,
"extra_pos_embed_type": None,
"use_crossattn_projection": True,
"crossattn_proj_in_channels": 100352,
"encoder_hidden_states_channels": 1024,
"controlnet_block_every_n": 7,
"img_context_dim_in": 1152,
"img_context_dim_out": 2048,
"img_context_num_tokens": 256,
},
}
CONTROLNET_CONFIGS = {
"Cosmos-2.5-Transfer-General-2B": {
"n_controlnet_blocks": 4,
"model_channels": 2048,
"in_channels": 130,
"latent_channels": 18, # (16 latent + 1 condition_mask) + 1 padding_mask = 18
"num_attention_heads": 16,
"attention_head_dim": 128,
"mlp_ratio": 4.0,
"text_embed_dim": 1024,
"adaln_lora_dim": 256,
"patch_size": (1, 2, 2),
"max_size": (128, 240, 240),
"rope_scale": (1.0, 3.0, 3.0),
"extra_pos_embed_type": None,
"img_context_dim_in": 1152,
"img_context_dim_out": 2048,
"use_crossattn_projection": True,
"crossattn_proj_in_channels": 100352,
"encoder_hidden_states_channels": 1024,
},
}
CONTROLNET_KEYS_RENAME_DICT = {
**TRANSFORMER_KEYS_RENAME_DICT_COSMOS_2_0,
"blocks": "blocks",
"control_embedder.proj.1": "patch_embed.proj",
}
CONTROLNET_SPECIAL_KEYS_REMAP = {**TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_2_0}
VAE_KEYS_RENAME_DICT = {
"down.0": "down_blocks.0",
"down.1": "down_blocks.1",
@@ -447,9 +558,12 @@ def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]:
return state_dict
def convert_transformer(transformer_type: str, ckpt_path: str, weights_only: bool = True):
def convert_transformer(
transformer_type: str,
state_dict: Optional[Dict[str, Any]] = None,
weights_only: bool = True,
):
PREFIX_KEY = "net."
original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", weights_only=weights_only))
if "Cosmos-1.0" in transformer_type:
TRANSFORMER_KEYS_RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT_COSMOS_1_0
@@ -467,23 +581,29 @@ def convert_transformer(transformer_type: str, ckpt_path: str, weights_only: boo
config = TRANSFORMER_CONFIGS[transformer_type]
transformer = CosmosTransformer3DModel(**config)
for key in list(original_state_dict.keys()):
old2new = {}
new2old = {}
for key in list(state_dict.keys()):
new_key = key[:]
if new_key.startswith(PREFIX_KEY):
new_key = new_key.removeprefix(PREFIX_KEY)
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
new_key = new_key.replace(replace_key, rename_key)
print(key, "->", new_key, flush=True)
update_state_dict_(original_state_dict, key, new_key)
assert new_key not in new2old, f"new key {new_key} already mapped"
assert key not in old2new, f"old key {key} already mapped"
old2new[key] = new_key
new2old[new_key] = key
update_state_dict_(state_dict, key, new_key)
for key in list(original_state_dict.keys()):
for key in list(state_dict.keys()):
for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
if special_key not in key:
continue
handler_fn_inplace(key, original_state_dict)
handler_fn_inplace(key, state_dict)
expected_keys = set(transformer.state_dict().keys())
mapped_keys = set(original_state_dict.keys())
mapped_keys = set(state_dict.keys())
missing_keys = expected_keys - mapped_keys
unexpected_keys = mapped_keys - expected_keys
if missing_keys:
@@ -497,10 +617,86 @@ def convert_transformer(transformer_type: str, ckpt_path: str, weights_only: boo
print(k)
sys.exit(2)
transformer.load_state_dict(original_state_dict, strict=True, assign=True)
transformer.load_state_dict(state_dict, strict=True, assign=True)
return transformer
def convert_controlnet(
transformer_type: str,
control_state_dict: Dict[str, Any],
base_state_dict: Dict[str, Any],
weights_only: bool = True,
):
"""
Convert controlnet weights.
Args:
transformer_type: The type of transformer/controlnet
control_state_dict: State dict containing controlnet-specific weights
base_state_dict: State dict containing base transformer weights (for shared modules)
weights_only: Whether to use weights_only loading
"""
if transformer_type not in CONTROLNET_CONFIGS:
raise AssertionError(f"{transformer_type} does not define a ControlNet config")
PREFIX_KEY = "net."
# Process control-specific keys
for key in list(control_state_dict.keys()):
new_key = key[:]
if new_key.startswith(PREFIX_KEY):
new_key = new_key.removeprefix(PREFIX_KEY)
for replace_key, rename_key in CONTROLNET_KEYS_RENAME_DICT.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict_(control_state_dict, key, new_key)
for key in list(control_state_dict.keys()):
for special_key, handler_fn_inplace in CONTROLNET_SPECIAL_KEYS_REMAP.items():
if special_key not in key:
continue
handler_fn_inplace(key, control_state_dict)
# Copy shared weights from base transformer to controlnet
# These are the duplicated modules: patch_embed_base, time_embed, learnable_pos_embed, img_context_proj, crossattn_proj
shared_module_mappings = {
# transformer key prefix -> controlnet key prefix
"patch_embed.": "patch_embed_base.",
"time_embed.": "time_embed.",
"learnable_pos_embed.": "learnable_pos_embed.",
"img_context_proj.": "img_context_proj.",
"crossattn_proj.": "crossattn_proj.",
}
for key in list(base_state_dict.keys()):
for transformer_prefix, controlnet_prefix in shared_module_mappings.items():
if key.startswith(transformer_prefix):
controlnet_key = controlnet_prefix + key[len(transformer_prefix) :]
control_state_dict[controlnet_key] = base_state_dict[key].clone()
print(f"Copied shared weight: {key} -> {controlnet_key}", flush=True)
break
cfg = CONTROLNET_CONFIGS[transformer_type]
controlnet = CosmosControlNetModel(**cfg)
expected_keys = set(controlnet.state_dict().keys())
mapped_keys = set(control_state_dict.keys())
missing_keys = expected_keys - mapped_keys
unexpected_keys = mapped_keys - expected_keys
if missing_keys:
print(f"WARNING: missing controlnet keys ({len(missing_keys)}):", file=sys.stderr, flush=True)
for k in sorted(missing_keys):
print(k, file=sys.stderr)
sys.exit(3)
if unexpected_keys:
print(f"WARNING: unexpected controlnet keys ({len(unexpected_keys)}):", file=sys.stderr, flush=True)
for k in sorted(unexpected_keys):
print(k, file=sys.stderr)
sys.exit(4)
controlnet.load_state_dict(control_state_dict, strict=True, assign=True)
return controlnet
def convert_vae(vae_type: str):
model_name = VAE_CONFIGS[vae_type]["name"]
snapshot_directory = snapshot_download(model_name, repo_type="model")
@@ -586,7 +782,7 @@ def save_pipeline_cosmos_2_0(args, transformer, vae):
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
def save_pipeline_cosmos2_5(args, transformer, vae):
def save_pipeline_cosmos2_5_predict(args, transformer, vae):
text_encoder_path = args.text_encoder_path or "nvidia/Cosmos-Reason1-7B"
tokenizer_path = args.tokenizer_path or "Qwen/Qwen2.5-VL-7B-Instruct"
@@ -614,6 +810,35 @@ def save_pipeline_cosmos2_5(args, transformer, vae):
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
def save_pipeline_cosmos2_5_transfer(args, transformer, controlnet, vae):
text_encoder_path = args.text_encoder_path or "nvidia/Cosmos-Reason1-7B"
tokenizer_path = args.tokenizer_path or "Qwen/Qwen2.5-VL-7B-Instruct"
text_encoder = Qwen2_5_VLForConditionalGeneration.from_pretrained(
text_encoder_path, torch_dtype="auto", device_map="cpu"
)
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
scheduler = UniPCMultistepScheduler(
use_karras_sigmas=True,
use_flow_sigmas=True,
prediction_type="flow_prediction",
sigma_max=200.0,
sigma_min=0.01,
)
pipe = Cosmos2_5_TransferPipeline(
text_encoder=text_encoder,
tokenizer=tokenizer,
transformer=transformer,
controlnet=controlnet,
vae=vae,
scheduler=scheduler,
safety_checker=lambda *args, **kwargs: None,
)
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--transformer_type", type=str, default=None, choices=list(TRANSFORMER_CONFIGS.keys()))
@@ -642,18 +867,61 @@ if __name__ == "__main__":
args = get_args()
transformer = None
controlnet = None
dtype = DTYPE_MAPPING[args.dtype]
if args.save_pipeline:
assert args.transformer_ckpt_path is not None
assert args.vae_type is not None
raw_state_dict = None
if args.transformer_ckpt_path is not None:
weights_only = "Cosmos-1.0" in args.transformer_type
transformer = convert_transformer(args.transformer_type, args.transformer_ckpt_path, weights_only)
transformer = transformer.to(dtype=dtype)
if not args.save_pipeline:
transformer.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
raw_state_dict = get_state_dict(
torch.load(args.transformer_ckpt_path, map_location="cpu", weights_only=weights_only)
)
if raw_state_dict is not None:
if "Transfer" in args.transformer_type:
base_state_dict = {}
control_state_dict = {}
for k, v in raw_state_dict.items():
plain_key = k.removeprefix("net.") if k.startswith("net.") else k
if "control" in plain_key.lower():
control_state_dict[k] = v
else:
base_state_dict[k] = v
assert len(base_state_dict.keys() & control_state_dict.keys()) == 0
# Convert transformer first to get the processed base state dict
transformer = convert_transformer(
args.transformer_type, state_dict=base_state_dict, weights_only=weights_only
)
transformer = transformer.to(dtype=dtype)
# Get converted transformer state dict to copy shared weights to controlnet
converted_base_state_dict = transformer.state_dict()
# Convert controlnet with both control-specific and shared weights from transformer
controlnet = convert_controlnet(
args.transformer_type, control_state_dict, converted_base_state_dict, weights_only=weights_only
)
controlnet = controlnet.to(dtype=dtype)
if not args.save_pipeline:
transformer.save_pretrained(
pathlib.Path(args.output_path) / "transformer", safe_serialization=True, max_shard_size="5GB"
)
controlnet.save_pretrained(
pathlib.Path(args.output_path) / "controlnet", safe_serialization=True, max_shard_size="5GB"
)
else:
transformer = convert_transformer(
args.transformer_type, state_dict=raw_state_dict, weights_only=weights_only
)
transformer = transformer.to(dtype=dtype)
if not args.save_pipeline:
transformer.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
if args.vae_type is not None:
if "Cosmos-1.0" in args.transformer_type:
@@ -667,6 +935,8 @@ if __name__ == "__main__":
if not args.save_pipeline:
vae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
else:
vae = None
if args.save_pipeline:
if "Cosmos-1.0" in args.transformer_type:
@@ -678,6 +948,11 @@ if __name__ == "__main__":
assert args.tokenizer_path is not None
save_pipeline_cosmos_2_0(args, transformer, vae)
elif "Cosmos-2.5" in args.transformer_type:
save_pipeline_cosmos2_5(args, transformer, vae)
if "Predict" in args.transformer_type:
save_pipeline_cosmos2_5_predict(args, transformer, vae)
elif "Transfer" in args.transformer_type:
save_pipeline_cosmos2_5_transfer(args, transformer, None, vae)
else:
raise AssertionError(f"{args.transformer_type} not supported")
else:
raise AssertionError(f"{args.transformer_type} not supported")

View File

@@ -221,6 +221,7 @@ else:
"ControlNetModel",
"ControlNetUnionModel",
"ControlNetXSAdapter",
"CosmosControlNetModel",
"CosmosTransformer3DModel",
"DiTTransformer2DModel",
"EasyAnimateTransformer3DModel",
@@ -485,6 +486,7 @@ else:
"CogView4Pipeline",
"ConsisIDPipeline",
"Cosmos2_5_PredictBasePipeline",
"Cosmos2_5_TransferPipeline",
"Cosmos2TextToImagePipeline",
"Cosmos2VideoToWorldPipeline",
"CosmosTextToWorldPipeline",
@@ -992,6 +994,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
ControlNetModel,
ControlNetUnionModel,
ControlNetXSAdapter,
CosmosControlNetModel,
CosmosTransformer3DModel,
DiTTransformer2DModel,
EasyAnimateTransformer3DModel,
@@ -1226,6 +1229,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
CogView4Pipeline,
ConsisIDPipeline,
Cosmos2_5_PredictBasePipeline,
Cosmos2_5_TransferPipeline,
Cosmos2TextToImagePipeline,
Cosmos2VideoToWorldPipeline,
CosmosTextToWorldPipeline,

View File

@@ -2321,6 +2321,14 @@ def _convert_non_diffusers_flux2_lora_to_diffusers(state_dict):
prefix = "diffusion_model."
original_state_dict = {k[len(prefix) :]: v for k, v in state_dict.items()}
has_lora_down_up = any("lora_down" in k or "lora_up" in k for k in original_state_dict.keys())
if has_lora_down_up:
temp_state_dict = {}
for k, v in original_state_dict.items():
new_key = k.replace("lora_down", "lora_A").replace("lora_up", "lora_B")
temp_state_dict[new_key] = v
original_state_dict = temp_state_dict
num_double_layers = 0
num_single_layers = 0
for key in original_state_dict.keys():
@@ -2337,13 +2345,15 @@ def _convert_non_diffusers_flux2_lora_to_diffusers(state_dict):
attn_prefix = f"single_transformer_blocks.{sl}.attn"
for lora_key in lora_keys:
converted_state_dict[f"{attn_prefix}.to_qkv_mlp_proj.{lora_key}.weight"] = original_state_dict.pop(
f"{single_block_prefix}.linear1.{lora_key}.weight"
)
linear1_key = f"{single_block_prefix}.linear1.{lora_key}.weight"
if linear1_key in original_state_dict:
converted_state_dict[f"{attn_prefix}.to_qkv_mlp_proj.{lora_key}.weight"] = original_state_dict.pop(
linear1_key
)
converted_state_dict[f"{attn_prefix}.to_out.{lora_key}.weight"] = original_state_dict.pop(
f"{single_block_prefix}.linear2.{lora_key}.weight"
)
linear2_key = f"{single_block_prefix}.linear2.{lora_key}.weight"
if linear2_key in original_state_dict:
converted_state_dict[f"{attn_prefix}.to_out.{lora_key}.weight"] = original_state_dict.pop(linear2_key)
for dl in range(num_double_layers):
transformer_block_prefix = f"transformer_blocks.{dl}"
@@ -2352,6 +2362,10 @@ def _convert_non_diffusers_flux2_lora_to_diffusers(state_dict):
for attn_type in attn_types:
attn_prefix = f"{transformer_block_prefix}.attn"
qkv_key = f"double_blocks.{dl}.{attn_type}.qkv.{lora_key}.weight"
if qkv_key not in original_state_dict:
continue
fused_qkv_weight = original_state_dict.pop(qkv_key)
if lora_key == "lora_A":
@@ -2383,8 +2397,9 @@ def _convert_non_diffusers_flux2_lora_to_diffusers(state_dict):
for org_proj, diff_proj in proj_mappings:
for lora_key in lora_keys:
original_key = f"double_blocks.{dl}.{org_proj}.{lora_key}.weight"
diffusers_key = f"{transformer_block_prefix}.{diff_proj}.{lora_key}.weight"
converted_state_dict[diffusers_key] = original_state_dict.pop(original_key)
if original_key in original_state_dict:
diffusers_key = f"{transformer_block_prefix}.{diff_proj}.{lora_key}.weight"
converted_state_dict[diffusers_key] = original_state_dict.pop(original_key)
mlp_mappings = [
("img_mlp.0", "ff.linear_in"),
@@ -2395,8 +2410,27 @@ def _convert_non_diffusers_flux2_lora_to_diffusers(state_dict):
for org_mlp, diff_mlp in mlp_mappings:
for lora_key in lora_keys:
original_key = f"double_blocks.{dl}.{org_mlp}.{lora_key}.weight"
diffusers_key = f"{transformer_block_prefix}.{diff_mlp}.{lora_key}.weight"
converted_state_dict[diffusers_key] = original_state_dict.pop(original_key)
if original_key in original_state_dict:
diffusers_key = f"{transformer_block_prefix}.{diff_mlp}.{lora_key}.weight"
converted_state_dict[diffusers_key] = original_state_dict.pop(original_key)
extra_mappings = {
"img_in": "x_embedder",
"txt_in": "context_embedder",
"time_in.in_layer": "time_guidance_embed.timestep_embedder.linear_1",
"time_in.out_layer": "time_guidance_embed.timestep_embedder.linear_2",
"final_layer.linear": "proj_out",
"final_layer.adaLN_modulation.1": "norm_out.linear",
"single_stream_modulation.lin": "single_stream_modulation.linear",
"double_stream_modulation_img.lin": "double_stream_modulation_img.linear",
"double_stream_modulation_txt.lin": "double_stream_modulation_txt.linear",
}
for org_key, diff_key in extra_mappings.items():
for lora_key in lora_keys:
original_key = f"{org_key}.{lora_key}.weight"
if original_key in original_state_dict:
converted_state_dict[f"{diff_key}.{lora_key}.weight"] = original_state_dict.pop(original_key)
if len(original_state_dict) > 0:
raise ValueError(f"`original_state_dict` should be empty at this point but has {original_state_dict.keys()=}.")
@@ -2421,18 +2455,22 @@ def _convert_non_diffusers_z_image_lora_to_diffusers(state_dict):
if has_diffusion_model:
state_dict = {k.removeprefix("diffusion_model."): v for k, v in state_dict.items()}
has_lora_unet = any(k.startswith("lora_unet_") for k in state_dict)
has_lora_unet = any(k.startswith("lora_unet_") or k.startswith("lora_unet__") for k in state_dict)
if has_lora_unet:
state_dict = {k.removeprefix("lora_unet_"): v for k, v in state_dict.items()}
state_dict = {k.removeprefix("lora_unet__").removeprefix("lora_unet_"): v for k, v in state_dict.items()}
def convert_key(key: str) -> str:
# ZImage has: layers, noise_refiner, context_refiner blocks
# Keys may be like: layers_0_attention_to_q.lora_down.weight
if "." in key:
base, suffix = key.rsplit(".", 1)
else:
base, suffix = key, ""
suffix = ""
for sfx in (".lora_down.weight", ".lora_up.weight", ".alpha"):
if key.endswith(sfx):
base = key[: -len(sfx)]
suffix = sfx
break
else:
base = key
# Protected n-grams that must keep their internal underscores
protected = {
@@ -2443,6 +2481,9 @@ def _convert_non_diffusers_z_image_lora_to_diffusers(state_dict):
("to", "out"),
# feed_forward
("feed", "forward"),
# noise and context refiner
("noise", "refiner"),
("context", "refiner"),
}
prot_by_len = {}
@@ -2467,7 +2508,7 @@ def _convert_non_diffusers_z_image_lora_to_diffusers(state_dict):
i += 1
converted_base = ".".join(merged)
return converted_base + (("." + suffix) if suffix else "")
return converted_base + suffix
state_dict = {convert_key(k): v for k, v in state_dict.items()}

View File

@@ -54,6 +54,7 @@ if is_torch_available():
_import_structure["autoencoders.vq_model"] = ["VQModel"]
_import_structure["cache_utils"] = ["CacheMixin"]
_import_structure["controlnets.controlnet"] = ["ControlNetModel"]
_import_structure["controlnets.controlnet_cosmos"] = ["CosmosControlNetModel"]
_import_structure["controlnets.controlnet_flux"] = ["FluxControlNetModel", "FluxMultiControlNetModel"]
_import_structure["controlnets.controlnet_hunyuan"] = [
"HunyuanDiT2DControlNetModel",
@@ -175,6 +176,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
ControlNetModel,
ControlNetUnionModel,
ControlNetXSAdapter,
CosmosControlNetModel,
FluxControlNetModel,
FluxMultiControlNetModel,
HunyuanDiT2DControlNetModel,

View File

@@ -3,6 +3,7 @@ from ...utils import is_flax_available, is_torch_available
if is_torch_available():
from .controlnet import ControlNetModel, ControlNetOutput
from .controlnet_cosmos import CosmosControlNetModel
from .controlnet_flux import FluxControlNetModel, FluxControlNetOutput, FluxMultiControlNetModel
from .controlnet_hunyuan import (
HunyuanControlNetOutput,

View File

@@ -0,0 +1,312 @@
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin
from ...utils import BaseOutput, is_torchvision_available, logging
from ..modeling_utils import ModelMixin
from ..transformers.transformer_cosmos import (
CosmosEmbedding,
CosmosLearnablePositionalEmbed,
CosmosPatchEmbed,
CosmosRotaryPosEmbed,
CosmosTransformerBlock,
)
if is_torchvision_available():
from torchvision import transforms
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@dataclass
class CosmosControlNetOutput(BaseOutput):
"""
Output of [`CosmosControlNetModel`].
Args:
control_block_samples (`list[torch.Tensor]`):
List of control block activations to be injected into transformer blocks.
"""
control_block_samples: List[torch.Tensor]
class CosmosControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
r"""
ControlNet for Cosmos Transfer2.5.
This model duplicates the shared embedding modules from the transformer (patch_embed, time_embed,
learnable_pos_embed, img_context_proj) to enable proper CPU offloading. The forward() method computes everything
internally from raw inputs.
"""
_supports_gradient_checkpointing = True
_skip_layerwise_casting_patterns = ["patch_embed", "patch_embed_base", "time_embed"]
_no_split_modules = ["CosmosTransformerBlock"]
_keep_in_fp32_modules = ["learnable_pos_embed"]
@register_to_config
def __init__(
self,
n_controlnet_blocks: int = 4,
in_channels: int = 130,
latent_channels: int = 18, # base latent channels (latents + condition_mask) + padding_mask
model_channels: int = 2048,
num_attention_heads: int = 32,
attention_head_dim: int = 128,
mlp_ratio: float = 4.0,
text_embed_dim: int = 1024,
adaln_lora_dim: int = 256,
patch_size: Tuple[int, int, int] = (1, 2, 2),
max_size: Tuple[int, int, int] = (128, 240, 240),
rope_scale: Tuple[float, float, float] = (2.0, 1.0, 1.0),
extra_pos_embed_type: Optional[str] = None,
img_context_dim_in: Optional[int] = None,
img_context_dim_out: int = 2048,
use_crossattn_projection: bool = False,
crossattn_proj_in_channels: int = 1024,
encoder_hidden_states_channels: int = 1024,
):
super().__init__()
self.patch_embed = CosmosPatchEmbed(in_channels, model_channels, patch_size, bias=False)
self.patch_embed_base = CosmosPatchEmbed(latent_channels, model_channels, patch_size, bias=False)
self.time_embed = CosmosEmbedding(model_channels, model_channels)
self.learnable_pos_embed = None
if extra_pos_embed_type == "learnable":
self.learnable_pos_embed = CosmosLearnablePositionalEmbed(
hidden_size=model_channels,
max_size=max_size,
patch_size=patch_size,
)
self.img_context_proj = None
if img_context_dim_in is not None and img_context_dim_in > 0:
self.img_context_proj = nn.Sequential(
nn.Linear(img_context_dim_in, img_context_dim_out, bias=True),
nn.GELU(),
)
# Cross-attention projection for text embeddings (same as transformer)
self.crossattn_proj = None
if use_crossattn_projection:
self.crossattn_proj = nn.Sequential(
nn.Linear(crossattn_proj_in_channels, encoder_hidden_states_channels, bias=True),
nn.GELU(),
)
# RoPE for both control and base latents
self.rope = CosmosRotaryPosEmbed(
hidden_size=attention_head_dim, max_size=max_size, patch_size=patch_size, rope_scale=rope_scale
)
self.control_blocks = nn.ModuleList(
[
CosmosTransformerBlock(
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
cross_attention_dim=text_embed_dim,
mlp_ratio=mlp_ratio,
adaln_lora_dim=adaln_lora_dim,
qk_norm="rms_norm",
out_bias=False,
img_context=img_context_dim_in is not None and img_context_dim_in > 0,
before_proj=(block_idx == 0),
after_proj=True,
)
for block_idx in range(n_controlnet_blocks)
]
)
self.gradient_checkpointing = False
def _expand_conditioning_scale(self, conditioning_scale: Union[float, List[float]]) -> List[float]:
if isinstance(conditioning_scale, list):
scales = conditioning_scale
else:
scales = [conditioning_scale] * len(self.control_blocks)
if len(scales) < len(self.control_blocks):
logger.warning(
"Received %d control scales, but control network defines %d blocks. "
"Scales will be trimmed or repeated to match.",
len(scales),
len(self.control_blocks),
)
scales = (scales * len(self.control_blocks))[: len(self.control_blocks)]
return scales
def forward(
self,
controls_latents: torch.Tensor,
latents: torch.Tensor,
timestep: torch.Tensor,
encoder_hidden_states: Union[Optional[torch.Tensor], Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]],
condition_mask: torch.Tensor,
conditioning_scale: Union[float, List[float]] = 1.0,
padding_mask: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
fps: Optional[int] = None,
return_dict: bool = True,
) -> Union[CosmosControlNetOutput, Tuple[List[torch.Tensor]]]:
"""
Forward pass for the ControlNet.
Args:
controls_latents: Control signal latents [B, C, T, H, W]
latents: Base latents from the noising process [B, C, T, H, W]
timestep: Diffusion timestep tensor
encoder_hidden_states: Tuple of (text_context, img_context) or text_context
condition_mask: Conditioning mask [B, 1, T, H, W]
conditioning_scale: Scale factor(s) for control outputs
padding_mask: Padding mask [B, 1, H, W] or None
attention_mask: Optional attention mask or None
fps: Frames per second for RoPE or None
return_dict: Whether to return a CosmosControlNetOutput or a tuple
Returns:
CosmosControlNetOutput or tuple of control tensors
"""
B, C, T, H, W = controls_latents.shape
# 1. Prepare control latents
control_hidden_states = controls_latents
vace_in_channels = self.config.in_channels - 1
if control_hidden_states.shape[1] < vace_in_channels - 1:
pad_C = vace_in_channels - 1 - control_hidden_states.shape[1]
control_hidden_states = torch.cat(
[
control_hidden_states,
torch.zeros(
(B, pad_C, T, H, W), dtype=control_hidden_states.dtype, device=control_hidden_states.device
),
],
dim=1,
)
control_hidden_states = torch.cat([control_hidden_states, torch.zeros_like(controls_latents[:, :1])], dim=1)
padding_mask_resized = transforms.functional.resize(
padding_mask, list(control_hidden_states.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST
)
control_hidden_states = torch.cat(
[control_hidden_states, padding_mask_resized.unsqueeze(2).repeat(B, 1, T, 1, 1)], dim=1
)
# 2. Prepare base latents (same processing as transformer.forward)
base_hidden_states = latents
if condition_mask is not None:
base_hidden_states = torch.cat([base_hidden_states, condition_mask], dim=1)
base_padding_mask = transforms.functional.resize(
padding_mask, list(base_hidden_states.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST
)
base_hidden_states = torch.cat(
[base_hidden_states, base_padding_mask.unsqueeze(2).repeat(B, 1, T, 1, 1)], dim=1
)
# 3. Generate positional embeddings (shared for both)
image_rotary_emb = self.rope(control_hidden_states, fps=fps)
extra_pos_emb = self.learnable_pos_embed(control_hidden_states) if self.learnable_pos_embed else None
# 4. Patchify control latents
control_hidden_states = self.patch_embed(control_hidden_states)
control_hidden_states = control_hidden_states.flatten(1, 3)
# 5. Patchify base latents
p_t, p_h, p_w = self.config.patch_size
post_patch_num_frames = T // p_t
post_patch_height = H // p_h
post_patch_width = W // p_w
base_hidden_states = self.patch_embed_base(base_hidden_states)
base_hidden_states = base_hidden_states.flatten(1, 3)
# 6. Time embeddings
if timestep.ndim == 1:
temb, embedded_timestep = self.time_embed(base_hidden_states, timestep)
elif timestep.ndim == 5:
batch_size, _, num_frames, _, _ = latents.shape
assert timestep.shape == (batch_size, 1, num_frames, 1, 1), (
f"Expected timestep to have shape [B, 1, T, 1, 1], but got {timestep.shape}"
)
timestep_flat = timestep.flatten()
temb, embedded_timestep = self.time_embed(base_hidden_states, timestep_flat)
temb, embedded_timestep = (
x.view(batch_size, post_patch_num_frames, 1, 1, -1)
.expand(-1, -1, post_patch_height, post_patch_width, -1)
.flatten(1, 3)
for x in (temb, embedded_timestep)
)
else:
raise ValueError(f"Expected timestep to have shape [B, 1, T, 1, 1] or [T], but got {timestep.shape}")
# 7. Process encoder hidden states
if isinstance(encoder_hidden_states, tuple):
text_context, img_context = encoder_hidden_states
else:
text_context = encoder_hidden_states
img_context = None
# Apply cross-attention projection to text context
if self.crossattn_proj is not None:
text_context = self.crossattn_proj(text_context)
# Apply cross-attention projection to image context (if provided)
if img_context is not None and self.img_context_proj is not None:
img_context = self.img_context_proj(img_context)
# Combine text and image context into a single tuple
if self.config.img_context_dim_in is not None and self.config.img_context_dim_in > 0:
processed_encoder_hidden_states = (text_context, img_context)
else:
processed_encoder_hidden_states = text_context
# 8. Prepare attention mask
if attention_mask is not None:
attention_mask = attention_mask.unsqueeze(1).unsqueeze(1) # [B, 1, 1, S]
# 9. Run control blocks
scales = self._expand_conditioning_scale(conditioning_scale)
result = []
for block_idx, (block, scale) in enumerate(zip(self.control_blocks, scales)):
if torch.is_grad_enabled() and self.gradient_checkpointing:
control_hidden_states, control_proj = self._gradient_checkpointing_func(
block,
control_hidden_states,
processed_encoder_hidden_states,
embedded_timestep,
temb,
image_rotary_emb,
extra_pos_emb,
attention_mask,
None, # controlnet_residual
base_hidden_states,
block_idx,
)
else:
control_hidden_states, control_proj = block(
hidden_states=control_hidden_states,
encoder_hidden_states=processed_encoder_hidden_states,
embedded_timestep=embedded_timestep,
temb=temb,
image_rotary_emb=image_rotary_emb,
extra_pos_emb=extra_pos_emb,
attention_mask=attention_mask,
controlnet_residual=None,
latents=base_hidden_states,
block_idx=block_idx,
)
result.append(control_proj * scale)
if not return_dict:
return (result,)
return CosmosControlNetOutput(control_block_samples=result)

View File

@@ -43,7 +43,7 @@ def _get_qkv_projections(attn: "WanAttention", hidden_states: torch.Tensor, enco
encoder_hidden_states = hidden_states
if attn.fused_projections:
if attn.cross_attention_dim_head is None:
if not attn.is_cross_attention:
# In self-attention layers, we can fuse the entire QKV projection into a single linear
query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1)
else:
@@ -219,7 +219,10 @@ class WanAttention(torch.nn.Module, AttentionModuleMixin):
self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True)
self.norm_added_k = torch.nn.RMSNorm(dim_head * heads, eps=eps)
self.is_cross_attention = cross_attention_dim_head is not None
if is_cross_attention is not None:
self.is_cross_attention = is_cross_attention
else:
self.is_cross_attention = cross_attention_dim_head is not None
self.set_processor(processor)
@@ -227,7 +230,7 @@ class WanAttention(torch.nn.Module, AttentionModuleMixin):
if getattr(self, "fused_projections", False):
return
if self.cross_attention_dim_head is None:
if not self.is_cross_attention:
concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data])
out_features, in_features = concatenated_weights.shape

View File

@@ -12,17 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin
from ...utils import is_torchvision_available
from ..attention import FeedForward
from ..attention_dispatch import dispatch_attention_fn
from ..attention_processor import Attention
from ..embeddings import Timesteps
from ..modeling_outputs import Transformer2DModelOutput
@@ -152,7 +152,7 @@ class CosmosAdaLayerNormZero(nn.Module):
class CosmosAttnProcessor2_0:
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
if not hasattr(torch.nn.functional, "scaled_dot_product_attention"):
raise ImportError("CosmosAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")
def __call__(
@@ -191,7 +191,6 @@ class CosmosAttnProcessor2_0:
query_idx = torch.tensor(query.size(3), device=query.device)
key_idx = torch.tensor(key.size(3), device=key.device)
value_idx = torch.tensor(value.size(3), device=value.device)
else:
query_idx = query.size(3)
key_idx = key.size(3)
@@ -200,18 +199,148 @@ class CosmosAttnProcessor2_0:
value = value.repeat_interleave(query_idx // value_idx, dim=3)
# 5. Attention
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
hidden_states = dispatch_attention_fn(
query.transpose(1, 2),
key.transpose(1, 2),
value.transpose(1, 2),
attn_mask=attention_mask,
dropout_p=0.0,
is_causal=False,
)
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3).type_as(query)
# 6. Output projection
hidden_states = hidden_states.flatten(2, 3).type_as(query)
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
return hidden_states
class CosmosAttnProcessor2_5:
def __init__(self):
if not hasattr(torch.nn.functional, "scaled_dot_product_attention"):
raise ImportError("CosmosAttnProcessor2_5 requires PyTorch 2.0. Please upgrade PyTorch to 2.0 or newer.")
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]],
attention_mask: Optional[Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]],
image_rotary_emb=None,
) -> torch.Tensor:
if not isinstance(encoder_hidden_states, tuple):
raise ValueError("Expected encoder_hidden_states as (text_context, img_context) tuple.")
text_context, img_context = encoder_hidden_states if encoder_hidden_states else (None, None)
text_mask, img_mask = attention_mask if attention_mask else (None, None)
if text_context is None:
text_context = hidden_states
query = attn.to_q(hidden_states)
key = attn.to_k(text_context)
value = attn.to_v(text_context)
query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
query = attn.norm_q(query)
key = attn.norm_k(key)
if image_rotary_emb is not None:
from ..embeddings import apply_rotary_emb
query = apply_rotary_emb(query, image_rotary_emb, use_real=True, use_real_unbind_dim=-2)
key = apply_rotary_emb(key, image_rotary_emb, use_real=True, use_real_unbind_dim=-2)
if torch.onnx.is_in_onnx_export():
query_idx = torch.tensor(query.size(3), device=query.device)
key_idx = torch.tensor(key.size(3), device=key.device)
value_idx = torch.tensor(value.size(3), device=value.device)
else:
query_idx = query.size(3)
key_idx = key.size(3)
value_idx = value.size(3)
key = key.repeat_interleave(query_idx // key_idx, dim=3)
value = value.repeat_interleave(query_idx // value_idx, dim=3)
attn_out = dispatch_attention_fn(
query.transpose(1, 2),
key.transpose(1, 2),
value.transpose(1, 2),
attn_mask=text_mask,
dropout_p=0.0,
is_causal=False,
)
attn_out = attn_out.flatten(2, 3).type_as(query)
if img_context is not None:
q_img = attn.q_img(hidden_states)
k_img = attn.k_img(img_context)
v_img = attn.v_img(img_context)
batch_size = hidden_states.shape[0]
dim_head = attn.out_dim // attn.heads
q_img = q_img.view(batch_size, -1, attn.heads, dim_head).transpose(1, 2)
k_img = k_img.view(batch_size, -1, attn.heads, dim_head).transpose(1, 2)
v_img = v_img.view(batch_size, -1, attn.heads, dim_head).transpose(1, 2)
q_img = attn.q_img_norm(q_img)
k_img = attn.k_img_norm(k_img)
q_img_idx = q_img.size(3)
k_img_idx = k_img.size(3)
v_img_idx = v_img.size(3)
k_img = k_img.repeat_interleave(q_img_idx // k_img_idx, dim=3)
v_img = v_img.repeat_interleave(q_img_idx // v_img_idx, dim=3)
img_out = dispatch_attention_fn(
q_img.transpose(1, 2),
k_img.transpose(1, 2),
v_img.transpose(1, 2),
attn_mask=img_mask,
dropout_p=0.0,
is_causal=False,
)
img_out = img_out.flatten(2, 3).type_as(q_img)
hidden_states = attn_out + img_out
else:
hidden_states = attn_out
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
return hidden_states
class CosmosAttention(Attention):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# add parameters for image q/k/v
inner_dim = self.heads * self.to_q.out_features // self.heads
self.q_img = nn.Linear(self.query_dim, inner_dim, bias=False)
self.k_img = nn.Linear(self.query_dim, inner_dim, bias=False)
self.v_img = nn.Linear(self.query_dim, inner_dim, bias=False)
self.q_img_norm = RMSNorm(self.to_q.out_features // self.heads, eps=1e-6, elementwise_affine=True)
self.k_img_norm = RMSNorm(self.to_k.out_features // self.heads, eps=1e-6, elementwise_affine=True)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]],
attention_mask: Optional[torch.Tensor] = None,
**cross_attention_kwargs,
) -> torch.Tensor:
return super().forward(
hidden_states=hidden_states,
# NOTE: type-hint in base class can be ignored
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
**cross_attention_kwargs,
)
class CosmosTransformerBlock(nn.Module):
def __init__(
self,
@@ -222,12 +351,16 @@ class CosmosTransformerBlock(nn.Module):
adaln_lora_dim: int = 256,
qk_norm: str = "rms_norm",
out_bias: bool = False,
img_context: bool = False,
before_proj: bool = False,
after_proj: bool = False,
) -> None:
super().__init__()
hidden_size = num_attention_heads * attention_head_dim
self.norm1 = CosmosAdaLayerNormZero(in_features=hidden_size, hidden_features=adaln_lora_dim)
self.img_context = img_context
self.attn1 = Attention(
query_dim=hidden_size,
cross_attention_dim=None,
@@ -240,30 +373,58 @@ class CosmosTransformerBlock(nn.Module):
)
self.norm2 = CosmosAdaLayerNormZero(in_features=hidden_size, hidden_features=adaln_lora_dim)
self.attn2 = Attention(
query_dim=hidden_size,
cross_attention_dim=cross_attention_dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
qk_norm=qk_norm,
elementwise_affine=True,
out_bias=out_bias,
processor=CosmosAttnProcessor2_0(),
)
if img_context:
self.attn2 = CosmosAttention(
query_dim=hidden_size,
cross_attention_dim=cross_attention_dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
qk_norm=qk_norm,
elementwise_affine=True,
out_bias=out_bias,
processor=CosmosAttnProcessor2_5(),
)
else:
self.attn2 = Attention(
query_dim=hidden_size,
cross_attention_dim=cross_attention_dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
qk_norm=qk_norm,
elementwise_affine=True,
out_bias=out_bias,
processor=CosmosAttnProcessor2_0(),
)
self.norm3 = CosmosAdaLayerNormZero(in_features=hidden_size, hidden_features=adaln_lora_dim)
self.ff = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu", bias=out_bias)
# NOTE: zero conv for CosmosControlNet
self.before_proj = None
self.after_proj = None
if before_proj:
self.before_proj = nn.Linear(hidden_size, hidden_size)
if after_proj:
self.after_proj = nn.Linear(hidden_size, hidden_size)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
encoder_hidden_states: Union[
Optional[torch.Tensor], Optional[Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]]
],
embedded_timestep: torch.Tensor,
temb: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
extra_pos_emb: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
controlnet_residual: Optional[torch.Tensor] = None,
latents: Optional[torch.Tensor] = None,
block_idx: Optional[int] = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if self.before_proj is not None:
hidden_states = self.before_proj(hidden_states) + latents
if extra_pos_emb is not None:
hidden_states = hidden_states + extra_pos_emb
@@ -284,6 +445,16 @@ class CosmosTransformerBlock(nn.Module):
ff_output = self.ff(norm_hidden_states)
hidden_states = hidden_states + gate * ff_output
if controlnet_residual is not None:
assert self.after_proj is None
# NOTE: this is assumed to be scaled by the controlnet
hidden_states += controlnet_residual
if self.after_proj is not None:
assert controlnet_residual is None
hs_proj = self.after_proj(hidden_states)
return hidden_states, hs_proj
return hidden_states
@@ -416,6 +587,17 @@ class CosmosTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
Whether to concatenate the padding mask to the input latent tensors.
extra_pos_embed_type (`str`, *optional*, defaults to `learnable`):
The type of extra positional embeddings to use. Can be one of `None` or `learnable`.
controlnet_block_every_n (`int`, *optional*):
Interval between transformer blocks that should receive control residuals (for example, `7` to inject after
every seventh block). Required for Cosmos Transfer2.5.
img_context_dim_in (`int`, *optional*):
The dimension of the input image context feature vector, i.e. it is the D in [B, N, D].
img_context_num_tokens (`int`):
The number of tokens in the image context feature vector, i.e. it is the N in [B, N, D]. If
`img_context_dim_in` is not provided, then this parameter is ignored.
img_context_dim_out (`int`):
The output dimension of the image context projection layer. If `img_context_dim_in` is not provided, then
this parameter is ignored.
"""
_supports_gradient_checkpointing = True
@@ -442,6 +624,10 @@ class CosmosTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
use_crossattn_projection: bool = False,
crossattn_proj_in_channels: int = 1024,
encoder_hidden_states_channels: int = 1024,
controlnet_block_every_n: Optional[int] = None,
img_context_dim_in: Optional[int] = None,
img_context_num_tokens: int = 256,
img_context_dim_out: int = 2048,
) -> None:
super().__init__()
hidden_size = num_attention_heads * attention_head_dim
@@ -477,6 +663,7 @@ class CosmosTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
adaln_lora_dim=adaln_lora_dim,
qk_norm="rms_norm",
out_bias=False,
img_context=self.config.img_context_dim_in is not None and self.config.img_context_dim_in > 0,
)
for _ in range(num_layers)
]
@@ -496,17 +683,24 @@ class CosmosTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
self.gradient_checkpointing = False
if self.config.img_context_dim_in:
self.img_context_proj = nn.Sequential(
nn.Linear(self.config.img_context_dim_in, self.config.img_context_dim_out, bias=True),
nn.GELU(),
)
def forward(
self,
hidden_states: torch.Tensor,
timestep: torch.Tensor,
encoder_hidden_states: torch.Tensor,
block_controlnet_hidden_states: Optional[List[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
fps: Optional[int] = None,
condition_mask: Optional[torch.Tensor] = None,
padding_mask: Optional[torch.Tensor] = None,
return_dict: bool = True,
) -> torch.Tensor:
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
batch_size, num_channels, num_frames, height, width = hidden_states.shape
# 1. Concatenate padding mask if needed & prepare attention mask
@@ -514,11 +708,11 @@ class CosmosTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
hidden_states = torch.cat([hidden_states, condition_mask], dim=1)
if self.config.concat_padding_mask:
padding_mask = transforms.functional.resize(
padding_mask_resized = transforms.functional.resize(
padding_mask, list(hidden_states.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST
)
hidden_states = torch.cat(
[hidden_states, padding_mask.unsqueeze(2).repeat(batch_size, 1, num_frames, 1, 1)], dim=1
[hidden_states, padding_mask_resized.unsqueeze(2).repeat(batch_size, 1, num_frames, 1, 1)], dim=1
)
if attention_mask is not None:
@@ -554,36 +748,59 @@ class CosmosTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
for x in (temb, embedded_timestep)
) # [BT, C] -> [B, T, 1, 1, C] -> [B, T, H, W, C] -> [B, THW, C]
else:
assert False
raise ValueError(f"Expected timestep to have shape [B, 1, T, 1, 1] or [T], but got {timestep.shape}")
# 5. Process encoder hidden states
text_context, img_context = (
encoder_hidden_states if isinstance(encoder_hidden_states, tuple) else (encoder_hidden_states, None)
)
if self.config.use_crossattn_projection:
encoder_hidden_states = self.crossattn_proj(encoder_hidden_states)
text_context = self.crossattn_proj(text_context)
# 5. Transformer blocks
for block in self.transformer_blocks:
if img_context is not None and self.config.img_context_dim_in:
img_context = self.img_context_proj(img_context)
processed_encoder_hidden_states = (
(text_context, img_context) if isinstance(encoder_hidden_states, tuple) else text_context
)
# 6. Build controlnet block index map
controlnet_block_index_map = {}
if block_controlnet_hidden_states is not None:
n_blocks = len(self.transformer_blocks)
controlnet_block_index_map = {
block_idx: block_controlnet_hidden_states[idx]
for idx, block_idx in list(enumerate(range(0, n_blocks, self.config.controlnet_block_every_n)))
}
# 7. Transformer blocks
for block_idx, block in enumerate(self.transformer_blocks):
controlnet_residual = controlnet_block_index_map.get(block_idx)
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
encoder_hidden_states,
processed_encoder_hidden_states,
embedded_timestep,
temb,
image_rotary_emb,
extra_pos_emb,
attention_mask,
controlnet_residual,
)
else:
hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
embedded_timestep=embedded_timestep,
temb=temb,
image_rotary_emb=image_rotary_emb,
extra_pos_emb=extra_pos_emb,
attention_mask=attention_mask,
hidden_states,
processed_encoder_hidden_states,
embedded_timestep,
temb,
image_rotary_emb,
extra_pos_emb,
attention_mask,
controlnet_residual,
)
# 6. Output norm & projection & unpatchify
# 8. Output norm & projection & unpatchify
hidden_states = self.norm_out(hidden_states, embedded_timestep, temb)
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.unflatten(2, (p_h, p_w, p_t, -1))

View File

@@ -56,10 +56,8 @@ def apply_split_rotary_emb(x: torch.Tensor, freqs: Tuple[torch.Tensor, torch.Ten
x_dtype = x.dtype
needs_reshape = False
if x.ndim != 4 and cos.ndim == 4:
# cos is (#b, h, t, r) -> reshape x to (b, h, t, dim_per_head)
# The cos/sin batch dim may only be broadcastable, so take batch size from x
b = x.shape[0]
_, h, t, _ = cos.shape
# cos is (b, h, t, r) -> reshape x to (b, h, t, dim_per_head)
b, h, t, _ = cos.shape
x = x.reshape(b, t, h, -1).swapaxes(1, 2)
needs_reshape = True

View File

@@ -42,7 +42,7 @@ def _get_qkv_projections(attn: "WanAttention", hidden_states: torch.Tensor, enco
encoder_hidden_states = hidden_states
if attn.fused_projections:
if attn.cross_attention_dim_head is None:
if not attn.is_cross_attention:
# In self-attention layers, we can fuse the entire QKV projection into a single linear
query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1)
else:
@@ -214,7 +214,10 @@ class WanAttention(torch.nn.Module, AttentionModuleMixin):
self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True)
self.norm_added_k = torch.nn.RMSNorm(dim_head * heads, eps=eps)
self.is_cross_attention = cross_attention_dim_head is not None
if is_cross_attention is not None:
self.is_cross_attention = is_cross_attention
else:
self.is_cross_attention = cross_attention_dim_head is not None
self.set_processor(processor)
@@ -222,7 +225,7 @@ class WanAttention(torch.nn.Module, AttentionModuleMixin):
if getattr(self, "fused_projections", False):
return
if self.cross_attention_dim_head is None:
if not self.is_cross_attention:
concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data])
out_features, in_features = concatenated_weights.shape

View File

@@ -54,7 +54,7 @@ def _get_qkv_projections(attn: "WanAttention", hidden_states: torch.Tensor, enco
encoder_hidden_states = hidden_states
if attn.fused_projections:
if attn.cross_attention_dim_head is None:
if not attn.is_cross_attention:
# In self-attention layers, we can fuse the entire QKV projection into a single linear
query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1)
else:
@@ -502,13 +502,16 @@ class WanAnimateFaceBlockCrossAttention(nn.Module, AttentionModuleMixin):
dim_head: int = 64,
eps: float = 1e-6,
cross_attention_dim_head: Optional[int] = None,
bias: bool = True,
processor=None,
):
super().__init__()
self.inner_dim = dim_head * heads
self.heads = heads
self.cross_attention_head_dim = cross_attention_dim_head
self.cross_attention_dim_head = cross_attention_dim_head
self.kv_inner_dim = self.inner_dim if cross_attention_dim_head is None else cross_attention_dim_head * heads
self.use_bias = bias
self.is_cross_attention = cross_attention_dim_head is not None
# 1. Pre-Attention Norms for the hidden_states (video latents) and encoder_hidden_states (motion vector).
# NOTE: this is not used in "vanilla" WanAttention
@@ -516,10 +519,10 @@ class WanAnimateFaceBlockCrossAttention(nn.Module, AttentionModuleMixin):
self.pre_norm_kv = nn.LayerNorm(dim, eps, elementwise_affine=False)
# 2. QKV and Output Projections
self.to_q = torch.nn.Linear(dim, self.inner_dim, bias=True)
self.to_k = torch.nn.Linear(dim, self.kv_inner_dim, bias=True)
self.to_v = torch.nn.Linear(dim, self.kv_inner_dim, bias=True)
self.to_out = torch.nn.Linear(self.inner_dim, dim, bias=True)
self.to_q = torch.nn.Linear(dim, self.inner_dim, bias=bias)
self.to_k = torch.nn.Linear(dim, self.kv_inner_dim, bias=bias)
self.to_v = torch.nn.Linear(dim, self.kv_inner_dim, bias=bias)
self.to_out = torch.nn.Linear(self.inner_dim, dim, bias=bias)
# 3. QK Norm
# NOTE: this is applied after the reshape, so only over dim_head rather than dim_head * heads
@@ -682,7 +685,10 @@ class WanAttention(torch.nn.Module, AttentionModuleMixin):
self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True)
self.norm_added_k = torch.nn.RMSNorm(dim_head * heads, eps=eps)
self.is_cross_attention = cross_attention_dim_head is not None
if is_cross_attention is not None:
self.is_cross_attention = is_cross_attention
else:
self.is_cross_attention = cross_attention_dim_head is not None
self.set_processor(processor)
@@ -690,7 +696,7 @@ class WanAttention(torch.nn.Module, AttentionModuleMixin):
if getattr(self, "fused_projections", False):
return
if self.cross_attention_dim_head is None:
if not self.is_cross_attention:
concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data])
out_features, in_features = concatenated_weights.shape

View File

@@ -76,6 +76,7 @@ class WanVACETransformerBlock(nn.Module):
eps=eps,
added_kv_proj_dim=added_kv_proj_dim,
processor=WanAttnProcessor(),
is_cross_attention=True,
)
self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity()
@@ -178,6 +179,7 @@ class WanVACETransformer3DModel(
_no_split_modules = ["WanTransformerBlock", "WanVACETransformerBlock"]
_keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"]
_keys_to_ignore_on_load_unexpected = ["norm_added_q"]
_repeated_blocks = ["WanTransformerBlock", "WanVACETransformerBlock"]
@register_to_config
def __init__(

View File

@@ -167,6 +167,7 @@ else:
_import_structure["consisid"] = ["ConsisIDPipeline"]
_import_structure["cosmos"] = [
"Cosmos2_5_PredictBasePipeline",
"Cosmos2_5_TransferPipeline",
"Cosmos2TextToImagePipeline",
"CosmosTextToWorldPipeline",
"CosmosVideoToWorldPipeline",
@@ -631,6 +632,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
)
from .cosmos import (
Cosmos2_5_PredictBasePipeline,
Cosmos2_5_TransferPipeline,
Cosmos2TextToImagePipeline,
Cosmos2VideoToWorldPipeline,
CosmosTextToWorldPipeline,

View File

@@ -25,6 +25,9 @@ else:
_import_structure["pipeline_cosmos2_5_predict"] = [
"Cosmos2_5_PredictBasePipeline",
]
_import_structure["pipeline_cosmos2_5_transfer"] = [
"Cosmos2_5_TransferPipeline",
]
_import_structure["pipeline_cosmos2_text2image"] = ["Cosmos2TextToImagePipeline"]
_import_structure["pipeline_cosmos2_video2world"] = ["Cosmos2VideoToWorldPipeline"]
_import_structure["pipeline_cosmos_text2world"] = ["CosmosTextToWorldPipeline"]
@@ -41,6 +44,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .pipeline_cosmos2_5_predict import (
Cosmos2_5_PredictBasePipeline,
)
from .pipeline_cosmos2_5_transfer import Cosmos2_5_TransferPipeline
from .pipeline_cosmos2_text2image import Cosmos2TextToImagePipeline
from .pipeline_cosmos2_video2world import Cosmos2VideoToWorldPipeline
from .pipeline_cosmos_text2world import CosmosTextToWorldPipeline

View File

@@ -0,0 +1,923 @@
# Copyright 2025 The NVIDIA Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, Dict, List, Optional, Union
import numpy as np
import PIL.Image
import torch
import torchvision
import torchvision.transforms
import torchvision.transforms.functional
from transformers import AutoTokenizer, Qwen2_5_VLForConditionalGeneration
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...image_processor import PipelineImageInput
from ...models import AutoencoderKLWan, CosmosControlNetModel, CosmosTransformer3DModel
from ...schedulers import UniPCMultistepScheduler
from ...utils import is_cosmos_guardrail_available, is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
from ..pipeline_utils import DiffusionPipeline
from .pipeline_output import CosmosPipelineOutput
if is_cosmos_guardrail_available():
from cosmos_guardrail import CosmosSafetyChecker
else:
class CosmosSafetyChecker:
def __init__(self, *args, **kwargs):
raise ImportError(
"`cosmos_guardrail` is not installed. Please install it to use the safety checker for Cosmos: `pip install cosmos_guardrail`."
)
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def _maybe_pad_video(video: torch.Tensor, num_frames: int):
n_pad_frames = num_frames - video.shape[2]
if n_pad_frames > 0:
last_frame = video[:, :, -1:, :, :]
video = torch.cat((video, last_frame.repeat(1, 1, n_pad_frames, 1, 1)), dim=2)
return video
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")
DEFAULT_NEGATIVE_PROMPT = "The video captures a series of frames showing ugly scenes, static with no motion, motion blur, over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. Overall, the video is of poor quality."
EXAMPLE_DOC_STRING = """
Examples:
```python
>>> import cv2
>>> import numpy as np
>>> import torch
>>> from diffusers import Cosmos2_5_TransferPipeline, AutoModel
>>> from diffusers.utils import export_to_video, load_video
>>> model_id = "nvidia/Cosmos-Transfer2.5-2B"
>>> # Load a Transfer2.5 controlnet variant (edge, depth, seg, or blur)
>>> controlnet = AutoModel.from_pretrained(model_id, revision="diffusers/controlnet/general/edge")
>>> pipe = Cosmos2_5_TransferPipeline.from_pretrained(
... model_id, controlnet=controlnet, revision="diffusers/general", torch_dtype=torch.bfloat16
... )
>>> pipe = pipe.to("cuda")
>>> # Video2World with edge control: Generate video guided by edge maps extracted from input video.
>>> prompt = (
... "The video is a demonstration of robotic manipulation, likely in a laboratory or testing environment. It"
... "features two robotic arms interacting with a piece of blue fabric. The setting is a room with a beige"
... "couch in the background, providing a neutral backdrop for the robotic activity. The robotic arms are"
... "positioned on either side of the fabric, which is placed on a yellow cushion. The left robotic arm is"
... "white with a black gripper, while the right arm is black with a more complex, articulated gripper. At the"
... "beginning, the fabric is laid out on the cushion. The left robotic arm approaches the fabric, its gripper"
... "opening and closing as it positions itself. The right arm remains stationary initially, poised to assist."
... "As the video progresses, the left arm grips the fabric, lifting it slightly off the cushion. The right arm"
... "then moves in, its gripper adjusting to grasp the opposite side of the fabric. Both arms work in"
... "coordination, lifting and holding the fabric between them. The fabric is manipulated with precision,"
... "showcasing the dexterity and control of the robotic arms. The camera remains static throughout, focusing"
... "on the interaction between the robotic arms and the fabric, allowing viewers to observe the detailed"
... "movements and coordination involved in the task."
... )
>>> negative_prompt = (
... "The video captures a series of frames showing ugly scenes, static with no motion, motion blur, "
... "over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, "
... "underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, jerky "
... "movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, "
... "fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. "
... "Overall, the video is of poor quality."
... )
>>> input_video = load_video(
... "https://github.com/nvidia-cosmos/cosmos-transfer2.5/raw/refs/heads/main/assets/robot_example/robot_input.mp4"
... )
>>> num_frames = 93
>>> # Extract edge maps from the input video using Canny edge detection
>>> edge_maps = [
... cv2.Canny(cv2.cvtColor(np.array(frame.convert("RGB")), cv2.COLOR_RGB2BGR), 100, 200)
... for frame in input_video[:num_frames]
... ]
>>> edge_maps = np.stack(edge_maps)[None] # (T, H, W) -> (1, T, H, W)
>>> controls = torch.from_numpy(edge_maps).expand(3, -1, -1, -1) # (1, T, H, W) -> (3, T, H, W)
>>> controls = [Image.fromarray(x.numpy()) for x in controls.permute(1, 2, 3, 0)]
>>> export_to_video(controls, "edge_controlled_video_edge.mp4", fps=30)
>>> video = pipe(
... video=input_video[:num_frames],
... controls=controls,
... controls_conditioning_scale=1.0,
... prompt=prompt,
... negative_prompt=negative_prompt,
... num_frames=num_frames,
... ).frames[0]
>>> export_to_video(video, "edge_controlled_video.mp4", fps=30)
```
"""
class Cosmos2_5_TransferPipeline(DiffusionPipeline):
r"""
Pipeline for Cosmos Transfer2.5 base model.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
Args:
text_encoder ([`Qwen2_5_VLForConditionalGeneration`]):
Frozen text-encoder. Cosmos Transfer2.5 uses the [Qwen2.5
VL](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) encoder.
tokenizer (`AutoTokenizer`):
Tokenizer associated with the Qwen2.5 VL encoder.
transformer ([`CosmosTransformer3DModel`]):
Conditional Transformer to denoise the encoded image latents.
scheduler ([`UniPCMultistepScheduler`]):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
vae ([`AutoencoderKLWan`]):
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
"""
model_cpu_offload_seq = "text_encoder->transformer->controlnet->vae"
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
# We mark safety_checker as optional here to get around some test failures, but it is not really optional
_optional_components = ["safety_checker", "controlnet"]
_exclude_from_cpu_offload = ["safety_checker"]
def __init__(
self,
text_encoder: Qwen2_5_VLForConditionalGeneration,
tokenizer: AutoTokenizer,
transformer: CosmosTransformer3DModel,
vae: AutoencoderKLWan,
scheduler: UniPCMultistepScheduler,
controlnet: Optional[CosmosControlNetModel],
safety_checker: CosmosSafetyChecker = None,
):
super().__init__()
if safety_checker is None:
safety_checker = CosmosSafetyChecker()
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
transformer=transformer,
controlnet=controlnet,
scheduler=scheduler,
safety_checker=safety_checker,
)
self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4
self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
latents_mean = (
torch.tensor(self.vae.config.latents_mean).view(1, self.vae.config.z_dim, 1, 1, 1).float()
if getattr(self.vae.config, "latents_mean", None) is not None
else None
)
latents_std = (
torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).float()
if getattr(self.vae.config, "latents_std", None) is not None
else None
)
self.latents_mean = latents_mean
self.latents_std = latents_std
if self.latents_mean is None or self.latents_std is None:
raise ValueError("VAE configuration must define both `latents_mean` and `latents_std`.")
def _get_prompt_embeds(
self,
prompt: Union[str, List[str]] = None,
max_sequence_length: int = 512,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
device = device or self._execution_device
dtype = dtype or self.text_encoder.dtype
prompt = [prompt] if isinstance(prompt, str) else prompt
input_ids_batch = []
for sample_idx in range(len(prompt)):
conversations = [
{
"role": "system",
"content": [
{
"type": "text",
"text": "You are a helpful assistant who will provide prompts to an image generator.",
}
],
},
{
"role": "user",
"content": [
{
"type": "text",
"text": prompt[sample_idx],
}
],
},
]
input_ids = self.tokenizer.apply_chat_template(
conversations,
tokenize=True,
add_generation_prompt=False,
add_vision_id=False,
max_length=max_sequence_length,
truncation=True,
padding="max_length",
)
input_ids = torch.LongTensor(input_ids)
input_ids_batch.append(input_ids)
input_ids_batch = torch.stack(input_ids_batch, dim=0)
outputs = self.text_encoder(
input_ids_batch.to(device),
output_hidden_states=True,
)
hidden_states = outputs.hidden_states
normalized_hidden_states = []
for layer_idx in range(1, len(hidden_states)):
normalized_state = (hidden_states[layer_idx] - hidden_states[layer_idx].mean(dim=-1, keepdim=True)) / (
hidden_states[layer_idx].std(dim=-1, keepdim=True) + 1e-8
)
normalized_hidden_states.append(normalized_state)
prompt_embeds = torch.cat(normalized_hidden_states, dim=-1)
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
return prompt_embeds
# Modified from diffusers.pipelines.cosmos.pipeline_cosmos_text2world.CosmosTextToWorldPipeline.encode_prompt
def encode_prompt(
self,
prompt: Union[str, List[str]],
negative_prompt: Optional[Union[str, List[str]]] = None,
do_classifier_free_guidance: bool = True,
num_videos_per_prompt: int = 1,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
max_sequence_length: int = 512,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
Whether to use classifier free guidance or not.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
device: (`torch.device`, *optional*):
torch device
dtype: (`torch.dtype`, *optional*):
torch dtype
"""
device = device or self._execution_device
prompt = [prompt] if isinstance(prompt, str) else prompt
if prompt is not None:
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
prompt_embeds = self._get_prompt_embeds(
prompt=prompt, max_sequence_length=max_sequence_length, device=device, dtype=dtype
)
# duplicate text embeddings for each generation per prompt, using mps friendly method
_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
if do_classifier_free_guidance and negative_prompt_embeds is None:
negative_prompt = negative_prompt or ""
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
if prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
negative_prompt_embeds = self._get_prompt_embeds(
prompt=negative_prompt, max_sequence_length=max_sequence_length, device=device, dtype=dtype
)
# duplicate text embeddings for each generation per prompt, using mps friendly method
_, seq_len, _ = negative_prompt_embeds.shape
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
return prompt_embeds, negative_prompt_embeds
# Modified from diffusers.pipelines.cosmos.pipeline_cosmos2_video2world.Cosmos2VideoToWorldPipeline.prepare_latents and
# diffusers.pipelines.cosmos.pipeline_cosmos2_video2world.Cosmos2TextToImagePipeline.prepare_latents
def prepare_latents(
self,
video: Optional[torch.Tensor],
batch_size: int,
num_channels_latents: int = 16,
height: int = 704,
width: int = 1280,
num_frames_in: int = 93,
num_frames_out: int = 93,
do_classifier_free_guidance: bool = True,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
B = batch_size
C = num_channels_latents
T = (num_frames_out - 1) // self.vae_scale_factor_temporal + 1
H = height // self.vae_scale_factor_spatial
W = width // self.vae_scale_factor_spatial
shape = (B, C, T, H, W)
if num_frames_in == 0:
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
cond_mask = torch.zeros((B, 1, T, H, W), dtype=latents.dtype, device=latents.device)
cond_indicator = torch.zeros((B, 1, T, 1, 1), dtype=latents.dtype, device=latents.device)
cond_latents = torch.zeros_like(latents)
return (
latents,
cond_latents,
cond_mask,
cond_indicator,
)
else:
if video is None:
raise ValueError("`video` must be provided when `num_frames_in` is greater than 0.")
video = video.to(device=device, dtype=self.vae.dtype)
if isinstance(generator, list):
cond_latents = [
retrieve_latents(self.vae.encode(video[i].unsqueeze(0)), generator=generator[i])
for i in range(batch_size)
]
else:
cond_latents = [retrieve_latents(self.vae.encode(vid.unsqueeze(0)), generator) for vid in video]
cond_latents = torch.cat(cond_latents, dim=0).to(dtype)
latents_mean = self.latents_mean.to(device=device, dtype=dtype)
latents_std = self.latents_std.to(device=device, dtype=dtype)
cond_latents = (cond_latents - latents_mean) / latents_std
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
latents = latents.to(device=device, dtype=dtype)
padding_shape = (B, 1, T, H, W)
ones_padding = latents.new_ones(padding_shape)
zeros_padding = latents.new_zeros(padding_shape)
cond_indicator = latents.new_zeros(1, 1, latents.size(2), 1, 1)
cond_mask = cond_indicator * ones_padding + (1 - cond_indicator) * zeros_padding
return (
latents,
cond_latents,
cond_mask,
cond_indicator,
)
def _encode_controls(
self,
controls: Optional[torch.Tensor],
height: int,
width: int,
num_frames: int,
dtype: torch.dtype,
device: torch.device,
generator: Optional[Union[torch.Generator, List[torch.Generator]]],
) -> Optional[torch.Tensor]:
if controls is None:
return None
control_video = self.video_processor.preprocess_video(controls, height, width)
control_video = _maybe_pad_video(control_video, num_frames)
control_video = control_video.to(device=device, dtype=self.vae.dtype)
control_latents = [
retrieve_latents(self.vae.encode(vid.unsqueeze(0)), generator=generator) for vid in control_video
]
control_latents = torch.cat(control_latents, dim=0).to(dtype)
latents_mean = self.latents_mean.to(device=device, dtype=dtype)
latents_std = self.latents_std.to(device=device, dtype=dtype)
control_latents = (control_latents - latents_mean) / latents_std
return control_latents
# Copied from diffusers.pipelines.cosmos.pipeline_cosmos_text2world.CosmosTextToWorldPipeline.check_inputs
def check_inputs(
self,
prompt,
height,
width,
prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
):
if height % 16 != 0 or width % 16 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
@property
def guidance_scale(self):
return self._guidance_scale
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1.0
@property
def num_timesteps(self):
return self._num_timesteps
@property
def current_timestep(self):
return self._current_timestep
@property
def interrupt(self):
return self._interrupt
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
image: PipelineImageInput | None = None,
video: List[PipelineImageInput] | None = None,
prompt: Union[str, List[str]] | None = None,
negative_prompt: Union[str, List[str]] = DEFAULT_NEGATIVE_PROMPT,
height: int = 704,
width: Optional[int] = None,
num_frames: int = 93,
num_inference_steps: int = 36,
guidance_scale: float = 3.0,
num_videos_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
controls: Optional[PipelineImageInput | List[PipelineImageInput]] = None,
controls_conditioning_scale: Union[float, List[float]] = 1.0,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback_on_step_end: Optional[
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 512,
conditional_frame_timestep: float = 0.1,
):
r"""
The call function to the pipeline for generation. Supports three modes:
- **Text2World**: `image=None`, `video=None`, `prompt` provided. Generates a world clip.
- **Image2World**: `image` provided, `video=None`, `prompt` provided. Conditions on a single frame.
- **Video2World**: `video` provided, `image=None`, `prompt` provided. Conditions on an input clip.
Set `num_frames=93` (default) to produce a world video, or `num_frames=1` to produce a single image frame (the
above in "*2Image mode").
Outputs follow `output_type` (e.g., `"pil"` returns a list of `num_frames` PIL images per prompt).
Args:
image (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, *optional*):
Optional single image for Image2World conditioning. Must be `None` when `video` is provided.
video (`List[PIL.Image.Image]`, `np.ndarray`, `torch.Tensor`, *optional*):
Optional input video for Video2World conditioning. Must be `None` when `image` is provided.
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide generation. Required unless `prompt_embeds` is supplied.
height (`int`, defaults to `704`):
The height in pixels of the generated image.
width (`int`, *optional*):
The width in pixels of the generated image. If not provided, this will be determined based on the
aspect ratio of the input and the provided height.
num_frames (`int`, defaults to `93`):
Number of output frames. Use `93` for world (video) generation; set to `1` to return a single frame.
num_inference_steps (`int`, defaults to `35`):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, defaults to `3.0`):
Guidance scale as defined in [Classifier-Free Diffusion
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
`guidance_scale > 1`.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor is generated by sampling using the supplied random `generator`.
controls (`PipelineImageInput`, `List[PipelineImageInput]`, *optional*):
Control image or video input used by the ControlNet. If `None`, ControlNet is skipped.
controls_conditioning_scale (`float` or `List[float]`, *optional*, defaults to `1.0`):
The scale factor(s) for the ControlNet outputs. A single float is broadcast to all control blocks.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not
provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`CosmosPipelineOutput`] instead of a plain tuple.
callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
max_sequence_length (`int`, defaults to `512`):
The maximum number of tokens in the prompt. If the prompt exceeds this length, it will be truncated. If
the prompt is shorter than this length, it will be padded.
Examples:
Returns:
[`~CosmosPipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`CosmosPipelineOutput`] is returned, otherwise a `tuple` is returned where
the first element is a list with the generated images and the second element is a list of `bool`s
indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
"""
if self.safety_checker is None:
raise ValueError(
f"You have disabled the safety checker for {self.__class__}. This is in violation of the "
"[NVIDIA Open Model License Agreement](https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-open-model-license). "
f"Please ensure that you are compliant with the license agreement."
)
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
if width is None:
frame = image or video[0] if image or video else None
if frame is None and controls is not None:
frame = controls[0] if isinstance(controls, list) else controls
if isinstance(frame, (torch.Tensor, np.ndarray)) and len(frame.shape) == 4:
frame = controls[0]
if frame is None:
width = int((height + 16) * (1280 / 720))
elif isinstance(frame, PIL.Image.Image):
width = int((height + 16) * (frame.width / frame.height))
else:
width = int((height + 16) * (frame.shape[2] / frame.shape[1])) # NOTE: assuming C H W
# Check inputs. Raise error if not correct
self.check_inputs(prompt, height, width, prompt_embeds, callback_on_step_end_tensor_inputs)
self._guidance_scale = guidance_scale
self._current_timestep = None
self._interrupt = False
device = self._execution_device
if self.safety_checker is not None:
self.safety_checker.to(device)
if prompt is not None:
prompt_list = [prompt] if isinstance(prompt, str) else prompt
for p in prompt_list:
if not self.safety_checker.check_text_safety(p):
raise ValueError(
f"Cosmos Guardrail detected unsafe text in the prompt: {p}. Please ensure that the "
f"prompt abides by the NVIDIA Open Model License Agreement."
)
# Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
# Encode input prompt
(
prompt_embeds,
negative_prompt_embeds,
) = self.encode_prompt(
prompt=prompt,
negative_prompt=negative_prompt,
do_classifier_free_guidance=self.do_classifier_free_guidance,
num_videos_per_prompt=num_videos_per_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
device=device,
max_sequence_length=max_sequence_length,
)
vae_dtype = self.vae.dtype
transformer_dtype = self.transformer.dtype
img_context = torch.zeros(
batch_size,
self.transformer.config.img_context_num_tokens,
self.transformer.config.img_context_dim_in,
device=prompt_embeds.device,
dtype=transformer_dtype,
)
encoder_hidden_states = (prompt_embeds, img_context)
neg_encoder_hidden_states = (negative_prompt_embeds, img_context)
num_frames_in = None
if image is not None:
if batch_size != 1:
raise ValueError(f"batch_size must be 1 for image input (given {batch_size})")
image = torchvision.transforms.functional.to_tensor(image).unsqueeze(0)
video = torch.cat([image, torch.zeros_like(image).repeat(num_frames - 1, 1, 1, 1)], dim=0)
video = video.unsqueeze(0)
num_frames_in = 1
elif video is None:
video = torch.zeros(batch_size, num_frames, 3, height, width, dtype=torch.uint8)
num_frames_in = 0
else:
num_frames_in = len(video)
if batch_size != 1:
raise ValueError(f"batch_size must be 1 for video input (given {batch_size})")
assert video is not None
video = self.video_processor.preprocess_video(video, height, width)
# pad with last frame (for video2world)
num_frames_out = num_frames
video = _maybe_pad_video(video, num_frames_out)
assert num_frames_in <= num_frames_out, f"expected ({num_frames_in=}) <= ({num_frames_out=})"
video = video.to(device=device, dtype=vae_dtype)
num_channels_latents = self.transformer.config.in_channels - 1
latents, cond_latent, cond_mask, cond_indicator = self.prepare_latents(
video=video,
batch_size=batch_size * num_videos_per_prompt,
num_channels_latents=num_channels_latents,
height=height,
width=width,
num_frames_in=num_frames_in,
num_frames_out=num_frames,
do_classifier_free_guidance=self.do_classifier_free_guidance,
dtype=torch.float32,
device=device,
generator=generator,
latents=latents,
)
cond_timestep = torch.ones_like(cond_indicator) * conditional_frame_timestep
cond_mask = cond_mask.to(transformer_dtype)
controls_latents = None
if controls is not None:
controls_latents = self._encode_controls(
controls,
height=height,
width=width,
num_frames=num_frames,
dtype=transformer_dtype,
device=device,
generator=generator,
)
padding_mask = latents.new_zeros(1, 1, height, width, dtype=transformer_dtype)
# Denoising loop
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
self._num_timesteps = len(timesteps)
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
gt_velocity = (latents - cond_latent) * cond_mask
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
self._current_timestep = t.cpu().item()
# NOTE: assumes sigma(t) \in [0, 1]
sigma_t = (
torch.tensor(self.scheduler.sigmas[i].item())
.unsqueeze(0)
.to(device=device, dtype=transformer_dtype)
)
in_latents = cond_mask * cond_latent + (1 - cond_mask) * latents
in_latents = in_latents.to(transformer_dtype)
in_timestep = cond_indicator * cond_timestep + (1 - cond_indicator) * sigma_t
control_blocks = None
if controls_latents is not None and self.controlnet is not None:
control_output = self.controlnet(
controls_latents=controls_latents,
latents=in_latents,
timestep=in_timestep,
encoder_hidden_states=encoder_hidden_states,
condition_mask=cond_mask,
conditioning_scale=controls_conditioning_scale,
padding_mask=padding_mask,
return_dict=False,
)
control_blocks = control_output[0]
noise_pred = self.transformer(
hidden_states=in_latents,
timestep=in_timestep,
encoder_hidden_states=encoder_hidden_states,
block_controlnet_hidden_states=control_blocks,
condition_mask=cond_mask,
padding_mask=padding_mask,
return_dict=False,
)[0]
noise_pred = gt_velocity + noise_pred * (1 - cond_mask)
if self.do_classifier_free_guidance:
control_blocks = None
if controls_latents is not None and self.controlnet is not None:
control_output = self.controlnet(
controls_latents=controls_latents,
latents=in_latents,
timestep=in_timestep,
encoder_hidden_states=neg_encoder_hidden_states, # NOTE: negative prompt
condition_mask=cond_mask,
conditioning_scale=controls_conditioning_scale,
padding_mask=padding_mask,
return_dict=False,
)
control_blocks = control_output[0]
noise_pred_neg = self.transformer(
hidden_states=in_latents,
timestep=in_timestep,
encoder_hidden_states=neg_encoder_hidden_states, # NOTE: negative prompt
block_controlnet_hidden_states=control_blocks,
condition_mask=cond_mask,
padding_mask=padding_mask,
return_dict=False,
)[0]
# NOTE: replace velocity (noise_pred_neg) with gt_velocity for conditioning inputs only
noise_pred_neg = gt_velocity + noise_pred_neg * (1 - cond_mask)
noise_pred = noise_pred + self.guidance_scale * (noise_pred - noise_pred_neg)
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
self._current_timestep = None
if not output_type == "latent":
latents_mean = self.latents_mean.to(latents.device, latents.dtype)
latents_std = self.latents_std.to(latents.device, latents.dtype)
latents = latents * latents_std + latents_mean
video = self.vae.decode(latents.to(self.vae.dtype), return_dict=False)[0]
video = self._match_num_frames(video, num_frames)
assert self.safety_checker is not None
self.safety_checker.to(device)
video = self.video_processor.postprocess_video(video, output_type="np")
video = (video * 255).astype(np.uint8)
video_batch = []
for vid in video:
vid = self.safety_checker.check_video_safety(vid)
if vid is None:
video_batch.append(np.zeros_like(video[0]))
else:
video_batch.append(vid)
video = np.stack(video_batch).astype(np.float32) / 255.0 * 2 - 1
video = torch.from_numpy(video).permute(0, 4, 1, 2, 3)
video = self.video_processor.postprocess_video(video, output_type=output_type)
else:
video = latents
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (video,)
return CosmosPipelineOutput(frames=video)
def _match_num_frames(self, video: torch.Tensor, target_num_frames: int) -> torch.Tensor:
if target_num_frames <= 0 or video.shape[2] == target_num_frames:
return video
frames_per_latent = max(self.vae_scale_factor_temporal, 1)
video = torch.repeat_interleave(video, repeats=frames_per_latent, dim=2)
current_frames = video.shape[2]
if current_frames < target_num_frames:
pad = video[:, :, -1:, :, :].repeat(1, 1, target_num_frames - current_frames, 1, 1)
video = torch.cat([video, pad], dim=2)
elif current_frames > target_num_frames:
video = video[:, :, :target_num_frames]
return video

View File

@@ -658,12 +658,7 @@ class GlmImagePipeline(DiffusionPipeline):
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if prompt is not None and prior_token_ids is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prior_token_ids`: {prior_token_ids}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prior_token_ids is None:
if prompt is None and prior_token_ids is None:
raise ValueError(
"Provide either `prompt` or `prior_token_ids`. Cannot leave both `prompt` and `prior_token_ids` undefined."
)
@@ -694,8 +689,8 @@ class GlmImagePipeline(DiffusionPipeline):
"for i2i mode, as the images are needed for VAE encoding to build the KV cache."
)
if prior_token_ids is not None and prompt_embeds is None:
raise ValueError("`prompt_embeds` must also be provided with `prior_token_ids`.")
if prior_token_ids is not None and prompt_embeds is None and prompt is None:
raise ValueError("`prompt_embeds` or `prompt` must also be provided with `prior_token_ids`.")
@property
def guidance_scale(self):

View File

@@ -1081,6 +1081,10 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
audio_coords = self.transformer.audio_rope.prepare_audio_coords(
audio_latents.shape[0], audio_num_frames, audio_latents.device
)
# Duplicate the positional ids as well if using CFG
if self.do_classifier_free_guidance:
video_coords = video_coords.repeat((2,) + (1,) * (video_coords.ndim - 1)) # Repeat twice in batch dim
audio_coords = audio_coords.repeat((2,) + (1,) * (audio_coords.ndim - 1))
# 7. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:

View File

@@ -1139,6 +1139,10 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL
audio_coords = self.transformer.audio_rope.prepare_audio_coords(
audio_latents.shape[0], audio_num_frames, audio_latents.device
)
# Duplicate the positional ids as well if using CFG
if self.do_classifier_free_guidance:
video_coords = video_coords.repeat((2,) + (1,) * (video_coords.ndim - 1)) # Repeat twice in batch dim
audio_coords = audio_coords.repeat((2,) + (1,) * (audio_coords.ndim - 1))
# 7. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:

View File

@@ -18,7 +18,6 @@ import re
from copy import deepcopy
from typing import Any, Callable, Dict, List, Optional, Union
import ftfy
import torch
from transformers import AutoTokenizer, UMT5EncoderModel

View File

@@ -18,7 +18,6 @@ import re
from copy import deepcopy
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import ftfy
import PIL
import torch
from transformers import AutoTokenizer, UMT5EncoderModel

View File

@@ -19,7 +19,6 @@ import re
from copy import deepcopy
from typing import Any, Callable, Dict, List, Optional, Union
import ftfy
import torch
from PIL import Image
from transformers import AutoTokenizer, UMT5EncoderModel

View File

@@ -41,7 +41,7 @@ class GGUFQuantizer(DiffusersQuantizer):
self.compute_dtype = quantization_config.compute_dtype
self.pre_quantized = quantization_config.pre_quantized
self.modules_to_not_convert = quantization_config.modules_to_not_convert
self.modules_to_not_convert = quantization_config.modules_to_not_convert or []
if not isinstance(self.modules_to_not_convert, list):
self.modules_to_not_convert = [self.modules_to_not_convert]

View File

@@ -15,7 +15,7 @@
# DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver and https://github.com/NVlabs/edm
import math
from typing import List, Optional, Tuple, Union
from typing import List, Literal, Optional, Tuple, Union
import numpy as np
import torch
@@ -51,13 +51,15 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
schedule was incorporated in this model: https://huggingface.co/stabilityai/cosxl.
num_train_timesteps (`int`, defaults to 1000):
The number of diffusion steps to train the model.
solver_order (`int`, defaults to 2):
The DPMSolver order which can be `1` or `2` or `3`. It is recommended to use `solver_order=2` for guided
sampling, and `solver_order=3` for unconditional sampling.
prediction_type (`str`, defaults to `epsilon`, *optional*):
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
Video](https://huggingface.co/papers/2210.02303) paper).
rho (`float`, *optional*, defaults to 7.0):
The rho parameter in the Karras sigma schedule. This was set to 7.0 in the EDM paper [1].
solver_order (`int`, defaults to 2):
The DPMSolver order which can be `1` or `2` or `3`. It is recommended to use `solver_order=2` for guided
sampling, and `solver_order=3` for unconditional sampling.
thresholding (`bool`, defaults to `False`):
Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
as Stable Diffusion.
@@ -94,19 +96,19 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
sigma_min: float = 0.002,
sigma_max: float = 80.0,
sigma_data: float = 0.5,
sigma_schedule: str = "karras",
sigma_schedule: Literal["karras", "exponential"] = "karras",
num_train_timesteps: int = 1000,
prediction_type: str = "epsilon",
prediction_type: Literal["epsilon", "sample", "v_prediction"] = "epsilon",
rho: float = 7.0,
solver_order: int = 2,
thresholding: bool = False,
dynamic_thresholding_ratio: float = 0.995,
sample_max_value: float = 1.0,
algorithm_type: str = "dpmsolver++",
solver_type: str = "midpoint",
algorithm_type: Literal["dpmsolver++", "sde-dpmsolver++"] = "dpmsolver++",
solver_type: Literal["midpoint", "heun"] = "midpoint",
lower_order_final: bool = True,
euler_at_final: bool = False,
final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
final_sigmas_type: Optional[Literal["zero", "sigma_min"]] = "zero", # "zero", "sigma_min"
):
# settings for DPM-Solver
if algorithm_type not in ["dpmsolver++", "sde-dpmsolver++"]:
@@ -145,19 +147,19 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
@property
def init_noise_sigma(self):
def init_noise_sigma(self) -> float:
# standard deviation of the initial noise distribution
return (self.config.sigma_max**2 + 1) ** 0.5
@property
def step_index(self):
def step_index(self) -> int:
"""
The index counter for current timestep. It will increase 1 after each scheduler step.
"""
return self._step_index
@property
def begin_index(self):
def begin_index(self) -> int:
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
"""
@@ -274,7 +276,11 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
self.is_scale_input_called = True
return sample
def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None):
def set_timesteps(
self,
num_inference_steps: int = None,
device: Optional[Union[str, torch.device]] = None,
):
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
@@ -460,13 +466,12 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
def _sigma_to_alpha_sigma_t(self, sigma):
alpha_t = torch.tensor(1) # Inputs are pre-scaled before going into unet, so alpha_t = 1
sigma_t = sigma
return alpha_t, sigma_t
def convert_model_output(
self,
model_output: torch.Tensor,
sample: torch.Tensor = None,
sample: torch.Tensor,
) -> torch.Tensor:
"""
Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is
@@ -497,7 +502,7 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
def dpm_solver_first_order_update(
self,
model_output: torch.Tensor,
sample: torch.Tensor = None,
sample: torch.Tensor,
noise: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
@@ -508,6 +513,8 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
The direct output from the learned diffusion model.
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
noise (`torch.Tensor`, *optional*):
The noise tensor to add to the original samples.
Returns:
`torch.Tensor`:
@@ -538,7 +545,7 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
def multistep_dpm_solver_second_order_update(
self,
model_output_list: List[torch.Tensor],
sample: torch.Tensor = None,
sample: torch.Tensor,
noise: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
@@ -549,6 +556,8 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
The direct outputs from learned diffusion model at current and latter timesteps.
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
noise (`torch.Tensor`, *optional*):
The noise tensor to add to the original samples.
Returns:
`torch.Tensor`:
@@ -609,7 +618,7 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
def multistep_dpm_solver_third_order_update(
self,
model_output_list: List[torch.Tensor],
sample: torch.Tensor = None,
sample: torch.Tensor,
) -> torch.Tensor:
"""
One step for the third-order multistep DPMSolver.
@@ -698,7 +707,7 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
return step_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index
def _init_step_index(self, timestep):
def _init_step_index(self, timestep: Union[int, torch.Tensor]) -> None:
"""
Initialize the step_index counter for the scheduler.
@@ -719,7 +728,7 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
model_output: torch.Tensor,
timestep: Union[int, torch.Tensor],
sample: torch.Tensor,
generator=None,
generator: Optional[torch.Generator] = None,
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
"""
@@ -860,5 +869,5 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
c_in = 1 / ((sigma**2 + self.config.sigma_data**2) ** 0.5)
return c_in
def __len__(self):
def __len__(self) -> int:
return self.config.num_train_timesteps

View File

@@ -14,7 +14,7 @@
import math
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
from typing import List, Literal, Optional, Tuple, Union
import numpy as np
import torch
@@ -102,12 +102,21 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
use_karras_sigmas: Optional[bool] = False,
use_exponential_sigmas: Optional[bool] = False,
use_beta_sigmas: Optional[bool] = False,
time_shift_type: str = "exponential",
time_shift_type: Literal["exponential", "linear"] = "exponential",
stochastic_sampling: bool = False,
):
if self.config.use_beta_sigmas and not is_scipy_available():
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
if (
sum(
[
self.config.use_beta_sigmas,
self.config.use_exponential_sigmas,
self.config.use_karras_sigmas,
]
)
> 1
):
raise ValueError(
"Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
)
@@ -166,6 +175,13 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
self._begin_index = begin_index
def set_shift(self, shift: float):
"""
Sets the shift value for the scheduler.
Args:
shift (`float`):
The shift value to be set.
"""
self._shift = shift
def scale_noise(
@@ -218,10 +234,25 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
return sample
def _sigma_to_t(self, sigma):
def _sigma_to_t(self, sigma) -> float:
return sigma * self.config.num_train_timesteps
def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
def time_shift(self, mu: float, sigma: float, t: torch.Tensor) -> torch.Tensor:
"""
Apply time shifting to the sigmas.
Args:
mu (`float`):
The mu parameter for the time shift.
sigma (`float`):
The sigma parameter for the time shift.
t (`torch.Tensor`):
The input timesteps.
Returns:
`torch.Tensor`:
The time-shifted timesteps.
"""
if self.config.time_shift_type == "exponential":
return self._time_shift_exponential(mu, sigma, t)
elif self.config.time_shift_type == "linear":
@@ -302,7 +333,9 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
if sigmas is None:
if timesteps is None:
timesteps = np.linspace(
self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps
self._sigma_to_t(self.sigma_max),
self._sigma_to_t(self.sigma_min),
num_inference_steps,
)
sigmas = timesteps / self.config.num_train_timesteps
else:
@@ -350,7 +383,24 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
self._step_index = None
self._begin_index = None
def index_for_timestep(self, timestep, schedule_timesteps=None):
def index_for_timestep(
self,
timestep: Union[float, torch.FloatTensor],
schedule_timesteps: Optional[torch.FloatTensor] = None,
) -> int:
"""
Get the index for the given timestep.
Args:
timestep (`float` or `torch.FloatTensor`):
The timestep to find the index for.
schedule_timesteps (`torch.FloatTensor`, *optional*):
The schedule timesteps to validate against. If `None`, the scheduler's timesteps are used.
Returns:
`int`:
The index of the timestep.
"""
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
@@ -364,7 +414,7 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
return indices[pos].item()
def _init_step_index(self, timestep):
def _init_step_index(self, timestep: Union[float, torch.FloatTensor]) -> None:
if self.begin_index is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
@@ -405,7 +455,7 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
A random number generator.
per_token_timesteps (`torch.Tensor`, *optional*):
The timesteps for each token in the sample.
return_dict (`bool`):
return_dict (`bool`, defaults to `True`):
Whether or not to return a
[`~schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteSchedulerOutput`] or tuple.
@@ -474,7 +524,7 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample)
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
"""
Construct the noise schedule as proposed in [Elucidating the Design Space of Diffusion-Based Generative
Models](https://huggingface.co/papers/2206.00364).
@@ -595,11 +645,11 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
)
return sigmas
def _time_shift_exponential(self, mu, sigma, t):
def _time_shift_exponential(self, mu: float, sigma: float, t: torch.Tensor) -> torch.Tensor:
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
def _time_shift_linear(self, mu, sigma, t):
def _time_shift_linear(self, mu: float, sigma: float, t: torch.Tensor) -> torch.Tensor:
return mu / (mu + (1 / t - 1) ** sigma)
def __len__(self):
def __len__(self) -> int:
return self.config.num_train_timesteps

View File

@@ -482,6 +482,21 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler.time_shift
def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
"""
Apply time shifting to the sigmas.
Args:
mu (`float`):
The mu parameter for the time shift.
sigma (`float`):
The sigma parameter for the time shift.
t (`torch.Tensor`):
The input timesteps.
Returns:
`torch.Tensor`:
The time-shifted timesteps.
"""
if self.config.time_shift_type == "exponential":
return self._time_shift_exponential(mu, sigma, t)
elif self.config.time_shift_type == "linear":

View File

@@ -896,6 +896,21 @@ class ControlNetXSAdapter(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class CosmosControlNetModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class CosmosTransformer3DModel(metaclass=DummyObject):
_backends = ["torch"]

View File

@@ -977,6 +977,21 @@ class Cosmos2_5_PredictBasePipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class Cosmos2_5_TransferPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class Cosmos2TextToImagePipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]

View File

View File

@@ -0,0 +1,255 @@
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import torch
from diffusers import CosmosControlNetModel
from diffusers.models.controlnets.controlnet_cosmos import CosmosControlNetOutput
from ...testing_utils import enable_full_determinism, torch_device
from ..test_modeling_common import ModelTesterMixin
enable_full_determinism()
class CosmosControlNetModelTests(ModelTesterMixin, unittest.TestCase):
model_class = CosmosControlNetModel
main_input_name = "controls_latents"
uses_custom_attn_processor = True
@property
def dummy_input(self):
batch_size = 1
num_channels = 16
num_frames = 1
height = 16
width = 16
text_embed_dim = 32
sequence_length = 12
img_context_dim_in = 32
img_context_num_tokens = 4
# Raw latents (not patchified) - the controlnet computes embeddings internally
controls_latents = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
latents = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
timestep = torch.tensor([0.5]).to(torch_device) # Diffusion timestep
condition_mask = torch.ones(batch_size, 1, num_frames, height, width).to(torch_device)
padding_mask = torch.zeros(batch_size, 1, height, width).to(torch_device)
# Text embeddings
text_context = torch.randn((batch_size, sequence_length, text_embed_dim)).to(torch_device)
# Image context for Cosmos 2.5
img_context = torch.randn((batch_size, img_context_num_tokens, img_context_dim_in)).to(torch_device)
encoder_hidden_states = (text_context, img_context)
return {
"controls_latents": controls_latents,
"latents": latents,
"timestep": timestep,
"encoder_hidden_states": encoder_hidden_states,
"condition_mask": condition_mask,
"conditioning_scale": 1.0,
"padding_mask": padding_mask,
}
@property
def input_shape(self):
return (16, 1, 16, 16)
@property
def output_shape(self):
# Output is tuple of n_controlnet_blocks tensors, each with shape (batch, num_patches, model_channels)
# After stacking by normalize_output: (n_blocks, batch, num_patches, model_channels)
# For test config: n_blocks=2, num_patches=64 (1*8*8), model_channels=32
# output_shape is used as (batch_size,) + output_shape, so: (2, 64, 32)
return (2, 64, 32)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"n_controlnet_blocks": 2,
"in_channels": 16 + 1 + 1, # control_latent_channels + condition_mask + padding_mask
"latent_channels": 16 + 1 + 1, # base_latent_channels (16) + condition_mask (1) + padding_mask (1) = 18
"model_channels": 32,
"num_attention_heads": 2,
"attention_head_dim": 16,
"mlp_ratio": 2,
"text_embed_dim": 32,
"adaln_lora_dim": 4,
"patch_size": (1, 2, 2),
"max_size": (4, 32, 32),
"rope_scale": (2.0, 1.0, 1.0),
"extra_pos_embed_type": None,
"img_context_dim_in": 32,
"img_context_dim_out": 32,
"use_crossattn_projection": False, # Test doesn't need this projection
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_output_format(self):
"""Test that the model outputs CosmosControlNetOutput with correct structure."""
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
with torch.no_grad():
output = model(**inputs_dict)
self.assertIsInstance(output, CosmosControlNetOutput)
self.assertIsInstance(output.control_block_samples, list)
self.assertEqual(len(output.control_block_samples), init_dict["n_controlnet_blocks"])
for tensor in output.control_block_samples:
self.assertIsInstance(tensor, torch.Tensor)
def test_output_list_format(self):
"""Test that return_dict=False returns a tuple containing a list."""
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
with torch.no_grad():
output = model(**inputs_dict, return_dict=False)
self.assertIsInstance(output, tuple)
self.assertEqual(len(output), 1)
self.assertIsInstance(output[0], list)
self.assertEqual(len(output[0]), init_dict["n_controlnet_blocks"])
def test_conditioning_scale_single(self):
"""Test that a single conditioning scale is broadcast to all blocks."""
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
inputs_dict["conditioning_scale"] = 0.5
with torch.no_grad():
output = model(**inputs_dict)
self.assertEqual(len(output.control_block_samples), init_dict["n_controlnet_blocks"])
def test_conditioning_scale_list(self):
"""Test that a list of conditioning scales is applied per block."""
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
# Provide a scale for each block
inputs_dict["conditioning_scale"] = [0.5, 1.0]
with torch.no_grad():
output = model(**inputs_dict)
self.assertEqual(len(output.control_block_samples), init_dict["n_controlnet_blocks"])
def test_forward_with_none_img_context(self):
"""Test forward pass when img_context is None."""
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
# Set encoder_hidden_states to (text_context, None)
text_context = inputs_dict["encoder_hidden_states"][0]
inputs_dict["encoder_hidden_states"] = (text_context, None)
with torch.no_grad():
output = model(**inputs_dict)
self.assertIsInstance(output, CosmosControlNetOutput)
self.assertEqual(len(output.control_block_samples), init_dict["n_controlnet_blocks"])
def test_forward_without_img_context_proj(self):
"""Test forward pass when img_context_proj is not configured."""
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
# Disable img_context_proj
init_dict["img_context_dim_in"] = None
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
# When img_context is disabled, pass only text context (not a tuple)
text_context = inputs_dict["encoder_hidden_states"][0]
inputs_dict["encoder_hidden_states"] = text_context
with torch.no_grad():
output = model(**inputs_dict)
self.assertIsInstance(output, CosmosControlNetOutput)
self.assertEqual(len(output.control_block_samples), init_dict["n_controlnet_blocks"])
def test_gradient_checkpointing_is_applied(self):
expected_set = {"CosmosControlNetModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
# Note: test_set_attn_processor_for_determinism already handles uses_custom_attn_processor=True
# so no explicit skip needed for it
# Note: test_forward_signature and test_set_default_attn_processor don't exist in base class
# Skip tests that don't apply to this architecture
@unittest.skip("CosmosControlNetModel doesn't use norm groups.")
def test_forward_with_norm_groups(self):
pass
# Skip tests that expect .sample attribute - ControlNets don't have this
@unittest.skip("ControlNet output doesn't have .sample attribute")
def test_effective_gradient_checkpointing(self):
pass
# Skip tests that compute MSE loss against single tensor output
@unittest.skip("ControlNet outputs list of control blocks, not single tensor for MSE loss")
def test_ema_training(self):
pass
@unittest.skip("ControlNet outputs list of control blocks, not single tensor for MSE loss")
def test_training(self):
pass
# Skip tests where output shape comparison doesn't apply to ControlNets
@unittest.skip("ControlNet output shape doesn't match input shape by design")
def test_output(self):
pass
# Skip outputs_equivalence - dict/list comparison logic not compatible (recursive_check expects dict.values())
@unittest.skip("ControlNet output structure not compatible with recursive dict check")
def test_outputs_equivalence(self):
pass
# Skip model parallelism - base test uses torch.allclose(base_output[0], new_output[0]) which fails
# because output[0] is the list of control_block_samples, not a tensor
@unittest.skip("test_model_parallelism uses torch.allclose on output[0] which is a list, not a tensor")
def test_model_parallelism(self):
pass
# Skip layerwise casting tests - these have two issues:
# 1. _inference and _memory: dtype compatibility issues with learnable_pos_embed and float8/bfloat16
# 2. _training: same as test_training - mse_loss expects tensor, not list
@unittest.skip("Layerwise casting has dtype issues with learnable_pos_embed")
def test_layerwise_casting_inference(self):
pass
@unittest.skip("Layerwise casting has dtype issues with learnable_pos_embed")
def test_layerwise_casting_memory(self):
pass
@unittest.skip("test_layerwise_casting_training computes mse_loss on list output")
def test_layerwise_casting_training(self):
pass

View File

@@ -446,16 +446,17 @@ class ModelTesterMixin:
torch_device not in ["cuda", "xpu"],
reason="float16 and bfloat16 can only be used with an accelerator",
)
def test_keep_in_fp32_modules(self):
def test_keep_in_fp32_modules(self, tmp_path):
model = self.model_class(**self.get_init_dict())
fp32_modules = model._keep_in_fp32_modules
if fp32_modules is None or len(fp32_modules) == 0:
pytest.skip("Model does not have _keep_in_fp32_modules defined.")
# Test with float16
model.to(torch_device)
model.to(torch.float16)
# Save the model and reload with float16 dtype
# _keep_in_fp32_modules is only enforced during from_pretrained loading
model.save_pretrained(tmp_path)
model = self.model_class.from_pretrained(tmp_path, torch_dtype=torch.float16).to(torch_device)
for name, param in model.named_parameters():
if any(module_to_keep_in_fp32 in name.split(".") for module_to_keep_in_fp32 in fp32_modules):
@@ -470,7 +471,7 @@ class ModelTesterMixin:
)
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"])
@torch.no_grad()
def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype):
def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype, atol=1e-4, rtol=0):
model = self.model_class(**self.get_init_dict())
model.to(torch_device)
fp32_modules = model._keep_in_fp32_modules or []
@@ -490,10 +491,6 @@ class ModelTesterMixin:
output = model(**inputs, return_dict=False)[0]
output_loaded = model_loaded(**inputs, return_dict=False)[0]
self._check_dtype_inference_output(output, output_loaded, dtype)
def _check_dtype_inference_output(self, output, output_loaded, dtype, atol=1e-4, rtol=0):
"""Check dtype inference output with configurable tolerance."""
assert_tensors_close(
output, output_loaded, atol=atol, rtol=rtol, msg=f"Loaded model output differs for {dtype}"
)

View File

@@ -176,15 +176,7 @@ class QuantizationTesterMixin:
model_quantized = self._create_quantized_model(config_kwargs)
model_quantized.to(torch_device)
# Get model dtype from first parameter
model_dtype = next(model_quantized.parameters()).dtype
inputs = self.get_dummy_inputs()
# Cast inputs to model dtype
inputs = {
k: v.to(model_dtype) if isinstance(v, torch.Tensor) and v.is_floating_point() else v
for k, v in inputs.items()
}
output = model_quantized(**inputs, return_dict=False)[0]
assert output is not None, "Model output is None"
@@ -229,6 +221,8 @@ class QuantizationTesterMixin:
init_lora_weights=False,
)
model.add_adapter(lora_config)
# Move LoRA adapter weights to device (they default to CPU)
model.to(torch_device)
inputs = self.get_dummy_inputs()
output = model(**inputs, return_dict=False)[0]
@@ -1021,9 +1015,6 @@ class GGUFTesterMixin(GGUFConfigMixin, QuantizationTesterMixin):
"""Test that dequantize() works correctly."""
self._test_dequantize({"compute_dtype": torch.bfloat16})
def test_gguf_quantized_layers(self):
self._test_quantized_layers({"compute_dtype": torch.bfloat16})
@is_quantization
@is_modelopt

View File

@@ -12,57 +12,57 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import pytest
import torch
from diffusers import WanTransformer3DModel
from diffusers.utils.torch_utils import randn_tensor
from ...testing_utils import (
enable_full_determinism,
torch_device,
from ...testing_utils import enable_full_determinism, torch_device
from ..testing_utils import (
AttentionTesterMixin,
BaseModelTesterConfig,
BitsAndBytesTesterMixin,
GGUFCompileTesterMixin,
GGUFTesterMixin,
MemoryTesterMixin,
ModelTesterMixin,
TorchAoTesterMixin,
TorchCompileTesterMixin,
TrainingTesterMixin,
)
from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
enable_full_determinism()
class WanTransformer3DTests(ModelTesterMixin, unittest.TestCase):
model_class = WanTransformer3DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
class WanTransformer3DTesterConfig(BaseModelTesterConfig):
@property
def model_class(self):
return WanTransformer3DModel
@property
def dummy_input(self):
batch_size = 1
num_channels = 4
num_frames = 2
height = 16
width = 16
text_encoder_embedding_dim = 16
sequence_length = 12
def pretrained_model_name_or_path(self):
return "hf-internal-testing/tiny-wan22-transformer"
hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, text_encoder_embedding_dim)).to(torch_device)
@property
def output_shape(self) -> tuple[int, ...]:
return (4, 2, 16, 16)
@property
def input_shape(self) -> tuple[int, ...]:
return (4, 2, 16, 16)
@property
def main_input_name(self) -> str:
return "hidden_states"
@property
def generator(self):
return torch.Generator("cpu").manual_seed(0)
def get_init_dict(self) -> dict[str, int | list[int] | tuple | str | bool]:
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"timestep": timestep,
}
@property
def input_shape(self):
return (4, 1, 16, 16)
@property
def output_shape(self):
return (4, 1, 16, 16)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"patch_size": (1, 2, 2),
"num_attention_heads": 2,
"attention_head_dim": 12,
@@ -76,16 +76,160 @@ class WanTransformer3DTests(ModelTesterMixin, unittest.TestCase):
"qk_norm": "rms_norm_across_heads",
"rope_max_seq_len": 32,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
batch_size = 1
num_channels = 4
num_frames = 2
height = 16
width = 16
text_encoder_embedding_dim = 16
sequence_length = 12
return {
"hidden_states": randn_tensor(
(batch_size, num_channels, num_frames, height, width),
generator=self.generator,
device=torch_device,
),
"encoder_hidden_states": randn_tensor(
(batch_size, sequence_length, text_encoder_embedding_dim),
generator=self.generator,
device=torch_device,
),
"timestep": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device),
}
class TestWanTransformer3D(WanTransformer3DTesterConfig, ModelTesterMixin):
"""Core model tests for Wan Transformer 3D."""
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"])
def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype):
# Skip: fp16/bf16 require very high atol to pass, providing little signal.
# Dtype preservation is already tested by test_from_save_pretrained_dtype and test_keep_in_fp32_modules.
pytest.skip("Tolerance requirements too high for meaningful test")
class TestWanTransformer3DMemory(WanTransformer3DTesterConfig, MemoryTesterMixin):
"""Memory optimization tests for Wan Transformer 3D."""
class TestWanTransformer3DTraining(WanTransformer3DTesterConfig, TrainingTesterMixin):
"""Training tests for Wan Transformer 3D."""
def test_gradient_checkpointing_is_applied(self):
expected_set = {"WanTransformer3DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class WanTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
model_class = WanTransformer3DModel
class TestWanTransformer3DAttention(WanTransformer3DTesterConfig, AttentionTesterMixin):
"""Attention processor tests for Wan Transformer 3D."""
def prepare_init_args_and_inputs_for_common(self):
return WanTransformer3DTests().prepare_init_args_and_inputs_for_common()
class TestWanTransformer3DCompile(WanTransformer3DTesterConfig, TorchCompileTesterMixin):
"""Torch compile tests for Wan Transformer 3D."""
class TestWanTransformer3DBitsAndBytes(WanTransformer3DTesterConfig, BitsAndBytesTesterMixin):
"""BitsAndBytes quantization tests for Wan Transformer 3D."""
@property
def torch_dtype(self):
return torch.float16
def get_dummy_inputs(self):
"""Override to provide inputs matching the tiny Wan model dimensions."""
return {
"hidden_states": randn_tensor(
(1, 36, 2, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"encoder_hidden_states": randn_tensor(
(1, 512, 4096), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"timestep": torch.tensor([1.0]).to(torch_device, self.torch_dtype),
}
class TestWanTransformer3DTorchAo(WanTransformer3DTesterConfig, TorchAoTesterMixin):
"""TorchAO quantization tests for Wan Transformer 3D."""
@property
def torch_dtype(self):
return torch.bfloat16
def get_dummy_inputs(self):
"""Override to provide inputs matching the tiny Wan model dimensions."""
return {
"hidden_states": randn_tensor(
(1, 36, 2, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"encoder_hidden_states": randn_tensor(
(1, 512, 4096), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"timestep": torch.tensor([1.0]).to(torch_device, self.torch_dtype),
}
class TestWanTransformer3DGGUF(WanTransformer3DTesterConfig, GGUFTesterMixin):
"""GGUF quantization tests for Wan Transformer 3D."""
@property
def gguf_filename(self):
return "https://huggingface.co/QuantStack/Wan2.2-I2V-A14B-GGUF/blob/main/LowNoise/Wan2.2-I2V-A14B-LowNoise-Q2_K.gguf"
@property
def torch_dtype(self):
return torch.bfloat16
def _create_quantized_model(self, config_kwargs=None, **extra_kwargs):
return super()._create_quantized_model(
config_kwargs, config="Wan-AI/Wan2.2-I2V-A14B-Diffusers", subfolder="transformer", **extra_kwargs
)
def get_dummy_inputs(self):
"""Override to provide inputs matching the real Wan I2V model dimensions.
Wan 2.2 I2V: in_channels=36, text_dim=4096
"""
return {
"hidden_states": randn_tensor(
(1, 36, 2, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"encoder_hidden_states": randn_tensor(
(1, 512, 4096), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"timestep": torch.tensor([1.0]).to(torch_device, self.torch_dtype),
}
class TestWanTransformer3DGGUFCompile(WanTransformer3DTesterConfig, GGUFCompileTesterMixin):
"""GGUF + compile tests for Wan Transformer 3D."""
@property
def gguf_filename(self):
return "https://huggingface.co/QuantStack/Wan2.2-I2V-A14B-GGUF/blob/main/LowNoise/Wan2.2-I2V-A14B-LowNoise-Q2_K.gguf"
@property
def torch_dtype(self):
return torch.bfloat16
def _create_quantized_model(self, config_kwargs=None, **extra_kwargs):
return super()._create_quantized_model(
config_kwargs, config="Wan-AI/Wan2.2-I2V-A14B-Diffusers", subfolder="transformer", **extra_kwargs
)
def get_dummy_inputs(self):
"""Override to provide inputs matching the real Wan I2V model dimensions.
Wan 2.2 I2V: in_channels=36, text_dim=4096
"""
return {
"hidden_states": randn_tensor(
(1, 36, 2, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"encoder_hidden_states": randn_tensor(
(1, 512, 4096), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"timestep": torch.tensor([1.0]).to(torch_device, self.torch_dtype),
}

View File

@@ -12,76 +12,62 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import pytest
import torch
from diffusers import WanAnimateTransformer3DModel
from diffusers.utils.torch_utils import randn_tensor
from ...testing_utils import (
enable_full_determinism,
torch_device,
from ...testing_utils import enable_full_determinism, torch_device
from ..testing_utils import (
AttentionTesterMixin,
BaseModelTesterConfig,
BitsAndBytesTesterMixin,
GGUFCompileTesterMixin,
GGUFTesterMixin,
MemoryTesterMixin,
ModelTesterMixin,
TorchAoTesterMixin,
TorchCompileTesterMixin,
TrainingTesterMixin,
)
from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
enable_full_determinism()
class WanAnimateTransformer3DTests(ModelTesterMixin, unittest.TestCase):
model_class = WanAnimateTransformer3DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
class WanAnimateTransformer3DTesterConfig(BaseModelTesterConfig):
@property
def model_class(self):
return WanAnimateTransformer3DModel
@property
def dummy_input(self):
batch_size = 1
num_channels = 4
num_frames = 20 # To make the shapes work out; for complicated reasons we want 21 to divide num_frames + 1
height = 16
width = 16
text_encoder_embedding_dim = 16
sequence_length = 12
clip_seq_len = 12
clip_dim = 16
inference_segment_length = 77 # The inference segment length in the full Wan2.2-Animate-14B model
face_height = 16 # Should be square and match `motion_encoder_size` below
face_width = 16
hidden_states = torch.randn((batch_size, 2 * num_channels + 4, num_frames + 1, height, width)).to(torch_device)
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, text_encoder_embedding_dim)).to(torch_device)
clip_ref_features = torch.randn((batch_size, clip_seq_len, clip_dim)).to(torch_device)
pose_latents = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
face_pixel_values = torch.randn((batch_size, 3, inference_segment_length, face_height, face_width)).to(
torch_device
)
return {
"hidden_states": hidden_states,
"timestep": timestep,
"encoder_hidden_states": encoder_hidden_states,
"encoder_hidden_states_image": clip_ref_features,
"pose_hidden_states": pose_latents,
"face_pixel_values": face_pixel_values,
}
def pretrained_model_name_or_path(self):
return "hf-internal-testing/tiny-wan-animate-transformer"
@property
def input_shape(self):
return (12, 1, 16, 16)
def output_shape(self) -> tuple[int, ...]:
# Output has fewer channels than input (4 vs 12)
return (4, 21, 16, 16)
@property
def output_shape(self):
return (4, 1, 16, 16)
def input_shape(self) -> tuple[int, ...]:
return (12, 21, 16, 16)
def prepare_init_args_and_inputs_for_common(self):
@property
def main_input_name(self) -> str:
return "hidden_states"
@property
def generator(self):
return torch.Generator("cpu").manual_seed(0)
def get_init_dict(self) -> dict[str, int | list[int] | tuple | str | bool | float | dict]:
# Use custom channel sizes since the default Wan Animate channel sizes will cause the motion encoder to
# contain the vast majority of the parameters in the test model
channel_sizes = {"4": 16, "8": 16, "16": 16}
init_dict = {
return {
"patch_size": (1, 2, 2),
"num_attention_heads": 2,
"attention_head_dim": 12,
@@ -105,22 +91,219 @@ class WanAnimateTransformer3DTests(ModelTesterMixin, unittest.TestCase):
"face_encoder_num_heads": 2,
"inject_face_latents_blocks": 2,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
batch_size = 1
num_channels = 4
num_frames = 20 # To make the shapes work out; for complicated reasons we want 21 to divide num_frames + 1
height = 16
width = 16
text_encoder_embedding_dim = 16
sequence_length = 12
clip_seq_len = 12
clip_dim = 16
inference_segment_length = 77 # The inference segment length in the full Wan2.2-Animate-14B model
face_height = 16 # Should be square and match `motion_encoder_size`
face_width = 16
return {
"hidden_states": randn_tensor(
(batch_size, 2 * num_channels + 4, num_frames + 1, height, width),
generator=self.generator,
device=torch_device,
),
"timestep": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device),
"encoder_hidden_states": randn_tensor(
(batch_size, sequence_length, text_encoder_embedding_dim),
generator=self.generator,
device=torch_device,
),
"encoder_hidden_states_image": randn_tensor(
(batch_size, clip_seq_len, clip_dim),
generator=self.generator,
device=torch_device,
),
"pose_hidden_states": randn_tensor(
(batch_size, num_channels, num_frames, height, width),
generator=self.generator,
device=torch_device,
),
"face_pixel_values": randn_tensor(
(batch_size, 3, inference_segment_length, face_height, face_width),
generator=self.generator,
device=torch_device,
),
}
class TestWanAnimateTransformer3D(WanAnimateTransformer3DTesterConfig, ModelTesterMixin):
"""Core model tests for Wan Animate Transformer 3D."""
def test_output(self):
# Override test_output because the transformer output is expected to have less channels
# than the main transformer input.
expected_output_shape = (1, 4, 21, 16, 16)
super().test_output(expected_output_shape=expected_output_shape)
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"])
def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype):
# Skip: fp16/bf16 require very high atol (~1e-2) to pass, providing little signal.
# Dtype preservation is already tested by test_from_save_pretrained_dtype and test_keep_in_fp32_modules.
pytest.skip("Tolerance requirements too high for meaningful test")
class TestWanAnimateTransformer3DMemory(WanAnimateTransformer3DTesterConfig, MemoryTesterMixin):
"""Memory optimization tests for Wan Animate Transformer 3D."""
class TestWanAnimateTransformer3DTraining(WanAnimateTransformer3DTesterConfig, TrainingTesterMixin):
"""Training tests for Wan Animate Transformer 3D."""
def test_gradient_checkpointing_is_applied(self):
expected_set = {"WanAnimateTransformer3DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
# Override test_output because the transformer output is expected to have less channels than the main transformer
# input.
def test_output(self):
expected_output_shape = (1, 4, 21, 16, 16)
super().test_output(expected_output_shape=expected_output_shape)
class TestWanAnimateTransformer3DAttention(WanAnimateTransformer3DTesterConfig, AttentionTesterMixin):
"""Attention processor tests for Wan Animate Transformer 3D."""
class WanAnimateTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
model_class = WanAnimateTransformer3DModel
class TestWanAnimateTransformer3DCompile(WanAnimateTransformer3DTesterConfig, TorchCompileTesterMixin):
"""Torch compile tests for Wan Animate Transformer 3D."""
def prepare_init_args_and_inputs_for_common(self):
return WanAnimateTransformer3DTests().prepare_init_args_and_inputs_for_common()
def test_torch_compile_recompilation_and_graph_break(self):
# Skip: F.pad with mode="replicate" in WanAnimateFaceEncoder triggers importlib.import_module
# internally, which dynamo doesn't support tracing through.
pytest.skip("F.pad with replicate mode triggers unsupported import in torch.compile")
class TestWanAnimateTransformer3DBitsAndBytes(WanAnimateTransformer3DTesterConfig, BitsAndBytesTesterMixin):
"""BitsAndBytes quantization tests for Wan Animate Transformer 3D."""
@property
def torch_dtype(self):
return torch.float16
def get_dummy_inputs(self):
"""Override to provide inputs matching the tiny Wan Animate model dimensions."""
return {
"hidden_states": randn_tensor(
(1, 36, 21, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"encoder_hidden_states": randn_tensor(
(1, 512, 4096), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"encoder_hidden_states_image": randn_tensor(
(1, 257, 1280), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"pose_hidden_states": randn_tensor(
(1, 16, 20, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"face_pixel_values": randn_tensor(
(1, 3, 77, 512, 512), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"timestep": torch.tensor([1.0]).to(torch_device, self.torch_dtype),
}
class TestWanAnimateTransformer3DTorchAo(WanAnimateTransformer3DTesterConfig, TorchAoTesterMixin):
"""TorchAO quantization tests for Wan Animate Transformer 3D."""
@property
def torch_dtype(self):
return torch.bfloat16
def get_dummy_inputs(self):
"""Override to provide inputs matching the tiny Wan Animate model dimensions."""
return {
"hidden_states": randn_tensor(
(1, 36, 21, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"encoder_hidden_states": randn_tensor(
(1, 512, 4096), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"encoder_hidden_states_image": randn_tensor(
(1, 257, 1280), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"pose_hidden_states": randn_tensor(
(1, 16, 20, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"face_pixel_values": randn_tensor(
(1, 3, 77, 512, 512), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"timestep": torch.tensor([1.0]).to(torch_device, self.torch_dtype),
}
class TestWanAnimateTransformer3DGGUF(WanAnimateTransformer3DTesterConfig, GGUFTesterMixin):
"""GGUF quantization tests for Wan Animate Transformer 3D."""
@property
def gguf_filename(self):
return "https://huggingface.co/QuantStack/Wan2.2-Animate-14B-GGUF/blob/main/Wan2.2-Animate-14B-Q2_K.gguf"
@property
def torch_dtype(self):
return torch.bfloat16
def get_dummy_inputs(self):
"""Override to provide inputs matching the real Wan Animate model dimensions.
Wan 2.2 Animate: in_channels=36 (2*16+4), text_dim=4096, image_dim=1280
"""
return {
"hidden_states": randn_tensor(
(1, 36, 21, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"encoder_hidden_states": randn_tensor(
(1, 512, 4096), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"encoder_hidden_states_image": randn_tensor(
(1, 257, 1280), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"pose_hidden_states": randn_tensor(
(1, 16, 20, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"face_pixel_values": randn_tensor(
(1, 3, 77, 512, 512), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"timestep": torch.tensor([1.0]).to(torch_device, self.torch_dtype),
}
class TestWanAnimateTransformer3DGGUFCompile(WanAnimateTransformer3DTesterConfig, GGUFCompileTesterMixin):
"""GGUF + compile tests for Wan Animate Transformer 3D."""
@property
def gguf_filename(self):
return "https://huggingface.co/QuantStack/Wan2.2-Animate-14B-GGUF/blob/main/Wan2.2-Animate-14B-Q2_K.gguf"
@property
def torch_dtype(self):
return torch.bfloat16
def get_dummy_inputs(self):
"""Override to provide inputs matching the real Wan Animate model dimensions.
Wan 2.2 Animate: in_channels=36 (2*16+4), text_dim=4096, image_dim=1280
"""
return {
"hidden_states": randn_tensor(
(1, 36, 21, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"encoder_hidden_states": randn_tensor(
(1, 512, 4096), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"encoder_hidden_states_image": randn_tensor(
(1, 257, 1280), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"pose_hidden_states": randn_tensor(
(1, 16, 20, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"face_pixel_values": randn_tensor(
(1, 3, 77, 512, 512), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"timestep": torch.tensor([1.0]).to(torch_device, self.torch_dtype),
}

View File

@@ -0,0 +1,271 @@
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
import torch
from diffusers import WanVACETransformer3DModel
from diffusers.utils.torch_utils import randn_tensor
from ...testing_utils import enable_full_determinism, torch_device
from ..testing_utils import (
AttentionTesterMixin,
BaseModelTesterConfig,
BitsAndBytesTesterMixin,
GGUFCompileTesterMixin,
GGUFTesterMixin,
MemoryTesterMixin,
ModelTesterMixin,
TorchAoTesterMixin,
TorchCompileTesterMixin,
TrainingTesterMixin,
)
enable_full_determinism()
class WanVACETransformer3DTesterConfig(BaseModelTesterConfig):
@property
def model_class(self):
return WanVACETransformer3DModel
@property
def pretrained_model_name_or_path(self):
return "hf-internal-testing/tiny-wan-vace-transformer"
@property
def output_shape(self) -> tuple[int, ...]:
return (16, 2, 16, 16)
@property
def input_shape(self) -> tuple[int, ...]:
return (16, 2, 16, 16)
@property
def main_input_name(self) -> str:
return "hidden_states"
@property
def generator(self):
return torch.Generator("cpu").manual_seed(0)
def get_init_dict(self) -> dict[str, int | list[int] | tuple | str | bool | None]:
return {
"patch_size": (1, 2, 2),
"num_attention_heads": 2,
"attention_head_dim": 12,
"in_channels": 16,
"out_channels": 16,
"text_dim": 32,
"freq_dim": 256,
"ffn_dim": 32,
"num_layers": 4,
"cross_attn_norm": True,
"qk_norm": "rms_norm_across_heads",
"rope_max_seq_len": 32,
"vace_layers": [0, 2],
"vace_in_channels": 48, # 3 * in_channels = 3 * 16 = 48
}
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
batch_size = 1
num_channels = 16
num_frames = 2
height = 16
width = 16
text_encoder_embedding_dim = 32
sequence_length = 12
# VACE requires control_hidden_states with vace_in_channels (3 * in_channels)
vace_in_channels = 48
return {
"hidden_states": randn_tensor(
(batch_size, num_channels, num_frames, height, width),
generator=self.generator,
device=torch_device,
),
"encoder_hidden_states": randn_tensor(
(batch_size, sequence_length, text_encoder_embedding_dim),
generator=self.generator,
device=torch_device,
),
"control_hidden_states": randn_tensor(
(batch_size, vace_in_channels, num_frames, height, width),
generator=self.generator,
device=torch_device,
),
"timestep": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device),
}
class TestWanVACETransformer3D(WanVACETransformer3DTesterConfig, ModelTesterMixin):
"""Core model tests for Wan VACE Transformer 3D."""
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"])
def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype):
# Skip: fp16/bf16 require very high atol to pass, providing little signal.
# Dtype preservation is already tested by test_from_save_pretrained_dtype and test_keep_in_fp32_modules.
pytest.skip("Tolerance requirements too high for meaningful test")
def test_model_parallelism(self, tmp_path):
# Skip: Device mismatch between cuda:0 and cuda:1 in VACE control flow
pytest.skip("Model parallelism not yet supported for WanVACE")
class TestWanVACETransformer3DMemory(WanVACETransformer3DTesterConfig, MemoryTesterMixin):
"""Memory optimization tests for Wan VACE Transformer 3D."""
class TestWanVACETransformer3DTraining(WanVACETransformer3DTesterConfig, TrainingTesterMixin):
"""Training tests for Wan VACE Transformer 3D."""
def test_gradient_checkpointing_is_applied(self):
expected_set = {"WanVACETransformer3DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class TestWanVACETransformer3DAttention(WanVACETransformer3DTesterConfig, AttentionTesterMixin):
"""Attention processor tests for Wan VACE Transformer 3D."""
class TestWanVACETransformer3DCompile(WanVACETransformer3DTesterConfig, TorchCompileTesterMixin):
"""Torch compile tests for Wan VACE Transformer 3D."""
def test_torch_compile_repeated_blocks(self):
# WanVACE has two block types (WanTransformerBlock and WanVACETransformerBlock),
# so we need recompile_limit=2 instead of the default 1.
import torch._dynamo
import torch._inductor.utils
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict).to(torch_device)
model.eval()
model.compile_repeated_blocks(fullgraph=True)
with (
torch._inductor.utils.fresh_inductor_cache(),
torch._dynamo.config.patch(recompile_limit=2),
):
_ = model(**inputs_dict)
_ = model(**inputs_dict)
class TestWanVACETransformer3DBitsAndBytes(WanVACETransformer3DTesterConfig, BitsAndBytesTesterMixin):
"""BitsAndBytes quantization tests for Wan VACE Transformer 3D."""
@property
def torch_dtype(self):
return torch.float16
def get_dummy_inputs(self):
"""Override to provide inputs matching the tiny Wan VACE model dimensions."""
return {
"hidden_states": randn_tensor(
(1, 16, 2, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"encoder_hidden_states": randn_tensor(
(1, 512, 4096), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"control_hidden_states": randn_tensor(
(1, 96, 2, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"timestep": torch.tensor([1.0]).to(torch_device, self.torch_dtype),
}
class TestWanVACETransformer3DTorchAo(WanVACETransformer3DTesterConfig, TorchAoTesterMixin):
"""TorchAO quantization tests for Wan VACE Transformer 3D."""
@property
def torch_dtype(self):
return torch.bfloat16
def get_dummy_inputs(self):
"""Override to provide inputs matching the tiny Wan VACE model dimensions."""
return {
"hidden_states": randn_tensor(
(1, 16, 2, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"encoder_hidden_states": randn_tensor(
(1, 512, 4096), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"control_hidden_states": randn_tensor(
(1, 96, 2, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"timestep": torch.tensor([1.0]).to(torch_device, self.torch_dtype),
}
class TestWanVACETransformer3DGGUF(WanVACETransformer3DTesterConfig, GGUFTesterMixin):
"""GGUF quantization tests for Wan VACE Transformer 3D."""
@property
def gguf_filename(self):
return "https://huggingface.co/QuantStack/Wan2.1_14B_VACE-GGUF/blob/main/Wan2.1_14B_VACE-Q3_K_S.gguf"
@property
def torch_dtype(self):
return torch.bfloat16
def get_dummy_inputs(self):
"""Override to provide inputs matching the real Wan VACE model dimensions.
Wan 2.1 VACE: in_channels=16, text_dim=4096, vace_in_channels=96
"""
return {
"hidden_states": randn_tensor(
(1, 16, 2, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"encoder_hidden_states": randn_tensor(
(1, 512, 4096), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"control_hidden_states": randn_tensor(
(1, 96, 2, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"timestep": torch.tensor([1.0]).to(torch_device, self.torch_dtype),
}
class TestWanVACETransformer3DGGUFCompile(WanVACETransformer3DTesterConfig, GGUFCompileTesterMixin):
"""GGUF + compile tests for Wan VACE Transformer 3D."""
@property
def gguf_filename(self):
return "https://huggingface.co/QuantStack/Wan2.1_14B_VACE-GGUF/blob/main/Wan2.1_14B_VACE-Q3_K_S.gguf"
@property
def torch_dtype(self):
return torch.bfloat16
def get_dummy_inputs(self):
"""Override to provide inputs matching the real Wan VACE model dimensions.
Wan 2.1 VACE: in_channels=16, text_dim=4096, vace_in_channels=96
"""
return {
"hidden_states": randn_tensor(
(1, 16, 2, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"encoder_hidden_states": randn_tensor(
(1, 512, 4096), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"control_hidden_states": randn_tensor(
(1, 96, 2, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"timestep": torch.tensor([1.0]).to(torch_device, self.torch_dtype),
}

View File

@@ -6,7 +6,7 @@ import pytest
import torch
import diffusers
from diffusers import ComponentsManager, ModularPipeline, ModularPipelineBlocks
from diffusers import AutoModel, ComponentsManager, ModularPipeline, ModularPipelineBlocks
from diffusers.guiders import ClassifierFreeGuidance
from diffusers.modular_pipelines.modular_pipeline_utils import (
ComponentSpec,
@@ -37,6 +37,9 @@ class ModularPipelineTesterMixin:
optional_params = frozenset(["num_inference_steps", "num_images_per_prompt", "latents", "output_type"])
# this is modular specific: generator needs to be a intermediate input because it's mutable
intermediate_params = frozenset(["generator"])
# Output type for the pipeline (e.g., "images" for image pipelines, "videos" for video pipelines)
# Subclasses can override this to change the expected output type
output_name = "images"
def get_generator(self, seed=0):
generator = torch.Generator("cpu").manual_seed(seed)
@@ -163,7 +166,7 @@ class ModularPipelineTesterMixin:
logger.setLevel(level=diffusers.logging.WARNING)
for batch_size, batched_input in zip(batch_sizes, batched_inputs):
output = pipe(**batched_input, output="images")
output = pipe(**batched_input, output=self.output_name)
assert len(output) == batch_size, "Output is different from expected batch size"
def test_inference_batch_single_identical(
@@ -197,12 +200,16 @@ class ModularPipelineTesterMixin:
if "batch_size" in inputs:
batched_inputs["batch_size"] = batch_size
output = pipe(**inputs, output="images")
output_batch = pipe(**batched_inputs, output="images")
output = pipe(**inputs, output=self.output_name)
output_batch = pipe(**batched_inputs, output=self.output_name)
assert output_batch.shape[0] == batch_size
max_diff = torch.abs(output_batch[0] - output[0]).max()
# For batch comparison, we only need to compare the first item
if output_batch.shape[0] == batch_size and output.shape[0] == 1:
output_batch = output_batch[0:1]
max_diff = torch.abs(output_batch - output).max()
assert max_diff < expected_max_diff, "Batch inference results different from single inference results"
@require_accelerator
@@ -217,19 +224,32 @@ class ModularPipelineTesterMixin:
# Reset generator in case it is used inside dummy inputs
if "generator" in inputs:
inputs["generator"] = self.get_generator(0)
output = pipe(**inputs, output="images")
output = pipe(**inputs, output=self.output_name)
fp16_inputs = self.get_dummy_inputs()
# Reset generator in case it is used inside dummy inputs
if "generator" in fp16_inputs:
fp16_inputs["generator"] = self.get_generator(0)
output_fp16 = pipe_fp16(**fp16_inputs, output="images")
output = output.cpu()
output_fp16 = output_fp16.cpu()
output_fp16 = pipe_fp16(**fp16_inputs, output=self.output_name)
max_diff = numpy_cosine_similarity_distance(output.flatten(), output_fp16.flatten())
assert max_diff < expected_max_diff, "FP16 inference is different from FP32 inference"
output_tensor = output.float().cpu()
output_fp16_tensor = output_fp16.float().cpu()
# Check for NaNs in outputs (can happen with tiny models in FP16)
if torch.isnan(output_tensor).any() or torch.isnan(output_fp16_tensor).any():
pytest.skip("FP16 inference produces NaN values - this is a known issue with tiny models")
max_diff = numpy_cosine_similarity_distance(
output_tensor.flatten().numpy(), output_fp16_tensor.flatten().numpy()
)
# Check if cosine similarity is NaN (which can happen if vectors are zero or very small)
if torch.isnan(torch.tensor(max_diff)):
pytest.skip("Cosine similarity is NaN - outputs may be too small for reliable comparison")
assert max_diff < expected_max_diff, f"FP16 inference is different from FP32 inference (max_diff: {max_diff})"
@require_accelerator
def test_to_device(self):
@@ -251,14 +271,16 @@ class ModularPipelineTesterMixin:
def test_inference_is_not_nan_cpu(self):
pipe = self.get_pipeline().to("cpu")
output = pipe(**self.get_dummy_inputs(), output="images")
inputs = self.get_dummy_inputs()
output = pipe(**inputs, output=self.output_name)
assert torch.isnan(output).sum() == 0, "CPU Inference returns NaN"
@require_accelerator
def test_inference_is_not_nan(self):
pipe = self.get_pipeline().to(torch_device)
output = pipe(**self.get_dummy_inputs(), output="images")
inputs = self.get_dummy_inputs()
output = pipe(**inputs, output=self.output_name)
assert torch.isnan(output).sum() == 0, "Accelerator Inference returns NaN"
def test_num_images_per_prompt(self):
@@ -278,7 +300,7 @@ class ModularPipelineTesterMixin:
if key in self.batch_params:
inputs[key] = batch_size * [inputs[key]]
images = pipe(**inputs, num_images_per_prompt=num_images_per_prompt, output="images")
images = pipe(**inputs, num_images_per_prompt=num_images_per_prompt, output=self.output_name)
assert images.shape[0] == batch_size * num_images_per_prompt
@@ -293,8 +315,7 @@ class ModularPipelineTesterMixin:
image_slices = []
for pipe in [base_pipe, offload_pipe]:
inputs = self.get_dummy_inputs()
image = pipe(**inputs, output="images")
image = pipe(**inputs, output=self.output_name)
image_slices.append(image[0, -3:, -3:, -1].flatten())
assert torch.abs(image_slices[0] - image_slices[1]).max() < 1e-3
@@ -315,8 +336,7 @@ class ModularPipelineTesterMixin:
image_slices = []
for pipe in pipes:
inputs = self.get_dummy_inputs()
image = pipe(**inputs, output="images")
image = pipe(**inputs, output=self.output_name)
image_slices.append(image[0, -3:, -3:, -1].flatten())
assert torch.abs(image_slices[0] - image_slices[1]).max() < 1e-3
@@ -331,13 +351,13 @@ class ModularGuiderTesterMixin:
pipe.update_components(guider=guider)
inputs = self.get_dummy_inputs()
out_no_cfg = pipe(**inputs, output="images")
out_no_cfg = pipe(**inputs, output=self.output_name)
# forward pass with CFG applied
guider = ClassifierFreeGuidance(guidance_scale=7.5)
pipe.update_components(guider=guider)
inputs = self.get_dummy_inputs()
out_cfg = pipe(**inputs, output="images")
out_cfg = pipe(**inputs, output=self.output_name)
assert out_cfg.shape == out_no_cfg.shape
max_diff = torch.abs(out_cfg - out_no_cfg).max()
@@ -578,3 +598,68 @@ class TestModularModelCardContent:
content = generate_modular_model_card_content(blocks)
assert "5-block architecture" in content["model_description"]
class TestAutoModelLoadIdTagging:
def test_automodel_tags_load_id(self):
model = AutoModel.from_pretrained("hf-internal-testing/tiny-stable-diffusion-xl-pipe", subfolder="unet")
assert hasattr(model, "_diffusers_load_id"), "Model should have _diffusers_load_id attribute"
assert model._diffusers_load_id != "null", "_diffusers_load_id should not be 'null'"
# Verify load_id contains the expected fields
load_id = model._diffusers_load_id
assert "hf-internal-testing/tiny-stable-diffusion-xl-pipe" in load_id
assert "unet" in load_id
def test_automodel_update_components(self):
pipe = ModularPipeline.from_pretrained("hf-internal-testing/tiny-stable-diffusion-xl-pipe")
pipe.load_components(torch_dtype=torch.float32)
auto_model = AutoModel.from_pretrained("hf-internal-testing/tiny-stable-diffusion-xl-pipe", subfolder="unet")
pipe.update_components(unet=auto_model)
assert pipe.unet is auto_model
assert "unet" in pipe._component_specs
spec = pipe._component_specs["unet"]
assert spec.pretrained_model_name_or_path == "hf-internal-testing/tiny-stable-diffusion-xl-pipe"
assert spec.subfolder == "unet"
class TestLoadComponentsSkipBehavior:
def test_load_components_skips_already_loaded(self):
pipe = ModularPipeline.from_pretrained("hf-internal-testing/tiny-stable-diffusion-xl-pipe")
pipe.load_components(torch_dtype=torch.float32)
original_unet = pipe.unet
pipe.load_components()
# Verify that the unet is the same object (not reloaded)
assert pipe.unet is original_unet, "load_components should skip already loaded components"
def test_load_components_selective_loading(self):
pipe = ModularPipeline.from_pretrained("hf-internal-testing/tiny-stable-diffusion-xl-pipe")
pipe.load_components(names="unet", torch_dtype=torch.float32)
# Verify only requested component was loaded.
assert hasattr(pipe, "unet")
assert pipe.unet is not None
assert getattr(pipe, "vae", None) is None
def test_load_components_skips_invalid_pretrained_path(self):
pipe = ModularPipeline.from_pretrained("hf-internal-testing/tiny-stable-diffusion-xl-pipe")
pipe._component_specs["test_component"] = ComponentSpec(
name="test_component",
type_hint=torch.nn.Module,
pretrained_model_name_or_path=None,
default_creation_method="from_pretrained",
)
pipe.load_components(torch_dtype=torch.float32)
# Verify test_component was not loaded
assert not hasattr(pipe, "test_component") or pipe.test_component is None

View File

View File

@@ -0,0 +1,49 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
from diffusers.modular_pipelines import WanBlocks, WanModularPipeline
from ..test_modular_pipelines_common import ModularPipelineTesterMixin
class TestWanModularPipelineFast(ModularPipelineTesterMixin):
pipeline_class = WanModularPipeline
pipeline_blocks_class = WanBlocks
pretrained_model_name_or_path = "hf-internal-testing/tiny-wan-modular-pipe"
params = frozenset(["prompt", "height", "width", "num_frames"])
batch_params = frozenset(["prompt"])
optional_params = frozenset(["num_inference_steps", "num_videos_per_prompt", "latents"])
output_name = "videos"
def get_dummy_inputs(self, seed=0):
generator = self.get_generator(seed)
inputs = {
"prompt": "A painting of a squirrel eating a burger",
"generator": generator,
"num_inference_steps": 2,
"height": 16,
"width": 16,
"num_frames": 9,
"max_sequence_length": 16,
"output_type": "pt",
}
return inputs
@pytest.mark.skip(reason="num_videos_per_prompt")
def test_num_images_per_prompt(self):
pass

View File

@@ -0,0 +1,44 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from diffusers.modular_pipelines import ZImageAutoBlocks, ZImageModularPipeline
from ..test_modular_pipelines_common import ModularPipelineTesterMixin
class TestZImageModularPipelineFast(ModularPipelineTesterMixin):
pipeline_class = ZImageModularPipeline
pipeline_blocks_class = ZImageAutoBlocks
pretrained_model_name_or_path = "hf-internal-testing/tiny-zimage-modular-pipe"
params = frozenset(["prompt", "height", "width"])
batch_params = frozenset(["prompt"])
def get_dummy_inputs(self, seed=0):
generator = self.get_generator(seed)
inputs = {
"prompt": "A painting of a squirrel eating a burger",
"generator": generator,
"num_inference_steps": 2,
"height": 32,
"width": 32,
"max_sequence_length": 16,
"output_type": "pt",
}
return inputs
def test_inference_batch_single_identical(self):
super().test_inference_batch_single_identical(expected_max_diff=5e-3)

View File

@@ -0,0 +1,386 @@
# Copyright 2025 The HuggingFace Team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import json
import os
import tempfile
import unittest
import numpy as np
import torch
from transformers import Qwen2_5_VLConfig, Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer
from diffusers import (
AutoencoderKLWan,
Cosmos2_5_TransferPipeline,
CosmosControlNetModel,
CosmosTransformer3DModel,
UniPCMultistepScheduler,
)
from ...testing_utils import enable_full_determinism, torch_device
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin, to_np
from .cosmos_guardrail import DummyCosmosSafetyChecker
enable_full_determinism()
class Cosmos2_5_TransferWrapper(Cosmos2_5_TransferPipeline):
@staticmethod
def from_pretrained(*args, **kwargs):
if "safety_checker" not in kwargs or kwargs["safety_checker"] is None:
safety_checker = DummyCosmosSafetyChecker()
device_map = kwargs.get("device_map", "cpu")
torch_dtype = kwargs.get("torch_dtype")
if device_map is not None or torch_dtype is not None:
safety_checker = safety_checker.to(device_map, dtype=torch_dtype)
kwargs["safety_checker"] = safety_checker
return Cosmos2_5_TransferPipeline.from_pretrained(*args, **kwargs)
class Cosmos2_5_TransferPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = Cosmos2_5_TransferWrapper
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
required_optional_params = frozenset(
[
"num_inference_steps",
"generator",
"latents",
"return_dict",
"callback_on_step_end",
"callback_on_step_end_tensor_inputs",
]
)
supports_dduf = False
test_xformers_attention = False
test_layerwise_casting = True
test_group_offloading = True
def get_dummy_components(self):
torch.manual_seed(0)
# Transformer with img_context support for Transfer2.5
transformer = CosmosTransformer3DModel(
in_channels=16 + 1,
out_channels=16,
num_attention_heads=2,
attention_head_dim=16,
num_layers=2,
mlp_ratio=2,
text_embed_dim=32,
adaln_lora_dim=4,
max_size=(4, 32, 32),
patch_size=(1, 2, 2),
rope_scale=(2.0, 1.0, 1.0),
concat_padding_mask=True,
extra_pos_embed_type="learnable",
controlnet_block_every_n=1,
img_context_dim_in=32,
img_context_num_tokens=4,
img_context_dim_out=32,
)
torch.manual_seed(0)
controlnet = CosmosControlNetModel(
n_controlnet_blocks=2,
in_channels=16 + 1 + 1, # control latent channels + condition_mask + padding_mask
latent_channels=16 + 1 + 1, # base latent channels (16) + condition_mask (1) + padding_mask (1) = 18
model_channels=32,
num_attention_heads=2,
attention_head_dim=16,
mlp_ratio=2,
text_embed_dim=32,
adaln_lora_dim=4,
patch_size=(1, 2, 2),
max_size=(4, 32, 32),
rope_scale=(2.0, 1.0, 1.0),
extra_pos_embed_type="learnable", # Match transformer's config
img_context_dim_in=32,
img_context_dim_out=32,
use_crossattn_projection=False, # Test doesn't need this projection
)
torch.manual_seed(0)
vae = AutoencoderKLWan(
base_dim=3,
z_dim=16,
dim_mult=[1, 1, 1, 1],
num_res_blocks=1,
temperal_downsample=[False, True, True],
)
torch.manual_seed(0)
scheduler = UniPCMultistepScheduler()
torch.manual_seed(0)
config = Qwen2_5_VLConfig(
text_config={
"hidden_size": 16,
"intermediate_size": 16,
"num_hidden_layers": 2,
"num_attention_heads": 2,
"num_key_value_heads": 2,
"rope_scaling": {
"mrope_section": [1, 1, 2],
"rope_type": "default",
"type": "default",
},
"rope_theta": 1000000.0,
},
vision_config={
"depth": 2,
"hidden_size": 16,
"intermediate_size": 16,
"num_heads": 2,
"out_hidden_size": 16,
},
hidden_size=16,
vocab_size=152064,
vision_end_token_id=151653,
vision_start_token_id=151652,
vision_token_id=151654,
)
text_encoder = Qwen2_5_VLForConditionalGeneration(config)
tokenizer = Qwen2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration")
components = {
"transformer": transformer,
"controlnet": controlnet,
"vae": vae,
"scheduler": scheduler,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"safety_checker": DummyCosmosSafetyChecker(),
}
return components
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
inputs = {
"prompt": "dance monkey",
"negative_prompt": "bad quality",
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 3.0,
"height": 32,
"width": 32,
"num_frames": 3,
"max_sequence_length": 16,
"output_type": "pt",
}
return inputs
def test_components_function(self):
init_components = self.get_dummy_components()
init_components = {k: v for k, v in init_components.items() if not isinstance(v, (str, int, float))}
pipe = self.pipeline_class(**init_components)
self.assertTrue(hasattr(pipe, "components"))
self.assertTrue(set(pipe.components.keys()) == set(init_components.keys()))
def test_inference(self):
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
video = pipe(**inputs).frames
generated_video = video[0]
self.assertEqual(generated_video.shape, (3, 3, 32, 32))
self.assertTrue(torch.isfinite(generated_video).all())
def test_inference_with_controls(self):
"""Test inference with control inputs (ControlNet)."""
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
# Add control video input - should be a video tensor
inputs["controls"] = [torch.randn(3, 3, 32, 32)] # num_frames, channels, height, width
inputs["controls_conditioning_scale"] = 1.0
video = pipe(**inputs).frames
generated_video = video[0]
self.assertEqual(generated_video.shape, (3, 3, 32, 32))
self.assertTrue(torch.isfinite(generated_video).all())
def test_callback_inputs(self):
sig = inspect.signature(self.pipeline_class.__call__)
has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters
has_callback_step_end = "callback_on_step_end" in sig.parameters
if not (has_callback_tensor_inputs and has_callback_step_end):
return
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
self.assertTrue(
hasattr(pipe, "_callback_tensor_inputs"),
f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs",
)
def callback_inputs_subset(pipe, i, t, callback_kwargs):
for tensor_name in callback_kwargs.keys():
assert tensor_name in pipe._callback_tensor_inputs
return callback_kwargs
def callback_inputs_all(pipe, i, t, callback_kwargs):
for tensor_name in pipe._callback_tensor_inputs:
assert tensor_name in callback_kwargs
for tensor_name in callback_kwargs.keys():
assert tensor_name in pipe._callback_tensor_inputs
return callback_kwargs
inputs = self.get_dummy_inputs(torch_device)
inputs["callback_on_step_end"] = callback_inputs_subset
inputs["callback_on_step_end_tensor_inputs"] = ["latents"]
_ = pipe(**inputs)[0]
inputs["callback_on_step_end"] = callback_inputs_all
inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
_ = pipe(**inputs)[0]
def callback_inputs_change_tensor(pipe, i, t, callback_kwargs):
is_last = i == (pipe.num_timesteps - 1)
if is_last:
callback_kwargs["latents"] = torch.zeros_like(callback_kwargs["latents"])
return callback_kwargs
inputs["callback_on_step_end"] = callback_inputs_change_tensor
inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
output = pipe(**inputs)[0]
assert output.abs().sum() < 1e10
def test_inference_batch_single_identical(self):
self._test_inference_batch_single_identical(batch_size=2, expected_max_diff=1e-2)
def test_attention_slicing_forward_pass(
self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3
):
if not getattr(self, "test_attention_slicing", True):
return
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
for component in pipe.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
generator_device = "cpu"
inputs = self.get_dummy_inputs(generator_device)
output_without_slicing = pipe(**inputs)[0]
pipe.enable_attention_slicing(slice_size=1)
inputs = self.get_dummy_inputs(generator_device)
output_with_slicing1 = pipe(**inputs)[0]
pipe.enable_attention_slicing(slice_size=2)
inputs = self.get_dummy_inputs(generator_device)
output_with_slicing2 = pipe(**inputs)[0]
if test_max_difference:
max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max()
max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max()
self.assertLess(
max(max_diff1, max_diff2),
expected_max_diff,
"Attention slicing should not affect the inference results",
)
def test_serialization_with_variants(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
model_components = [
component_name
for component_name, component in pipe.components.items()
if isinstance(component, torch.nn.Module)
]
# Remove components that aren't saved as standard diffusers models
if "safety_checker" in model_components:
model_components.remove("safety_checker")
variant = "fp16"
with tempfile.TemporaryDirectory() as tmpdir:
pipe.save_pretrained(tmpdir, variant=variant, safe_serialization=False)
with open(f"{tmpdir}/model_index.json", "r") as f:
config = json.load(f)
for subfolder in os.listdir(tmpdir):
if not os.path.isfile(subfolder) and subfolder in model_components:
folder_path = os.path.join(tmpdir, subfolder)
is_folder = os.path.isdir(folder_path) and subfolder in config
assert is_folder and any(p.split(".")[1].startswith(variant) for p in os.listdir(folder_path))
def test_torch_dtype_dict(self):
components = self.get_dummy_components()
if not components:
self.skipTest("No dummy components defined.")
pipe = self.pipeline_class(**components)
specified_key = next(iter(components.keys()))
with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as tmpdirname:
pipe.save_pretrained(tmpdirname, safe_serialization=False)
torch_dtype_dict = {specified_key: torch.bfloat16, "default": torch.float16}
loaded_pipe = self.pipeline_class.from_pretrained(
tmpdirname, safety_checker=DummyCosmosSafetyChecker(), torch_dtype=torch_dtype_dict
)
for name, component in loaded_pipe.components.items():
# Skip components that are not loaded from disk or have special handling
if name == "safety_checker":
continue
if isinstance(component, torch.nn.Module) and hasattr(component, "dtype"):
expected_dtype = torch_dtype_dict.get(name, torch_dtype_dict.get("default", torch.float32))
self.assertEqual(
component.dtype,
expected_dtype,
f"Component '{name}' has dtype {component.dtype} but expected {expected_dtype}",
)
def test_save_load_optional_components(self, expected_max_difference=1e-4):
self.pipeline_class._optional_components.remove("safety_checker")
super().test_save_load_optional_components(expected_max_difference=expected_max_difference)
self.pipeline_class._optional_components.append("safety_checker")
@unittest.skip(
"The pipeline should not be runnable without a safety checker. The test creates a pipeline without passing in "
"a safety checker, which makes the pipeline default to the actual Cosmos Guardrail. The Cosmos Guardrail is "
"too large and slow to run on CI."
)
def test_encode_prompt_works_in_isolation(self):
pass

View File

@@ -281,6 +281,86 @@ class GlmImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
# Should return 4 images (2 prompts × 2 images per prompt)
self.assertEqual(len(images), 4)
def test_prompt_with_prior_token_ids(self):
"""Test that prompt and prior_token_ids can be provided together.
When both are given, the AR generation step is skipped (prior_token_ids is used
directly) and prompt is used to generate prompt_embeds via the glyph encoder.
"""
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)
height, width = 32, 32
# Step 1: Run with prompt only to get prior_token_ids from AR model
generator = torch.Generator(device=device).manual_seed(0)
prior_token_ids, _, _ = pipe.generate_prior_tokens(
prompt="A photo of a cat",
height=height,
width=width,
device=torch.device(device),
generator=torch.Generator(device=device).manual_seed(0),
)
# Step 2: Run with both prompt and prior_token_ids — should not raise
generator = torch.Generator(device=device).manual_seed(0)
inputs_both = {
"prompt": "A photo of a cat",
"prior_token_ids": prior_token_ids,
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 1.5,
"height": height,
"width": width,
"max_sequence_length": 16,
"output_type": "pt",
}
images = pipe(**inputs_both).images
self.assertEqual(len(images), 1)
self.assertEqual(images[0].shape, (3, 32, 32))
def test_check_inputs_rejects_invalid_combinations(self):
"""Test that check_inputs correctly rejects invalid input combinations."""
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)
height, width = 32, 32
# Neither prompt nor prior_token_ids → error
with self.assertRaises(ValueError):
pipe.check_inputs(
prompt=None,
height=height,
width=width,
callback_on_step_end_tensor_inputs=None,
prompt_embeds=torch.randn(1, 16, 32),
)
# prior_token_ids alone without prompt or prompt_embeds → error
with self.assertRaises(ValueError):
pipe.check_inputs(
prompt=None,
height=height,
width=width,
callback_on_step_end_tensor_inputs=None,
prior_token_ids=torch.randint(0, 100, (1, 64)),
)
# prompt + prompt_embeds together → error
with self.assertRaises(ValueError):
pipe.check_inputs(
prompt="A cat",
height=height,
width=width,
callback_on_step_end_tensor_inputs=None,
prompt_embeds=torch.randn(1, 16, 32),
)
@unittest.skip("Needs to be revisited.")
def test_encode_prompt_works_in_isolation(self):
pass

View File

@@ -2406,7 +2406,11 @@ class PipelineTesterMixin:
if name not in [exclude_module_name] and isinstance(component, torch.nn.Module):
# `component.device` prints the `onload_device` type. We should probably override the
# `device` property in `ModelMixin`.
component_device = next(component.parameters())[0].device
# Skip modules with no parameters (e.g., dummy safety checkers with only buffers)
params = list(component.parameters())
if not params:
continue
component_device = params[0].device
self.assertTrue(torch.device(component_device).type == torch.device(offload_device).type)
@require_torch_accelerator

View File

@@ -168,7 +168,7 @@ def assert_tensors_close(
max_diff = abs_diff.max().item()
flat_idx = abs_diff.argmax().item()
max_idx = tuple(torch.unravel_index(torch.tensor(flat_idx), actual.shape).tolist())
max_idx = tuple(idx.item() for idx in torch.unravel_index(torch.tensor(flat_idx), actual.shape))
threshold = atol + rtol * expected.abs()
mismatched = (abs_diff > threshold).sum().item()