mirror of
https://github.com/huggingface/diffusers.git
synced 2026-02-14 06:45:22 +08:00
Compare commits
19 Commits
modular-do
...
fix-i2v-lt
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c458048d09 | ||
|
|
4527dcfad3 | ||
|
|
76af013a41 | ||
|
|
277e305589 | ||
|
|
5f3ea22513 | ||
|
|
427472eb00 | ||
|
|
985d83c948 | ||
|
|
ed77a246c9 | ||
|
|
a1816166a5 | ||
|
|
06a0f98e6e | ||
|
|
d32483913a | ||
|
|
64e2adf8f5 | ||
|
|
c3a4cd14b8 | ||
|
|
4d00980e25 | ||
|
|
5bf248ddd8 | ||
|
|
bedc67c75f | ||
|
|
20efb79d49 | ||
|
|
8933686770 | ||
|
|
baaa8d040b |
5
.github/workflows/pr_tests_gpu.yml
vendored
5
.github/workflows/pr_tests_gpu.yml
vendored
@@ -199,6 +199,11 @@ jobs:
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
# Install pkgs which depend on setuptools<81 for pkg_resources first with no build isolation
|
||||
uv pip install pip==25.2 setuptools==80.10.2
|
||||
uv pip install --no-build-isolation k-diffusion==0.0.12
|
||||
uv pip install --upgrade pip setuptools
|
||||
# Install the rest as normal
|
||||
uv pip install -e ".[quality]"
|
||||
uv pip install peft@git+https://github.com/huggingface/peft.git
|
||||
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
|
||||
5
.github/workflows/push_tests.yml
vendored
5
.github/workflows/push_tests.yml
vendored
@@ -126,6 +126,11 @@ jobs:
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
# Install pkgs which depend on setuptools<81 for pkg_resources first with no build isolation
|
||||
uv pip install pip==25.2 setuptools==80.10.2
|
||||
uv pip install --no-build-isolation k-diffusion==0.0.12
|
||||
uv pip install --upgrade pip setuptools
|
||||
# Install the rest as normal
|
||||
uv pip install -e ".[quality]"
|
||||
uv pip install peft@git+https://github.com/huggingface/peft.git
|
||||
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
|
||||
@@ -106,8 +106,6 @@ video, audio = pipe(
|
||||
output_type="np",
|
||||
return_dict=False,
|
||||
)
|
||||
video = (video * 255).round().astype("uint8")
|
||||
video = torch.from_numpy(video)
|
||||
|
||||
encode_video(
|
||||
video[0],
|
||||
@@ -185,8 +183,6 @@ video, audio = pipe(
|
||||
output_type="np",
|
||||
return_dict=False,
|
||||
)
|
||||
video = (video * 255).round().astype("uint8")
|
||||
video = torch.from_numpy(video)
|
||||
|
||||
encode_video(
|
||||
video[0],
|
||||
|
||||
@@ -29,8 +29,31 @@ text_encoder = AutoModel.from_pretrained(
|
||||
)
|
||||
```
|
||||
|
||||
## Custom models
|
||||
|
||||
[`AutoModel`] also loads models from the [Hub](https://huggingface.co/models) that aren't included in Diffusers. Set `trust_remote_code=True` in [`AutoModel.from_pretrained`] to load custom models.
|
||||
|
||||
A custom model repository needs a Python module with the model class, and a `config.json` with an `auto_map` entry that maps `"AutoModel"` to `"module_file.ClassName"`.
|
||||
|
||||
```
|
||||
custom/custom-transformer-model/
|
||||
├── config.json
|
||||
├── my_model.py
|
||||
└── diffusion_pytorch_model.safetensors
|
||||
```
|
||||
|
||||
The `config.json` includes the `auto_map` field pointing to the custom class.
|
||||
|
||||
```json
|
||||
{
|
||||
"auto_map": {
|
||||
"AutoModel": "my_model.MyCustomModel"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Then load it with `trust_remote_code=True`.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import AutoModel
|
||||
@@ -40,7 +63,39 @@ transformer = AutoModel.from_pretrained(
|
||||
)
|
||||
```
|
||||
|
||||
For a real-world example, [Overworld/Waypoint-1-Small](https://huggingface.co/Overworld/Waypoint-1-Small/tree/main/transformer) hosts a custom `WorldModel` class across several modules in its `transformer` subfolder.
|
||||
|
||||
```
|
||||
transformer/
|
||||
├── config.json # auto_map: "model.WorldModel"
|
||||
├── model.py
|
||||
├── attn.py
|
||||
├── nn.py
|
||||
├── cache.py
|
||||
├── quantize.py
|
||||
├── __init__.py
|
||||
└── diffusion_pytorch_model.safetensors
|
||||
```
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import AutoModel
|
||||
|
||||
transformer = AutoModel.from_pretrained(
|
||||
"Overworld/Waypoint-1-Small", subfolder="transformer", trust_remote_code=True, torch_dtype=torch.bfloat16, device_map="cuda"
|
||||
)
|
||||
```
|
||||
|
||||
If the custom model inherits from the [`ModelMixin`] class, it gets access to the same features as Diffusers model classes, like [regional compilation](../optimization/fp16#regional-compilation) and [group offloading](../optimization/memory#group-offloading).
|
||||
|
||||
> [!WARNING]
|
||||
> As a precaution with `trust_remote_code=True`, pass a commit hash to the `revision` argument in [`AutoModel.from_pretrained`] to make sure the code hasn't been updated with new malicious code (unless you fully trust the model owners).
|
||||
>
|
||||
> ```py
|
||||
> transformer = AutoModel.from_pretrained(
|
||||
> "Overworld/Waypoint-1-Small", subfolder="transformer", trust_remote_code=True, revision="a3d8cb2"
|
||||
> )
|
||||
> ```
|
||||
|
||||
> [!NOTE]
|
||||
> Learn more about implementing custom models in the [Community components](../using-diffusers/custom_pipeline_overview#community-components) guide.
|
||||
347
examples/dreambooth/README_z_image.md
Normal file
347
examples/dreambooth/README_z_image.md
Normal file
@@ -0,0 +1,347 @@
|
||||
# DreamBooth training example for Z-Image
|
||||
|
||||
[DreamBooth](https://huggingface.co/papers/2208.12242) is a method to personalize image generation models given just a few (3~5) images of a subject/concept.
|
||||
[LoRA](https://huggingface.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora) is a popular parameter-efficient fine-tuning technique that allows you to achieve full-finetuning like performance but with a fraction of learnable parameters.
|
||||
|
||||
The `train_dreambooth_lora_z_image.py` script shows how to implement the training procedure for [LoRAs](https://huggingface.co/blog/lora) and adapt it for [Z-Image](https://huggingface.co/Tongyi-MAI/Z-Image).
|
||||
|
||||
> [!NOTE]
|
||||
> **About Z-Image**
|
||||
>
|
||||
> Z-Image is a high-quality text-to-image generation model from Alibaba's Tongyi Lab. It uses a DiT (Diffusion Transformer) architecture with Qwen3 as the text encoder. The model excels at generating images with accurate text rendering, especially for Chinese characters.
|
||||
|
||||
> [!NOTE]
|
||||
> **Memory consumption**
|
||||
>
|
||||
> Z-Image is relatively memory efficient compared to other large-scale diffusion models. Below we provide some tips and tricks to further reduce memory consumption during training.
|
||||
|
||||
## Running locally with PyTorch
|
||||
|
||||
### Installing the dependencies
|
||||
|
||||
Before running the scripts, make sure to install the library's training dependencies:
|
||||
|
||||
**Important**
|
||||
|
||||
To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/huggingface/diffusers
|
||||
cd diffusers
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
Then cd in the `examples/dreambooth` folder and run
|
||||
```bash
|
||||
pip install -r requirements_z_image.txt
|
||||
```
|
||||
|
||||
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
|
||||
|
||||
```bash
|
||||
accelerate config
|
||||
```
|
||||
|
||||
Or for a default accelerate configuration without answering questions about your environment
|
||||
|
||||
```bash
|
||||
accelerate config default
|
||||
```
|
||||
|
||||
Or if your environment doesn't support an interactive shell (e.g., a notebook)
|
||||
|
||||
```python
|
||||
from accelerate.utils import write_basic_config
|
||||
write_basic_config()
|
||||
```
|
||||
|
||||
When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
|
||||
Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.
|
||||
|
||||
|
||||
### Dog toy example
|
||||
|
||||
Now let's get our dataset. For this example we will use some dog images: https://huggingface.co/datasets/diffusers/dog-example.
|
||||
|
||||
Let's first download it locally:
|
||||
|
||||
```python
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
local_dir = "./dog"
|
||||
snapshot_download(
|
||||
"diffusers/dog-example",
|
||||
local_dir=local_dir, repo_type="dataset",
|
||||
ignore_patterns=".gitattributes",
|
||||
)
|
||||
```
|
||||
|
||||
This will also allow us to push the trained LoRA parameters to the Hugging Face Hub platform.
|
||||
|
||||
## Memory Optimizations
|
||||
|
||||
> [!NOTE]
|
||||
> Many of these techniques complement each other and can be used together to further reduce memory consumption. However some techniques may be mutually exclusive so be sure to check before launching a training run.
|
||||
|
||||
### CPU Offloading
|
||||
To offload parts of the model to CPU memory, you can use `--offload` flag. This will offload the VAE and text encoder to CPU memory and only move them to GPU when needed.
|
||||
|
||||
### Latent Caching
|
||||
Pre-encode the training images with the VAE, and then delete it to free up some memory. To enable `latent_caching` simply pass `--cache_latents`.
|
||||
|
||||
### QLoRA: Low Precision Training with Quantization
|
||||
Perform low precision training using 8-bit or 4-bit quantization to reduce memory usage. You can use the following flags:
|
||||
|
||||
- **FP8 training** with `torchao`:
|
||||
Enable FP8 training by passing `--do_fp8_training`.
|
||||
> [!IMPORTANT]
|
||||
> Since we are utilizing FP8 tensor cores we need CUDA GPUs with compute capability at least 8.9 or greater. If you're looking for memory-efficient training on relatively older cards, we encourage you to check out other trainers.
|
||||
|
||||
- **NF4 training** with `bitsandbytes`:
|
||||
Alternatively, you can use 8-bit or 4-bit quantization with `bitsandbytes` by passing `--bnb_quantization_config_path` to enable 4-bit NF4 quantization.
|
||||
|
||||
### Gradient Checkpointing and Accumulation
|
||||
* `--gradient_accumulation` refers to the number of updates steps to accumulate before performing a backward/update pass. By passing a value > 1 you can reduce the amount of backward/update passes and hence also memory requirements.
|
||||
* With `--gradient_checkpointing` we can save memory by not storing all intermediate activations during the forward pass. Instead, only a subset of these activations (the checkpoints) are stored and the rest is recomputed as needed during the backward pass. Note that this comes at the expense of a slower backward pass.
|
||||
|
||||
### 8-bit-Adam Optimizer
|
||||
When training with `AdamW` (doesn't apply to `prodigy`) you can pass `--use_8bit_adam` to reduce the memory requirements of training. Make sure to install `bitsandbytes` if you want to do so.
|
||||
|
||||
### Image Resolution
|
||||
An easy way to mitigate some of the memory requirements is through `--resolution`. `--resolution` refers to the resolution for input images, all the images in the train/validation dataset are resized to this.
|
||||
Note that by default, images are resized to resolution of 1024, but it's good to keep in mind in case you're training on higher resolutions.
|
||||
|
||||
### Precision of saved LoRA layers
|
||||
By default, trained transformer layers are saved in the precision dtype in which training was performed. E.g. when training in mixed precision is enabled with `--mixed_precision="bf16"`, final finetuned layers will be saved in `torch.bfloat16` as well.
|
||||
This reduces memory requirements significantly without a significant quality loss. Note that if you do wish to save the final layers in float32 at the expense of more memory usage, you can do so by passing `--upcast_before_saving`.
|
||||
|
||||
## Training Examples
|
||||
|
||||
### Z-Image Training
|
||||
|
||||
To perform DreamBooth with LoRA on Z-Image, run:
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="Tongyi-MAI/Z-Image"
|
||||
export INSTANCE_DIR="dog"
|
||||
export OUTPUT_DIR="trained-z-image-lora"
|
||||
|
||||
accelerate launch train_dreambooth_lora_z_image.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||
--instance_data_dir=$INSTANCE_DIR \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--mixed_precision="bf16" \
|
||||
--gradient_checkpointing \
|
||||
--cache_latents \
|
||||
--instance_prompt="a photo of sks dog" \
|
||||
--resolution=1024 \
|
||||
--train_batch_size=1 \
|
||||
--guidance_scale=5.0 \
|
||||
--use_8bit_adam \
|
||||
--gradient_accumulation_steps=4 \
|
||||
--optimizer="adamW" \
|
||||
--learning_rate=1e-4 \
|
||||
--report_to="wandb" \
|
||||
--lr_scheduler="constant" \
|
||||
--lr_warmup_steps=100 \
|
||||
--max_train_steps=500 \
|
||||
--validation_prompt="A photo of sks dog in a bucket" \
|
||||
--validation_epochs=25 \
|
||||
--seed="0" \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
To better track our training experiments, we're using the following flags in the command above:
|
||||
|
||||
* `report_to="wandb"` will ensure the training runs are tracked on [Weights and Biases](https://wandb.ai/site). To use it, be sure to install `wandb` with `pip install wandb`. Don't forget to call `wandb login <your_api_key>` before training if you haven't done it before.
|
||||
* `validation_prompt` and `validation_epochs` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.
|
||||
|
||||
> [!NOTE]
|
||||
> If you want to train using long prompts, you can use `--max_sequence_length` to set the token limit. The default is 512. Note that this will use more resources and may slow down the training in some cases.
|
||||
|
||||
### Training with FP8 Quantization
|
||||
|
||||
For reduced memory usage with FP8 training:
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="Tongyi-MAI/Z-Image"
|
||||
export INSTANCE_DIR="dog"
|
||||
export OUTPUT_DIR="trained-z-image-lora-fp8"
|
||||
|
||||
accelerate launch train_dreambooth_lora_z_image.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||
--instance_data_dir=$INSTANCE_DIR \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--do_fp8_training \
|
||||
--gradient_checkpointing \
|
||||
--cache_latents \
|
||||
--instance_prompt="a photo of sks dog" \
|
||||
--resolution=1024 \
|
||||
--train_batch_size=1 \
|
||||
--guidance_scale=5.0 \
|
||||
--use_8bit_adam \
|
||||
--gradient_accumulation_steps=4 \
|
||||
--optimizer="adamW" \
|
||||
--learning_rate=1e-4 \
|
||||
--report_to="wandb" \
|
||||
--lr_scheduler="constant" \
|
||||
--lr_warmup_steps=100 \
|
||||
--max_train_steps=500 \
|
||||
--validation_prompt="A photo of sks dog in a bucket" \
|
||||
--validation_epochs=25 \
|
||||
--seed="0" \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
### FSDP on the transformer
|
||||
|
||||
By setting the accelerate configuration with FSDP, the transformer block will be wrapped automatically. E.g. set the configuration to:
|
||||
|
||||
```yaml
|
||||
distributed_type: FSDP
|
||||
fsdp_config:
|
||||
fsdp_version: 2
|
||||
fsdp_offload_params: false
|
||||
fsdp_sharding_strategy: HYBRID_SHARD
|
||||
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
fsdp_transformer_layer_cls_to_wrap: ZImageTransformerBlock
|
||||
fsdp_forward_prefetch: true
|
||||
fsdp_sync_module_states: false
|
||||
fsdp_state_dict_type: FULL_STATE_DICT
|
||||
fsdp_use_orig_params: false
|
||||
fsdp_activation_checkpointing: true
|
||||
fsdp_reshard_after_forward: true
|
||||
fsdp_cpu_ram_efficient_loading: false
|
||||
```
|
||||
|
||||
### Prodigy Optimizer
|
||||
|
||||
Prodigy is an adaptive optimizer that dynamically adjusts the learning rate learned parameters based on past gradients, allowing for more efficient convergence.
|
||||
By using prodigy we can "eliminate" the need for manual learning rate tuning. Read more [here](https://huggingface.co/blog/sdxl_lora_advanced_script#adaptive-optimizers).
|
||||
|
||||
To use prodigy, first make sure to install the prodigyopt library: `pip install prodigyopt`, and then specify:
|
||||
```bash
|
||||
--optimizer="prodigy"
|
||||
```
|
||||
|
||||
> [!TIP]
|
||||
> When using prodigy it's generally good practice to set `--learning_rate=1.0`
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="Tongyi-MAI/Z-Image"
|
||||
export INSTANCE_DIR="dog"
|
||||
export OUTPUT_DIR="trained-z-image-lora-prodigy"
|
||||
|
||||
accelerate launch train_dreambooth_lora_z_image.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||
--instance_data_dir=$INSTANCE_DIR \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--mixed_precision="bf16" \
|
||||
--gradient_checkpointing \
|
||||
--cache_latents \
|
||||
--instance_prompt="a photo of sks dog" \
|
||||
--resolution=1024 \
|
||||
--train_batch_size=1 \
|
||||
--guidance_scale=5.0 \
|
||||
--gradient_accumulation_steps=4 \
|
||||
--optimizer="prodigy" \
|
||||
--learning_rate=1.0 \
|
||||
--report_to="wandb" \
|
||||
--lr_scheduler="constant_with_warmup" \
|
||||
--lr_warmup_steps=100 \
|
||||
--max_train_steps=500 \
|
||||
--validation_prompt="A photo of sks dog in a bucket" \
|
||||
--validation_epochs=25 \
|
||||
--seed="0" \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
### LoRA Rank and Alpha
|
||||
|
||||
Two key LoRA hyperparameters are LoRA rank and LoRA alpha:
|
||||
|
||||
- `--rank`: Defines the dimension of the trainable LoRA matrices. A higher rank means more expressiveness and capacity to learn (and more parameters).
|
||||
- `--lora_alpha`: A scaling factor for the LoRA's output. The LoRA update is scaled by `lora_alpha / lora_rank`.
|
||||
|
||||
**lora_alpha vs. rank:**
|
||||
|
||||
This ratio dictates the LoRA's effective strength:
|
||||
- `lora_alpha == rank`: Scaling factor is 1. The LoRA is applied with its learned strength. (e.g., alpha=16, rank=16)
|
||||
- `lora_alpha < rank`: Scaling factor < 1. Reduces the LoRA's impact. Useful for subtle changes or to prevent overpowering the base model. (e.g., alpha=8, rank=16)
|
||||
- `lora_alpha > rank`: Scaling factor > 1. Amplifies the LoRA's impact. Allows a lower rank LoRA to have a stronger effect. (e.g., alpha=32, rank=16)
|
||||
|
||||
> [!TIP]
|
||||
> A common starting point is to set `lora_alpha` equal to `rank`.
|
||||
> Some also set `lora_alpha` to be twice the `rank` (e.g., lora_alpha=32 for lora_rank=16)
|
||||
> to give the LoRA updates more influence without increasing parameter count.
|
||||
> If you find your LoRA is "overcooking" or learning too aggressively, consider setting `lora_alpha` to half of `rank`
|
||||
> (e.g., lora_alpha=8 for rank=16). Experimentation is often key to finding the optimal balance for your use case.
|
||||
|
||||
### Target Modules
|
||||
|
||||
When LoRA was first adapted from language models to diffusion models, it was applied to the cross-attention layers in the UNet that relate the image representations with the prompts that describe them.
|
||||
More recently, SOTA text-to-image diffusion models replaced the UNet with a diffusion Transformer (DiT). With this change, we may also want to explore applying LoRA training onto different types of layers and blocks.
|
||||
|
||||
To allow more flexibility and control over the targeted modules we added `--lora_layers`, in which you can specify in a comma separated string the exact modules for LoRA training. Here are some examples of target modules you can provide:
|
||||
|
||||
- For attention only layers: `--lora_layers="to_k,to_q,to_v,to_out.0"`
|
||||
- For attention and feed-forward layers: `--lora_layers="to_k,to_q,to_v,to_out.0,ff.net.0.proj,ff.net.2"`
|
||||
|
||||
> [!NOTE]
|
||||
> `--lora_layers` can also be used to specify which **blocks** to apply LoRA training to. To do so, simply add a block prefix to each layer in the comma separated string.
|
||||
|
||||
> [!NOTE]
|
||||
> Keep in mind that while training more layers can improve quality and expressiveness, it also increases the size of the output LoRA weights.
|
||||
|
||||
### Aspect Ratio Bucketing
|
||||
|
||||
We've added aspect ratio bucketing support which allows training on images with different aspect ratios without cropping them to a single square resolution. This technique helps preserve the original composition of training images and can improve training efficiency.
|
||||
|
||||
To enable aspect ratio bucketing, pass `--aspect_ratio_buckets` argument with a semicolon-separated list of height,width pairs, such as:
|
||||
|
||||
```bash
|
||||
--aspect_ratio_buckets="672,1568;688,1504;720,1456;752,1392;800,1328;832,1248;880,1184;944,1104;1024,1024;1104,944;1184,880;1248,832;1328,800;1392,752;1456,720;1504,688;1568,672"
|
||||
```
|
||||
|
||||
### Bilingual Prompts
|
||||
|
||||
Z-Image has strong support for both Chinese and English prompts. When training with Chinese prompts, ensure your dataset captions are properly encoded in UTF-8:
|
||||
|
||||
```bash
|
||||
--instance_prompt="一只sks狗的照片"
|
||||
--validation_prompt="一只sks狗在桶里的照片"
|
||||
```
|
||||
|
||||
> [!TIP]
|
||||
> Z-Image excels at text rendering in generated images, especially for Chinese characters. If your use case involves generating images with text, consider including text-related examples in your training data.
|
||||
|
||||
## Inference
|
||||
|
||||
Once you have trained a LoRA, you can load it for inference:
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import ZImagePipeline
|
||||
|
||||
pipe = ZImagePipeline.from_pretrained("Tongyi-MAI/Z-Image", torch_dtype=torch.bfloat16)
|
||||
pipe.to("cuda")
|
||||
|
||||
# Load your trained LoRA
|
||||
pipe.load_lora_weights("path/to/your/trained-z-image-lora")
|
||||
|
||||
# Generate an image
|
||||
image = pipe(
|
||||
prompt="A photo of sks dog in a bucket",
|
||||
height=1024,
|
||||
width=1024,
|
||||
num_inference_steps=50,
|
||||
guidance_scale=5.0,
|
||||
generator=torch.Generator("cuda").manual_seed(42),
|
||||
).images[0]
|
||||
|
||||
image.save("output.png")
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
Since Z-Image finetuning is still in an experimental phase, we encourage you to explore different settings and share your insights! 🤗
|
||||
1912
examples/dreambooth/train_dreambooth_lora_z_image.py
Normal file
1912
examples/dreambooth/train_dreambooth_lora_z_image.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -78,12 +78,67 @@ python scripts/convert_cosmos_to_diffusers.py \
|
||||
--save_pipeline
|
||||
```
|
||||
|
||||
# Cosmos 2.5 Transfer
|
||||
|
||||
Download checkpoint
|
||||
```bash
|
||||
hf download nvidia/Cosmos-Transfer2.5-2B
|
||||
```
|
||||
|
||||
Convert checkpoint
|
||||
```bash
|
||||
# depth
|
||||
transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Transfer2.5-2B/snapshots/eb5325b77d358944da58a690157dd2b8071bbf85/general/depth/626e6618-bfcd-4d9a-a077-1409e2ce353f_ema_bf16.pt
|
||||
|
||||
python scripts/convert_cosmos_to_diffusers.py \
|
||||
--transformer_type Cosmos-2.5-Transfer-General-2B \
|
||||
--transformer_ckpt_path $transformer_ckpt_path \
|
||||
--vae_type wan2.1 \
|
||||
--output_path converted/transfer/2b/general/depth \
|
||||
--save_pipeline
|
||||
|
||||
# edge
|
||||
transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Transfer2.5-2B/snapshots/eb5325b77d358944da58a690157dd2b8071bbf85/general/edge/61f5694b-0ad5-4ecd-8ad7-c8545627d125_ema_bf16.pt
|
||||
|
||||
python scripts/convert_cosmos_to_diffusers.py \
|
||||
--transformer_type Cosmos-2.5-Transfer-General-2B \
|
||||
--transformer_ckpt_path $transformer_ckpt_path \
|
||||
--vae_type wan2.1 \
|
||||
--output_path converted/transfer/2b/general/edge/pipeline \
|
||||
--save_pipeline
|
||||
|
||||
python scripts/convert_cosmos_to_diffusers.py \
|
||||
--transformer_type Cosmos-2.5-Transfer-General-2B \
|
||||
--transformer_ckpt_path $transformer_ckpt_path \
|
||||
--vae_type wan2.1 \
|
||||
--output_path converted/transfer/2b/general/edge/models
|
||||
|
||||
# blur
|
||||
transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Transfer2.5-2B/snapshots/eb5325b77d358944da58a690157dd2b8071bbf85/general/blur/ba2f44f2-c726-4fe7-949f-597069d9b91c_ema_bf16.pt
|
||||
|
||||
python scripts/convert_cosmos_to_diffusers.py \
|
||||
--transformer_type Cosmos-2.5-Transfer-General-2B \
|
||||
--transformer_ckpt_path $transformer_ckpt_path \
|
||||
--vae_type wan2.1 \
|
||||
--output_path converted/transfer/2b/general/blur \
|
||||
--save_pipeline
|
||||
|
||||
# seg
|
||||
transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Transfer2.5-2B/snapshots/eb5325b77d358944da58a690157dd2b8071bbf85/general/seg/5136ef49-6d8d-42e8-8abf-7dac722a304a_ema_bf16.pt
|
||||
|
||||
python scripts/convert_cosmos_to_diffusers.py \
|
||||
--transformer_type Cosmos-2.5-Transfer-General-2B \
|
||||
--transformer_ckpt_path $transformer_ckpt_path \
|
||||
--vae_type wan2.1 \
|
||||
--output_path converted/transfer/2b/general/seg \
|
||||
--save_pipeline
|
||||
```
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import pathlib
|
||||
import sys
|
||||
from typing import Any, Dict
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
from accelerate import init_empty_weights
|
||||
@@ -95,6 +150,7 @@ from diffusers import (
|
||||
AutoencoderKLWan,
|
||||
Cosmos2TextToImagePipeline,
|
||||
Cosmos2VideoToWorldPipeline,
|
||||
CosmosControlNetModel,
|
||||
CosmosTextToWorldPipeline,
|
||||
CosmosTransformer3DModel,
|
||||
CosmosVideoToWorldPipeline,
|
||||
@@ -103,6 +159,7 @@ from diffusers import (
|
||||
UniPCMultistepScheduler,
|
||||
)
|
||||
from diffusers.pipelines.cosmos.pipeline_cosmos2_5_predict import Cosmos2_5_PredictBasePipeline
|
||||
from diffusers.pipelines.cosmos.pipeline_cosmos2_5_transfer import Cosmos2_5_TransferPipeline
|
||||
|
||||
|
||||
def remove_keys_(key: str, state_dict: Dict[str, Any]):
|
||||
@@ -356,8 +413,62 @@ TRANSFORMER_CONFIGS = {
|
||||
"crossattn_proj_in_channels": 100352,
|
||||
"encoder_hidden_states_channels": 1024,
|
||||
},
|
||||
"Cosmos-2.5-Transfer-General-2B": {
|
||||
"in_channels": 16 + 1,
|
||||
"out_channels": 16,
|
||||
"num_attention_heads": 16,
|
||||
"attention_head_dim": 128,
|
||||
"num_layers": 28,
|
||||
"mlp_ratio": 4.0,
|
||||
"text_embed_dim": 1024,
|
||||
"adaln_lora_dim": 256,
|
||||
"max_size": (128, 240, 240),
|
||||
"patch_size": (1, 2, 2),
|
||||
"rope_scale": (1.0, 3.0, 3.0),
|
||||
"concat_padding_mask": True,
|
||||
"extra_pos_embed_type": None,
|
||||
"use_crossattn_projection": True,
|
||||
"crossattn_proj_in_channels": 100352,
|
||||
"encoder_hidden_states_channels": 1024,
|
||||
"controlnet_block_every_n": 7,
|
||||
"img_context_dim_in": 1152,
|
||||
"img_context_dim_out": 2048,
|
||||
"img_context_num_tokens": 256,
|
||||
},
|
||||
}
|
||||
|
||||
CONTROLNET_CONFIGS = {
|
||||
"Cosmos-2.5-Transfer-General-2B": {
|
||||
"n_controlnet_blocks": 4,
|
||||
"model_channels": 2048,
|
||||
"in_channels": 130,
|
||||
"latent_channels": 18, # (16 latent + 1 condition_mask) + 1 padding_mask = 18
|
||||
"num_attention_heads": 16,
|
||||
"attention_head_dim": 128,
|
||||
"mlp_ratio": 4.0,
|
||||
"text_embed_dim": 1024,
|
||||
"adaln_lora_dim": 256,
|
||||
"patch_size": (1, 2, 2),
|
||||
"max_size": (128, 240, 240),
|
||||
"rope_scale": (1.0, 3.0, 3.0),
|
||||
"extra_pos_embed_type": None,
|
||||
"img_context_dim_in": 1152,
|
||||
"img_context_dim_out": 2048,
|
||||
"use_crossattn_projection": True,
|
||||
"crossattn_proj_in_channels": 100352,
|
||||
"encoder_hidden_states_channels": 1024,
|
||||
},
|
||||
}
|
||||
|
||||
CONTROLNET_KEYS_RENAME_DICT = {
|
||||
**TRANSFORMER_KEYS_RENAME_DICT_COSMOS_2_0,
|
||||
"blocks": "blocks",
|
||||
"control_embedder.proj.1": "patch_embed.proj",
|
||||
}
|
||||
|
||||
|
||||
CONTROLNET_SPECIAL_KEYS_REMAP = {**TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_2_0}
|
||||
|
||||
VAE_KEYS_RENAME_DICT = {
|
||||
"down.0": "down_blocks.0",
|
||||
"down.1": "down_blocks.1",
|
||||
@@ -447,9 +558,12 @@ def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]:
|
||||
return state_dict
|
||||
|
||||
|
||||
def convert_transformer(transformer_type: str, ckpt_path: str, weights_only: bool = True):
|
||||
def convert_transformer(
|
||||
transformer_type: str,
|
||||
state_dict: Optional[Dict[str, Any]] = None,
|
||||
weights_only: bool = True,
|
||||
):
|
||||
PREFIX_KEY = "net."
|
||||
original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", weights_only=weights_only))
|
||||
|
||||
if "Cosmos-1.0" in transformer_type:
|
||||
TRANSFORMER_KEYS_RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT_COSMOS_1_0
|
||||
@@ -467,23 +581,29 @@ def convert_transformer(transformer_type: str, ckpt_path: str, weights_only: boo
|
||||
config = TRANSFORMER_CONFIGS[transformer_type]
|
||||
transformer = CosmosTransformer3DModel(**config)
|
||||
|
||||
for key in list(original_state_dict.keys()):
|
||||
old2new = {}
|
||||
new2old = {}
|
||||
for key in list(state_dict.keys()):
|
||||
new_key = key[:]
|
||||
if new_key.startswith(PREFIX_KEY):
|
||||
new_key = new_key.removeprefix(PREFIX_KEY)
|
||||
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
|
||||
new_key = new_key.replace(replace_key, rename_key)
|
||||
print(key, "->", new_key, flush=True)
|
||||
update_state_dict_(original_state_dict, key, new_key)
|
||||
assert new_key not in new2old, f"new key {new_key} already mapped"
|
||||
assert key not in old2new, f"old key {key} already mapped"
|
||||
old2new[key] = new_key
|
||||
new2old[new_key] = key
|
||||
update_state_dict_(state_dict, key, new_key)
|
||||
|
||||
for key in list(original_state_dict.keys()):
|
||||
for key in list(state_dict.keys()):
|
||||
for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
|
||||
if special_key not in key:
|
||||
continue
|
||||
handler_fn_inplace(key, original_state_dict)
|
||||
handler_fn_inplace(key, state_dict)
|
||||
|
||||
expected_keys = set(transformer.state_dict().keys())
|
||||
mapped_keys = set(original_state_dict.keys())
|
||||
mapped_keys = set(state_dict.keys())
|
||||
missing_keys = expected_keys - mapped_keys
|
||||
unexpected_keys = mapped_keys - expected_keys
|
||||
if missing_keys:
|
||||
@@ -497,10 +617,86 @@ def convert_transformer(transformer_type: str, ckpt_path: str, weights_only: boo
|
||||
print(k)
|
||||
sys.exit(2)
|
||||
|
||||
transformer.load_state_dict(original_state_dict, strict=True, assign=True)
|
||||
transformer.load_state_dict(state_dict, strict=True, assign=True)
|
||||
return transformer
|
||||
|
||||
|
||||
def convert_controlnet(
|
||||
transformer_type: str,
|
||||
control_state_dict: Dict[str, Any],
|
||||
base_state_dict: Dict[str, Any],
|
||||
weights_only: bool = True,
|
||||
):
|
||||
"""
|
||||
Convert controlnet weights.
|
||||
|
||||
Args:
|
||||
transformer_type: The type of transformer/controlnet
|
||||
control_state_dict: State dict containing controlnet-specific weights
|
||||
base_state_dict: State dict containing base transformer weights (for shared modules)
|
||||
weights_only: Whether to use weights_only loading
|
||||
"""
|
||||
if transformer_type not in CONTROLNET_CONFIGS:
|
||||
raise AssertionError(f"{transformer_type} does not define a ControlNet config")
|
||||
|
||||
PREFIX_KEY = "net."
|
||||
|
||||
# Process control-specific keys
|
||||
for key in list(control_state_dict.keys()):
|
||||
new_key = key[:]
|
||||
if new_key.startswith(PREFIX_KEY):
|
||||
new_key = new_key.removeprefix(PREFIX_KEY)
|
||||
for replace_key, rename_key in CONTROLNET_KEYS_RENAME_DICT.items():
|
||||
new_key = new_key.replace(replace_key, rename_key)
|
||||
update_state_dict_(control_state_dict, key, new_key)
|
||||
|
||||
for key in list(control_state_dict.keys()):
|
||||
for special_key, handler_fn_inplace in CONTROLNET_SPECIAL_KEYS_REMAP.items():
|
||||
if special_key not in key:
|
||||
continue
|
||||
handler_fn_inplace(key, control_state_dict)
|
||||
|
||||
# Copy shared weights from base transformer to controlnet
|
||||
# These are the duplicated modules: patch_embed_base, time_embed, learnable_pos_embed, img_context_proj, crossattn_proj
|
||||
shared_module_mappings = {
|
||||
# transformer key prefix -> controlnet key prefix
|
||||
"patch_embed.": "patch_embed_base.",
|
||||
"time_embed.": "time_embed.",
|
||||
"learnable_pos_embed.": "learnable_pos_embed.",
|
||||
"img_context_proj.": "img_context_proj.",
|
||||
"crossattn_proj.": "crossattn_proj.",
|
||||
}
|
||||
|
||||
for key in list(base_state_dict.keys()):
|
||||
for transformer_prefix, controlnet_prefix in shared_module_mappings.items():
|
||||
if key.startswith(transformer_prefix):
|
||||
controlnet_key = controlnet_prefix + key[len(transformer_prefix) :]
|
||||
control_state_dict[controlnet_key] = base_state_dict[key].clone()
|
||||
print(f"Copied shared weight: {key} -> {controlnet_key}", flush=True)
|
||||
break
|
||||
|
||||
cfg = CONTROLNET_CONFIGS[transformer_type]
|
||||
controlnet = CosmosControlNetModel(**cfg)
|
||||
|
||||
expected_keys = set(controlnet.state_dict().keys())
|
||||
mapped_keys = set(control_state_dict.keys())
|
||||
missing_keys = expected_keys - mapped_keys
|
||||
unexpected_keys = mapped_keys - expected_keys
|
||||
if missing_keys:
|
||||
print(f"WARNING: missing controlnet keys ({len(missing_keys)}):", file=sys.stderr, flush=True)
|
||||
for k in sorted(missing_keys):
|
||||
print(k, file=sys.stderr)
|
||||
sys.exit(3)
|
||||
if unexpected_keys:
|
||||
print(f"WARNING: unexpected controlnet keys ({len(unexpected_keys)}):", file=sys.stderr, flush=True)
|
||||
for k in sorted(unexpected_keys):
|
||||
print(k, file=sys.stderr)
|
||||
sys.exit(4)
|
||||
|
||||
controlnet.load_state_dict(control_state_dict, strict=True, assign=True)
|
||||
return controlnet
|
||||
|
||||
|
||||
def convert_vae(vae_type: str):
|
||||
model_name = VAE_CONFIGS[vae_type]["name"]
|
||||
snapshot_directory = snapshot_download(model_name, repo_type="model")
|
||||
@@ -586,7 +782,7 @@ def save_pipeline_cosmos_2_0(args, transformer, vae):
|
||||
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
|
||||
|
||||
|
||||
def save_pipeline_cosmos2_5(args, transformer, vae):
|
||||
def save_pipeline_cosmos2_5_predict(args, transformer, vae):
|
||||
text_encoder_path = args.text_encoder_path or "nvidia/Cosmos-Reason1-7B"
|
||||
tokenizer_path = args.tokenizer_path or "Qwen/Qwen2.5-VL-7B-Instruct"
|
||||
|
||||
@@ -614,6 +810,35 @@ def save_pipeline_cosmos2_5(args, transformer, vae):
|
||||
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
|
||||
|
||||
|
||||
def save_pipeline_cosmos2_5_transfer(args, transformer, controlnet, vae):
|
||||
text_encoder_path = args.text_encoder_path or "nvidia/Cosmos-Reason1-7B"
|
||||
tokenizer_path = args.tokenizer_path or "Qwen/Qwen2.5-VL-7B-Instruct"
|
||||
|
||||
text_encoder = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
||||
text_encoder_path, torch_dtype="auto", device_map="cpu"
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
|
||||
|
||||
scheduler = UniPCMultistepScheduler(
|
||||
use_karras_sigmas=True,
|
||||
use_flow_sigmas=True,
|
||||
prediction_type="flow_prediction",
|
||||
sigma_max=200.0,
|
||||
sigma_min=0.01,
|
||||
)
|
||||
|
||||
pipe = Cosmos2_5_TransferPipeline(
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
transformer=transformer,
|
||||
controlnet=controlnet,
|
||||
vae=vae,
|
||||
scheduler=scheduler,
|
||||
safety_checker=lambda *args, **kwargs: None,
|
||||
)
|
||||
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--transformer_type", type=str, default=None, choices=list(TRANSFORMER_CONFIGS.keys()))
|
||||
@@ -642,18 +867,61 @@ if __name__ == "__main__":
|
||||
args = get_args()
|
||||
|
||||
transformer = None
|
||||
controlnet = None
|
||||
dtype = DTYPE_MAPPING[args.dtype]
|
||||
|
||||
if args.save_pipeline:
|
||||
assert args.transformer_ckpt_path is not None
|
||||
assert args.vae_type is not None
|
||||
|
||||
raw_state_dict = None
|
||||
if args.transformer_ckpt_path is not None:
|
||||
weights_only = "Cosmos-1.0" in args.transformer_type
|
||||
transformer = convert_transformer(args.transformer_type, args.transformer_ckpt_path, weights_only)
|
||||
transformer = transformer.to(dtype=dtype)
|
||||
if not args.save_pipeline:
|
||||
transformer.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
|
||||
raw_state_dict = get_state_dict(
|
||||
torch.load(args.transformer_ckpt_path, map_location="cpu", weights_only=weights_only)
|
||||
)
|
||||
|
||||
if raw_state_dict is not None:
|
||||
if "Transfer" in args.transformer_type:
|
||||
base_state_dict = {}
|
||||
control_state_dict = {}
|
||||
for k, v in raw_state_dict.items():
|
||||
plain_key = k.removeprefix("net.") if k.startswith("net.") else k
|
||||
if "control" in plain_key.lower():
|
||||
control_state_dict[k] = v
|
||||
else:
|
||||
base_state_dict[k] = v
|
||||
assert len(base_state_dict.keys() & control_state_dict.keys()) == 0
|
||||
|
||||
# Convert transformer first to get the processed base state dict
|
||||
transformer = convert_transformer(
|
||||
args.transformer_type, state_dict=base_state_dict, weights_only=weights_only
|
||||
)
|
||||
transformer = transformer.to(dtype=dtype)
|
||||
|
||||
# Get converted transformer state dict to copy shared weights to controlnet
|
||||
converted_base_state_dict = transformer.state_dict()
|
||||
|
||||
# Convert controlnet with both control-specific and shared weights from transformer
|
||||
controlnet = convert_controlnet(
|
||||
args.transformer_type, control_state_dict, converted_base_state_dict, weights_only=weights_only
|
||||
)
|
||||
controlnet = controlnet.to(dtype=dtype)
|
||||
|
||||
if not args.save_pipeline:
|
||||
transformer.save_pretrained(
|
||||
pathlib.Path(args.output_path) / "transformer", safe_serialization=True, max_shard_size="5GB"
|
||||
)
|
||||
controlnet.save_pretrained(
|
||||
pathlib.Path(args.output_path) / "controlnet", safe_serialization=True, max_shard_size="5GB"
|
||||
)
|
||||
else:
|
||||
transformer = convert_transformer(
|
||||
args.transformer_type, state_dict=raw_state_dict, weights_only=weights_only
|
||||
)
|
||||
transformer = transformer.to(dtype=dtype)
|
||||
if not args.save_pipeline:
|
||||
transformer.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
|
||||
|
||||
if args.vae_type is not None:
|
||||
if "Cosmos-1.0" in args.transformer_type:
|
||||
@@ -667,6 +935,8 @@ if __name__ == "__main__":
|
||||
|
||||
if not args.save_pipeline:
|
||||
vae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
|
||||
else:
|
||||
vae = None
|
||||
|
||||
if args.save_pipeline:
|
||||
if "Cosmos-1.0" in args.transformer_type:
|
||||
@@ -678,6 +948,11 @@ if __name__ == "__main__":
|
||||
assert args.tokenizer_path is not None
|
||||
save_pipeline_cosmos_2_0(args, transformer, vae)
|
||||
elif "Cosmos-2.5" in args.transformer_type:
|
||||
save_pipeline_cosmos2_5(args, transformer, vae)
|
||||
if "Predict" in args.transformer_type:
|
||||
save_pipeline_cosmos2_5_predict(args, transformer, vae)
|
||||
elif "Transfer" in args.transformer_type:
|
||||
save_pipeline_cosmos2_5_transfer(args, transformer, None, vae)
|
||||
else:
|
||||
raise AssertionError(f"{args.transformer_type} not supported")
|
||||
else:
|
||||
raise AssertionError(f"{args.transformer_type} not supported")
|
||||
|
||||
@@ -221,6 +221,7 @@ else:
|
||||
"ControlNetModel",
|
||||
"ControlNetUnionModel",
|
||||
"ControlNetXSAdapter",
|
||||
"CosmosControlNetModel",
|
||||
"CosmosTransformer3DModel",
|
||||
"DiTTransformer2DModel",
|
||||
"EasyAnimateTransformer3DModel",
|
||||
@@ -485,6 +486,7 @@ else:
|
||||
"CogView4Pipeline",
|
||||
"ConsisIDPipeline",
|
||||
"Cosmos2_5_PredictBasePipeline",
|
||||
"Cosmos2_5_TransferPipeline",
|
||||
"Cosmos2TextToImagePipeline",
|
||||
"Cosmos2VideoToWorldPipeline",
|
||||
"CosmosTextToWorldPipeline",
|
||||
@@ -992,6 +994,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
ControlNetModel,
|
||||
ControlNetUnionModel,
|
||||
ControlNetXSAdapter,
|
||||
CosmosControlNetModel,
|
||||
CosmosTransformer3DModel,
|
||||
DiTTransformer2DModel,
|
||||
EasyAnimateTransformer3DModel,
|
||||
@@ -1226,6 +1229,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
CogView4Pipeline,
|
||||
ConsisIDPipeline,
|
||||
Cosmos2_5_PredictBasePipeline,
|
||||
Cosmos2_5_TransferPipeline,
|
||||
Cosmos2TextToImagePipeline,
|
||||
Cosmos2VideoToWorldPipeline,
|
||||
CosmosTextToWorldPipeline,
|
||||
|
||||
@@ -2321,6 +2321,14 @@ def _convert_non_diffusers_flux2_lora_to_diffusers(state_dict):
|
||||
prefix = "diffusion_model."
|
||||
original_state_dict = {k[len(prefix) :]: v for k, v in state_dict.items()}
|
||||
|
||||
has_lora_down_up = any("lora_down" in k or "lora_up" in k for k in original_state_dict.keys())
|
||||
if has_lora_down_up:
|
||||
temp_state_dict = {}
|
||||
for k, v in original_state_dict.items():
|
||||
new_key = k.replace("lora_down", "lora_A").replace("lora_up", "lora_B")
|
||||
temp_state_dict[new_key] = v
|
||||
original_state_dict = temp_state_dict
|
||||
|
||||
num_double_layers = 0
|
||||
num_single_layers = 0
|
||||
for key in original_state_dict.keys():
|
||||
@@ -2337,13 +2345,15 @@ def _convert_non_diffusers_flux2_lora_to_diffusers(state_dict):
|
||||
attn_prefix = f"single_transformer_blocks.{sl}.attn"
|
||||
|
||||
for lora_key in lora_keys:
|
||||
converted_state_dict[f"{attn_prefix}.to_qkv_mlp_proj.{lora_key}.weight"] = original_state_dict.pop(
|
||||
f"{single_block_prefix}.linear1.{lora_key}.weight"
|
||||
)
|
||||
linear1_key = f"{single_block_prefix}.linear1.{lora_key}.weight"
|
||||
if linear1_key in original_state_dict:
|
||||
converted_state_dict[f"{attn_prefix}.to_qkv_mlp_proj.{lora_key}.weight"] = original_state_dict.pop(
|
||||
linear1_key
|
||||
)
|
||||
|
||||
converted_state_dict[f"{attn_prefix}.to_out.{lora_key}.weight"] = original_state_dict.pop(
|
||||
f"{single_block_prefix}.linear2.{lora_key}.weight"
|
||||
)
|
||||
linear2_key = f"{single_block_prefix}.linear2.{lora_key}.weight"
|
||||
if linear2_key in original_state_dict:
|
||||
converted_state_dict[f"{attn_prefix}.to_out.{lora_key}.weight"] = original_state_dict.pop(linear2_key)
|
||||
|
||||
for dl in range(num_double_layers):
|
||||
transformer_block_prefix = f"transformer_blocks.{dl}"
|
||||
@@ -2352,6 +2362,10 @@ def _convert_non_diffusers_flux2_lora_to_diffusers(state_dict):
|
||||
for attn_type in attn_types:
|
||||
attn_prefix = f"{transformer_block_prefix}.attn"
|
||||
qkv_key = f"double_blocks.{dl}.{attn_type}.qkv.{lora_key}.weight"
|
||||
|
||||
if qkv_key not in original_state_dict:
|
||||
continue
|
||||
|
||||
fused_qkv_weight = original_state_dict.pop(qkv_key)
|
||||
|
||||
if lora_key == "lora_A":
|
||||
@@ -2383,8 +2397,9 @@ def _convert_non_diffusers_flux2_lora_to_diffusers(state_dict):
|
||||
for org_proj, diff_proj in proj_mappings:
|
||||
for lora_key in lora_keys:
|
||||
original_key = f"double_blocks.{dl}.{org_proj}.{lora_key}.weight"
|
||||
diffusers_key = f"{transformer_block_prefix}.{diff_proj}.{lora_key}.weight"
|
||||
converted_state_dict[diffusers_key] = original_state_dict.pop(original_key)
|
||||
if original_key in original_state_dict:
|
||||
diffusers_key = f"{transformer_block_prefix}.{diff_proj}.{lora_key}.weight"
|
||||
converted_state_dict[diffusers_key] = original_state_dict.pop(original_key)
|
||||
|
||||
mlp_mappings = [
|
||||
("img_mlp.0", "ff.linear_in"),
|
||||
@@ -2395,8 +2410,27 @@ def _convert_non_diffusers_flux2_lora_to_diffusers(state_dict):
|
||||
for org_mlp, diff_mlp in mlp_mappings:
|
||||
for lora_key in lora_keys:
|
||||
original_key = f"double_blocks.{dl}.{org_mlp}.{lora_key}.weight"
|
||||
diffusers_key = f"{transformer_block_prefix}.{diff_mlp}.{lora_key}.weight"
|
||||
converted_state_dict[diffusers_key] = original_state_dict.pop(original_key)
|
||||
if original_key in original_state_dict:
|
||||
diffusers_key = f"{transformer_block_prefix}.{diff_mlp}.{lora_key}.weight"
|
||||
converted_state_dict[diffusers_key] = original_state_dict.pop(original_key)
|
||||
|
||||
extra_mappings = {
|
||||
"img_in": "x_embedder",
|
||||
"txt_in": "context_embedder",
|
||||
"time_in.in_layer": "time_guidance_embed.timestep_embedder.linear_1",
|
||||
"time_in.out_layer": "time_guidance_embed.timestep_embedder.linear_2",
|
||||
"final_layer.linear": "proj_out",
|
||||
"final_layer.adaLN_modulation.1": "norm_out.linear",
|
||||
"single_stream_modulation.lin": "single_stream_modulation.linear",
|
||||
"double_stream_modulation_img.lin": "double_stream_modulation_img.linear",
|
||||
"double_stream_modulation_txt.lin": "double_stream_modulation_txt.linear",
|
||||
}
|
||||
|
||||
for org_key, diff_key in extra_mappings.items():
|
||||
for lora_key in lora_keys:
|
||||
original_key = f"{org_key}.{lora_key}.weight"
|
||||
if original_key in original_state_dict:
|
||||
converted_state_dict[f"{diff_key}.{lora_key}.weight"] = original_state_dict.pop(original_key)
|
||||
|
||||
if len(original_state_dict) > 0:
|
||||
raise ValueError(f"`original_state_dict` should be empty at this point but has {original_state_dict.keys()=}.")
|
||||
|
||||
@@ -54,6 +54,7 @@ if is_torch_available():
|
||||
_import_structure["autoencoders.vq_model"] = ["VQModel"]
|
||||
_import_structure["cache_utils"] = ["CacheMixin"]
|
||||
_import_structure["controlnets.controlnet"] = ["ControlNetModel"]
|
||||
_import_structure["controlnets.controlnet_cosmos"] = ["CosmosControlNetModel"]
|
||||
_import_structure["controlnets.controlnet_flux"] = ["FluxControlNetModel", "FluxMultiControlNetModel"]
|
||||
_import_structure["controlnets.controlnet_hunyuan"] = [
|
||||
"HunyuanDiT2DControlNetModel",
|
||||
@@ -175,6 +176,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
ControlNetModel,
|
||||
ControlNetUnionModel,
|
||||
ControlNetXSAdapter,
|
||||
CosmosControlNetModel,
|
||||
FluxControlNetModel,
|
||||
FluxMultiControlNetModel,
|
||||
HunyuanDiT2DControlNetModel,
|
||||
|
||||
@@ -3,6 +3,7 @@ from ...utils import is_flax_available, is_torch_available
|
||||
|
||||
if is_torch_available():
|
||||
from .controlnet import ControlNetModel, ControlNetOutput
|
||||
from .controlnet_cosmos import CosmosControlNetModel
|
||||
from .controlnet_flux import FluxControlNetModel, FluxControlNetOutput, FluxMultiControlNetModel
|
||||
from .controlnet_hunyuan import (
|
||||
HunyuanControlNetOutput,
|
||||
|
||||
312
src/diffusers/models/controlnets/controlnet_cosmos.py
Normal file
312
src/diffusers/models/controlnets/controlnet_cosmos.py
Normal file
@@ -0,0 +1,312 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin
|
||||
from ...utils import BaseOutput, is_torchvision_available, logging
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..transformers.transformer_cosmos import (
|
||||
CosmosEmbedding,
|
||||
CosmosLearnablePositionalEmbed,
|
||||
CosmosPatchEmbed,
|
||||
CosmosRotaryPosEmbed,
|
||||
CosmosTransformerBlock,
|
||||
)
|
||||
|
||||
|
||||
if is_torchvision_available():
|
||||
from torchvision import transforms
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
@dataclass
|
||||
class CosmosControlNetOutput(BaseOutput):
|
||||
"""
|
||||
Output of [`CosmosControlNetModel`].
|
||||
|
||||
Args:
|
||||
control_block_samples (`list[torch.Tensor]`):
|
||||
List of control block activations to be injected into transformer blocks.
|
||||
"""
|
||||
|
||||
control_block_samples: List[torch.Tensor]
|
||||
|
||||
|
||||
class CosmosControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
r"""
|
||||
ControlNet for Cosmos Transfer2.5.
|
||||
|
||||
This model duplicates the shared embedding modules from the transformer (patch_embed, time_embed,
|
||||
learnable_pos_embed, img_context_proj) to enable proper CPU offloading. The forward() method computes everything
|
||||
internally from raw inputs.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
_skip_layerwise_casting_patterns = ["patch_embed", "patch_embed_base", "time_embed"]
|
||||
_no_split_modules = ["CosmosTransformerBlock"]
|
||||
_keep_in_fp32_modules = ["learnable_pos_embed"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
n_controlnet_blocks: int = 4,
|
||||
in_channels: int = 130,
|
||||
latent_channels: int = 18, # base latent channels (latents + condition_mask) + padding_mask
|
||||
model_channels: int = 2048,
|
||||
num_attention_heads: int = 32,
|
||||
attention_head_dim: int = 128,
|
||||
mlp_ratio: float = 4.0,
|
||||
text_embed_dim: int = 1024,
|
||||
adaln_lora_dim: int = 256,
|
||||
patch_size: Tuple[int, int, int] = (1, 2, 2),
|
||||
max_size: Tuple[int, int, int] = (128, 240, 240),
|
||||
rope_scale: Tuple[float, float, float] = (2.0, 1.0, 1.0),
|
||||
extra_pos_embed_type: Optional[str] = None,
|
||||
img_context_dim_in: Optional[int] = None,
|
||||
img_context_dim_out: int = 2048,
|
||||
use_crossattn_projection: bool = False,
|
||||
crossattn_proj_in_channels: int = 1024,
|
||||
encoder_hidden_states_channels: int = 1024,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.patch_embed = CosmosPatchEmbed(in_channels, model_channels, patch_size, bias=False)
|
||||
|
||||
self.patch_embed_base = CosmosPatchEmbed(latent_channels, model_channels, patch_size, bias=False)
|
||||
self.time_embed = CosmosEmbedding(model_channels, model_channels)
|
||||
|
||||
self.learnable_pos_embed = None
|
||||
if extra_pos_embed_type == "learnable":
|
||||
self.learnable_pos_embed = CosmosLearnablePositionalEmbed(
|
||||
hidden_size=model_channels,
|
||||
max_size=max_size,
|
||||
patch_size=patch_size,
|
||||
)
|
||||
|
||||
self.img_context_proj = None
|
||||
if img_context_dim_in is not None and img_context_dim_in > 0:
|
||||
self.img_context_proj = nn.Sequential(
|
||||
nn.Linear(img_context_dim_in, img_context_dim_out, bias=True),
|
||||
nn.GELU(),
|
||||
)
|
||||
|
||||
# Cross-attention projection for text embeddings (same as transformer)
|
||||
self.crossattn_proj = None
|
||||
if use_crossattn_projection:
|
||||
self.crossattn_proj = nn.Sequential(
|
||||
nn.Linear(crossattn_proj_in_channels, encoder_hidden_states_channels, bias=True),
|
||||
nn.GELU(),
|
||||
)
|
||||
|
||||
# RoPE for both control and base latents
|
||||
self.rope = CosmosRotaryPosEmbed(
|
||||
hidden_size=attention_head_dim, max_size=max_size, patch_size=patch_size, rope_scale=rope_scale
|
||||
)
|
||||
|
||||
self.control_blocks = nn.ModuleList(
|
||||
[
|
||||
CosmosTransformerBlock(
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
cross_attention_dim=text_embed_dim,
|
||||
mlp_ratio=mlp_ratio,
|
||||
adaln_lora_dim=adaln_lora_dim,
|
||||
qk_norm="rms_norm",
|
||||
out_bias=False,
|
||||
img_context=img_context_dim_in is not None and img_context_dim_in > 0,
|
||||
before_proj=(block_idx == 0),
|
||||
after_proj=True,
|
||||
)
|
||||
for block_idx in range(n_controlnet_blocks)
|
||||
]
|
||||
)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def _expand_conditioning_scale(self, conditioning_scale: Union[float, List[float]]) -> List[float]:
|
||||
if isinstance(conditioning_scale, list):
|
||||
scales = conditioning_scale
|
||||
else:
|
||||
scales = [conditioning_scale] * len(self.control_blocks)
|
||||
|
||||
if len(scales) < len(self.control_blocks):
|
||||
logger.warning(
|
||||
"Received %d control scales, but control network defines %d blocks. "
|
||||
"Scales will be trimmed or repeated to match.",
|
||||
len(scales),
|
||||
len(self.control_blocks),
|
||||
)
|
||||
scales = (scales * len(self.control_blocks))[: len(self.control_blocks)]
|
||||
return scales
|
||||
|
||||
def forward(
|
||||
self,
|
||||
controls_latents: torch.Tensor,
|
||||
latents: torch.Tensor,
|
||||
timestep: torch.Tensor,
|
||||
encoder_hidden_states: Union[Optional[torch.Tensor], Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]],
|
||||
condition_mask: torch.Tensor,
|
||||
conditioning_scale: Union[float, List[float]] = 1.0,
|
||||
padding_mask: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
fps: Optional[int] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[CosmosControlNetOutput, Tuple[List[torch.Tensor]]]:
|
||||
"""
|
||||
Forward pass for the ControlNet.
|
||||
|
||||
Args:
|
||||
controls_latents: Control signal latents [B, C, T, H, W]
|
||||
latents: Base latents from the noising process [B, C, T, H, W]
|
||||
timestep: Diffusion timestep tensor
|
||||
encoder_hidden_states: Tuple of (text_context, img_context) or text_context
|
||||
condition_mask: Conditioning mask [B, 1, T, H, W]
|
||||
conditioning_scale: Scale factor(s) for control outputs
|
||||
padding_mask: Padding mask [B, 1, H, W] or None
|
||||
attention_mask: Optional attention mask or None
|
||||
fps: Frames per second for RoPE or None
|
||||
return_dict: Whether to return a CosmosControlNetOutput or a tuple
|
||||
|
||||
Returns:
|
||||
CosmosControlNetOutput or tuple of control tensors
|
||||
"""
|
||||
B, C, T, H, W = controls_latents.shape
|
||||
|
||||
# 1. Prepare control latents
|
||||
control_hidden_states = controls_latents
|
||||
vace_in_channels = self.config.in_channels - 1
|
||||
if control_hidden_states.shape[1] < vace_in_channels - 1:
|
||||
pad_C = vace_in_channels - 1 - control_hidden_states.shape[1]
|
||||
control_hidden_states = torch.cat(
|
||||
[
|
||||
control_hidden_states,
|
||||
torch.zeros(
|
||||
(B, pad_C, T, H, W), dtype=control_hidden_states.dtype, device=control_hidden_states.device
|
||||
),
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
|
||||
control_hidden_states = torch.cat([control_hidden_states, torch.zeros_like(controls_latents[:, :1])], dim=1)
|
||||
|
||||
padding_mask_resized = transforms.functional.resize(
|
||||
padding_mask, list(control_hidden_states.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST
|
||||
)
|
||||
control_hidden_states = torch.cat(
|
||||
[control_hidden_states, padding_mask_resized.unsqueeze(2).repeat(B, 1, T, 1, 1)], dim=1
|
||||
)
|
||||
|
||||
# 2. Prepare base latents (same processing as transformer.forward)
|
||||
base_hidden_states = latents
|
||||
if condition_mask is not None:
|
||||
base_hidden_states = torch.cat([base_hidden_states, condition_mask], dim=1)
|
||||
|
||||
base_padding_mask = transforms.functional.resize(
|
||||
padding_mask, list(base_hidden_states.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST
|
||||
)
|
||||
base_hidden_states = torch.cat(
|
||||
[base_hidden_states, base_padding_mask.unsqueeze(2).repeat(B, 1, T, 1, 1)], dim=1
|
||||
)
|
||||
|
||||
# 3. Generate positional embeddings (shared for both)
|
||||
image_rotary_emb = self.rope(control_hidden_states, fps=fps)
|
||||
extra_pos_emb = self.learnable_pos_embed(control_hidden_states) if self.learnable_pos_embed else None
|
||||
|
||||
# 4. Patchify control latents
|
||||
control_hidden_states = self.patch_embed(control_hidden_states)
|
||||
control_hidden_states = control_hidden_states.flatten(1, 3)
|
||||
|
||||
# 5. Patchify base latents
|
||||
p_t, p_h, p_w = self.config.patch_size
|
||||
post_patch_num_frames = T // p_t
|
||||
post_patch_height = H // p_h
|
||||
post_patch_width = W // p_w
|
||||
|
||||
base_hidden_states = self.patch_embed_base(base_hidden_states)
|
||||
base_hidden_states = base_hidden_states.flatten(1, 3)
|
||||
|
||||
# 6. Time embeddings
|
||||
if timestep.ndim == 1:
|
||||
temb, embedded_timestep = self.time_embed(base_hidden_states, timestep)
|
||||
elif timestep.ndim == 5:
|
||||
batch_size, _, num_frames, _, _ = latents.shape
|
||||
assert timestep.shape == (batch_size, 1, num_frames, 1, 1), (
|
||||
f"Expected timestep to have shape [B, 1, T, 1, 1], but got {timestep.shape}"
|
||||
)
|
||||
timestep_flat = timestep.flatten()
|
||||
temb, embedded_timestep = self.time_embed(base_hidden_states, timestep_flat)
|
||||
temb, embedded_timestep = (
|
||||
x.view(batch_size, post_patch_num_frames, 1, 1, -1)
|
||||
.expand(-1, -1, post_patch_height, post_patch_width, -1)
|
||||
.flatten(1, 3)
|
||||
for x in (temb, embedded_timestep)
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Expected timestep to have shape [B, 1, T, 1, 1] or [T], but got {timestep.shape}")
|
||||
|
||||
# 7. Process encoder hidden states
|
||||
if isinstance(encoder_hidden_states, tuple):
|
||||
text_context, img_context = encoder_hidden_states
|
||||
else:
|
||||
text_context = encoder_hidden_states
|
||||
img_context = None
|
||||
|
||||
# Apply cross-attention projection to text context
|
||||
if self.crossattn_proj is not None:
|
||||
text_context = self.crossattn_proj(text_context)
|
||||
|
||||
# Apply cross-attention projection to image context (if provided)
|
||||
if img_context is not None and self.img_context_proj is not None:
|
||||
img_context = self.img_context_proj(img_context)
|
||||
|
||||
# Combine text and image context into a single tuple
|
||||
if self.config.img_context_dim_in is not None and self.config.img_context_dim_in > 0:
|
||||
processed_encoder_hidden_states = (text_context, img_context)
|
||||
else:
|
||||
processed_encoder_hidden_states = text_context
|
||||
|
||||
# 8. Prepare attention mask
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask.unsqueeze(1).unsqueeze(1) # [B, 1, 1, S]
|
||||
|
||||
# 9. Run control blocks
|
||||
scales = self._expand_conditioning_scale(conditioning_scale)
|
||||
result = []
|
||||
for block_idx, (block, scale) in enumerate(zip(self.control_blocks, scales)):
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
control_hidden_states, control_proj = self._gradient_checkpointing_func(
|
||||
block,
|
||||
control_hidden_states,
|
||||
processed_encoder_hidden_states,
|
||||
embedded_timestep,
|
||||
temb,
|
||||
image_rotary_emb,
|
||||
extra_pos_emb,
|
||||
attention_mask,
|
||||
None, # controlnet_residual
|
||||
base_hidden_states,
|
||||
block_idx,
|
||||
)
|
||||
else:
|
||||
control_hidden_states, control_proj = block(
|
||||
hidden_states=control_hidden_states,
|
||||
encoder_hidden_states=processed_encoder_hidden_states,
|
||||
embedded_timestep=embedded_timestep,
|
||||
temb=temb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
extra_pos_emb=extra_pos_emb,
|
||||
attention_mask=attention_mask,
|
||||
controlnet_residual=None,
|
||||
latents=base_hidden_states,
|
||||
block_idx=block_idx,
|
||||
)
|
||||
result.append(control_proj * scale)
|
||||
|
||||
if not return_dict:
|
||||
return (result,)
|
||||
|
||||
return CosmosControlNetOutput(control_block_samples=result)
|
||||
@@ -43,7 +43,7 @@ def _get_qkv_projections(attn: "WanAttention", hidden_states: torch.Tensor, enco
|
||||
encoder_hidden_states = hidden_states
|
||||
|
||||
if attn.fused_projections:
|
||||
if attn.cross_attention_dim_head is None:
|
||||
if not attn.is_cross_attention:
|
||||
# In self-attention layers, we can fuse the entire QKV projection into a single linear
|
||||
query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1)
|
||||
else:
|
||||
@@ -219,7 +219,10 @@ class WanAttention(torch.nn.Module, AttentionModuleMixin):
|
||||
self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True)
|
||||
self.norm_added_k = torch.nn.RMSNorm(dim_head * heads, eps=eps)
|
||||
|
||||
self.is_cross_attention = cross_attention_dim_head is not None
|
||||
if is_cross_attention is not None:
|
||||
self.is_cross_attention = is_cross_attention
|
||||
else:
|
||||
self.is_cross_attention = cross_attention_dim_head is not None
|
||||
|
||||
self.set_processor(processor)
|
||||
|
||||
@@ -227,7 +230,7 @@ class WanAttention(torch.nn.Module, AttentionModuleMixin):
|
||||
if getattr(self, "fused_projections", False):
|
||||
return
|
||||
|
||||
if self.cross_attention_dim_head is None:
|
||||
if not self.is_cross_attention:
|
||||
concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
|
||||
concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data])
|
||||
out_features, in_features = concatenated_weights.shape
|
||||
|
||||
@@ -12,17 +12,17 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional, Tuple
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin
|
||||
from ...utils import is_torchvision_available
|
||||
from ..attention import FeedForward
|
||||
from ..attention_dispatch import dispatch_attention_fn
|
||||
from ..attention_processor import Attention
|
||||
from ..embeddings import Timesteps
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
@@ -152,7 +152,7 @@ class CosmosAdaLayerNormZero(nn.Module):
|
||||
|
||||
class CosmosAttnProcessor2_0:
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
if not hasattr(torch.nn.functional, "scaled_dot_product_attention"):
|
||||
raise ImportError("CosmosAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")
|
||||
|
||||
def __call__(
|
||||
@@ -191,7 +191,6 @@ class CosmosAttnProcessor2_0:
|
||||
query_idx = torch.tensor(query.size(3), device=query.device)
|
||||
key_idx = torch.tensor(key.size(3), device=key.device)
|
||||
value_idx = torch.tensor(value.size(3), device=value.device)
|
||||
|
||||
else:
|
||||
query_idx = query.size(3)
|
||||
key_idx = key.size(3)
|
||||
@@ -200,18 +199,148 @@ class CosmosAttnProcessor2_0:
|
||||
value = value.repeat_interleave(query_idx // value_idx, dim=3)
|
||||
|
||||
# 5. Attention
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||
hidden_states = dispatch_attention_fn(
|
||||
query.transpose(1, 2),
|
||||
key.transpose(1, 2),
|
||||
value.transpose(1, 2),
|
||||
attn_mask=attention_mask,
|
||||
dropout_p=0.0,
|
||||
is_causal=False,
|
||||
)
|
||||
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3).type_as(query)
|
||||
|
||||
# 6. Output projection
|
||||
hidden_states = hidden_states.flatten(2, 3).type_as(query)
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CosmosAttnProcessor2_5:
|
||||
def __init__(self):
|
||||
if not hasattr(torch.nn.functional, "scaled_dot_product_attention"):
|
||||
raise ImportError("CosmosAttnProcessor2_5 requires PyTorch 2.0. Please upgrade PyTorch to 2.0 or newer.")
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]],
|
||||
attention_mask: Optional[Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]],
|
||||
image_rotary_emb=None,
|
||||
) -> torch.Tensor:
|
||||
if not isinstance(encoder_hidden_states, tuple):
|
||||
raise ValueError("Expected encoder_hidden_states as (text_context, img_context) tuple.")
|
||||
|
||||
text_context, img_context = encoder_hidden_states if encoder_hidden_states else (None, None)
|
||||
text_mask, img_mask = attention_mask if attention_mask else (None, None)
|
||||
|
||||
if text_context is None:
|
||||
text_context = hidden_states
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(text_context)
|
||||
value = attn.to_v(text_context)
|
||||
|
||||
query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
||||
key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
||||
value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
||||
|
||||
query = attn.norm_q(query)
|
||||
key = attn.norm_k(key)
|
||||
|
||||
if image_rotary_emb is not None:
|
||||
from ..embeddings import apply_rotary_emb
|
||||
|
||||
query = apply_rotary_emb(query, image_rotary_emb, use_real=True, use_real_unbind_dim=-2)
|
||||
key = apply_rotary_emb(key, image_rotary_emb, use_real=True, use_real_unbind_dim=-2)
|
||||
|
||||
if torch.onnx.is_in_onnx_export():
|
||||
query_idx = torch.tensor(query.size(3), device=query.device)
|
||||
key_idx = torch.tensor(key.size(3), device=key.device)
|
||||
value_idx = torch.tensor(value.size(3), device=value.device)
|
||||
else:
|
||||
query_idx = query.size(3)
|
||||
key_idx = key.size(3)
|
||||
value_idx = value.size(3)
|
||||
key = key.repeat_interleave(query_idx // key_idx, dim=3)
|
||||
value = value.repeat_interleave(query_idx // value_idx, dim=3)
|
||||
|
||||
attn_out = dispatch_attention_fn(
|
||||
query.transpose(1, 2),
|
||||
key.transpose(1, 2),
|
||||
value.transpose(1, 2),
|
||||
attn_mask=text_mask,
|
||||
dropout_p=0.0,
|
||||
is_causal=False,
|
||||
)
|
||||
attn_out = attn_out.flatten(2, 3).type_as(query)
|
||||
|
||||
if img_context is not None:
|
||||
q_img = attn.q_img(hidden_states)
|
||||
k_img = attn.k_img(img_context)
|
||||
v_img = attn.v_img(img_context)
|
||||
|
||||
batch_size = hidden_states.shape[0]
|
||||
dim_head = attn.out_dim // attn.heads
|
||||
|
||||
q_img = q_img.view(batch_size, -1, attn.heads, dim_head).transpose(1, 2)
|
||||
k_img = k_img.view(batch_size, -1, attn.heads, dim_head).transpose(1, 2)
|
||||
v_img = v_img.view(batch_size, -1, attn.heads, dim_head).transpose(1, 2)
|
||||
|
||||
q_img = attn.q_img_norm(q_img)
|
||||
k_img = attn.k_img_norm(k_img)
|
||||
|
||||
q_img_idx = q_img.size(3)
|
||||
k_img_idx = k_img.size(3)
|
||||
v_img_idx = v_img.size(3)
|
||||
k_img = k_img.repeat_interleave(q_img_idx // k_img_idx, dim=3)
|
||||
v_img = v_img.repeat_interleave(q_img_idx // v_img_idx, dim=3)
|
||||
|
||||
img_out = dispatch_attention_fn(
|
||||
q_img.transpose(1, 2),
|
||||
k_img.transpose(1, 2),
|
||||
v_img.transpose(1, 2),
|
||||
attn_mask=img_mask,
|
||||
dropout_p=0.0,
|
||||
is_causal=False,
|
||||
)
|
||||
img_out = img_out.flatten(2, 3).type_as(q_img)
|
||||
hidden_states = attn_out + img_out
|
||||
else:
|
||||
hidden_states = attn_out
|
||||
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CosmosAttention(Attention):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# add parameters for image q/k/v
|
||||
inner_dim = self.heads * self.to_q.out_features // self.heads
|
||||
self.q_img = nn.Linear(self.query_dim, inner_dim, bias=False)
|
||||
self.k_img = nn.Linear(self.query_dim, inner_dim, bias=False)
|
||||
self.v_img = nn.Linear(self.query_dim, inner_dim, bias=False)
|
||||
self.q_img_norm = RMSNorm(self.to_q.out_features // self.heads, eps=1e-6, elementwise_affine=True)
|
||||
self.k_img_norm = RMSNorm(self.to_k.out_features // self.heads, eps=1e-6, elementwise_affine=True)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]],
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
**cross_attention_kwargs,
|
||||
) -> torch.Tensor:
|
||||
return super().forward(
|
||||
hidden_states=hidden_states,
|
||||
# NOTE: type-hint in base class can be ignored
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
**cross_attention_kwargs,
|
||||
)
|
||||
|
||||
|
||||
class CosmosTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -222,12 +351,16 @@ class CosmosTransformerBlock(nn.Module):
|
||||
adaln_lora_dim: int = 256,
|
||||
qk_norm: str = "rms_norm",
|
||||
out_bias: bool = False,
|
||||
img_context: bool = False,
|
||||
before_proj: bool = False,
|
||||
after_proj: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
hidden_size = num_attention_heads * attention_head_dim
|
||||
|
||||
self.norm1 = CosmosAdaLayerNormZero(in_features=hidden_size, hidden_features=adaln_lora_dim)
|
||||
self.img_context = img_context
|
||||
self.attn1 = Attention(
|
||||
query_dim=hidden_size,
|
||||
cross_attention_dim=None,
|
||||
@@ -240,30 +373,58 @@ class CosmosTransformerBlock(nn.Module):
|
||||
)
|
||||
|
||||
self.norm2 = CosmosAdaLayerNormZero(in_features=hidden_size, hidden_features=adaln_lora_dim)
|
||||
self.attn2 = Attention(
|
||||
query_dim=hidden_size,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
qk_norm=qk_norm,
|
||||
elementwise_affine=True,
|
||||
out_bias=out_bias,
|
||||
processor=CosmosAttnProcessor2_0(),
|
||||
)
|
||||
if img_context:
|
||||
self.attn2 = CosmosAttention(
|
||||
query_dim=hidden_size,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
qk_norm=qk_norm,
|
||||
elementwise_affine=True,
|
||||
out_bias=out_bias,
|
||||
processor=CosmosAttnProcessor2_5(),
|
||||
)
|
||||
else:
|
||||
self.attn2 = Attention(
|
||||
query_dim=hidden_size,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
qk_norm=qk_norm,
|
||||
elementwise_affine=True,
|
||||
out_bias=out_bias,
|
||||
processor=CosmosAttnProcessor2_0(),
|
||||
)
|
||||
|
||||
self.norm3 = CosmosAdaLayerNormZero(in_features=hidden_size, hidden_features=adaln_lora_dim)
|
||||
self.ff = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu", bias=out_bias)
|
||||
|
||||
# NOTE: zero conv for CosmosControlNet
|
||||
self.before_proj = None
|
||||
self.after_proj = None
|
||||
if before_proj:
|
||||
self.before_proj = nn.Linear(hidden_size, hidden_size)
|
||||
if after_proj:
|
||||
self.after_proj = nn.Linear(hidden_size, hidden_size)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Union[
|
||||
Optional[torch.Tensor], Optional[Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]]
|
||||
],
|
||||
embedded_timestep: torch.Tensor,
|
||||
temb: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
extra_pos_emb: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
controlnet_residual: Optional[torch.Tensor] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
block_idx: Optional[int] = None,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
if self.before_proj is not None:
|
||||
hidden_states = self.before_proj(hidden_states) + latents
|
||||
|
||||
if extra_pos_emb is not None:
|
||||
hidden_states = hidden_states + extra_pos_emb
|
||||
|
||||
@@ -284,6 +445,16 @@ class CosmosTransformerBlock(nn.Module):
|
||||
ff_output = self.ff(norm_hidden_states)
|
||||
hidden_states = hidden_states + gate * ff_output
|
||||
|
||||
if controlnet_residual is not None:
|
||||
assert self.after_proj is None
|
||||
# NOTE: this is assumed to be scaled by the controlnet
|
||||
hidden_states += controlnet_residual
|
||||
|
||||
if self.after_proj is not None:
|
||||
assert controlnet_residual is None
|
||||
hs_proj = self.after_proj(hidden_states)
|
||||
return hidden_states, hs_proj
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
@@ -416,6 +587,17 @@ class CosmosTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
Whether to concatenate the padding mask to the input latent tensors.
|
||||
extra_pos_embed_type (`str`, *optional*, defaults to `learnable`):
|
||||
The type of extra positional embeddings to use. Can be one of `None` or `learnable`.
|
||||
controlnet_block_every_n (`int`, *optional*):
|
||||
Interval between transformer blocks that should receive control residuals (for example, `7` to inject after
|
||||
every seventh block). Required for Cosmos Transfer2.5.
|
||||
img_context_dim_in (`int`, *optional*):
|
||||
The dimension of the input image context feature vector, i.e. it is the D in [B, N, D].
|
||||
img_context_num_tokens (`int`):
|
||||
The number of tokens in the image context feature vector, i.e. it is the N in [B, N, D]. If
|
||||
`img_context_dim_in` is not provided, then this parameter is ignored.
|
||||
img_context_dim_out (`int`):
|
||||
The output dimension of the image context projection layer. If `img_context_dim_in` is not provided, then
|
||||
this parameter is ignored.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
@@ -442,6 +624,10 @@ class CosmosTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
use_crossattn_projection: bool = False,
|
||||
crossattn_proj_in_channels: int = 1024,
|
||||
encoder_hidden_states_channels: int = 1024,
|
||||
controlnet_block_every_n: Optional[int] = None,
|
||||
img_context_dim_in: Optional[int] = None,
|
||||
img_context_num_tokens: int = 256,
|
||||
img_context_dim_out: int = 2048,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
hidden_size = num_attention_heads * attention_head_dim
|
||||
@@ -477,6 +663,7 @@ class CosmosTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
adaln_lora_dim=adaln_lora_dim,
|
||||
qk_norm="rms_norm",
|
||||
out_bias=False,
|
||||
img_context=self.config.img_context_dim_in is not None and self.config.img_context_dim_in > 0,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
@@ -496,17 +683,24 @@ class CosmosTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
if self.config.img_context_dim_in:
|
||||
self.img_context_proj = nn.Sequential(
|
||||
nn.Linear(self.config.img_context_dim_in, self.config.img_context_dim_out, bias=True),
|
||||
nn.GELU(),
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
timestep: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
block_controlnet_hidden_states: Optional[List[torch.Tensor]] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
fps: Optional[int] = None,
|
||||
condition_mask: Optional[torch.Tensor] = None,
|
||||
padding_mask: Optional[torch.Tensor] = None,
|
||||
return_dict: bool = True,
|
||||
) -> torch.Tensor:
|
||||
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
|
||||
batch_size, num_channels, num_frames, height, width = hidden_states.shape
|
||||
|
||||
# 1. Concatenate padding mask if needed & prepare attention mask
|
||||
@@ -514,11 +708,11 @@ class CosmosTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
hidden_states = torch.cat([hidden_states, condition_mask], dim=1)
|
||||
|
||||
if self.config.concat_padding_mask:
|
||||
padding_mask = transforms.functional.resize(
|
||||
padding_mask_resized = transforms.functional.resize(
|
||||
padding_mask, list(hidden_states.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST
|
||||
)
|
||||
hidden_states = torch.cat(
|
||||
[hidden_states, padding_mask.unsqueeze(2).repeat(batch_size, 1, num_frames, 1, 1)], dim=1
|
||||
[hidden_states, padding_mask_resized.unsqueeze(2).repeat(batch_size, 1, num_frames, 1, 1)], dim=1
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
@@ -554,36 +748,59 @@ class CosmosTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
for x in (temb, embedded_timestep)
|
||||
) # [BT, C] -> [B, T, 1, 1, C] -> [B, T, H, W, C] -> [B, THW, C]
|
||||
else:
|
||||
assert False
|
||||
raise ValueError(f"Expected timestep to have shape [B, 1, T, 1, 1] or [T], but got {timestep.shape}")
|
||||
|
||||
# 5. Process encoder hidden states
|
||||
text_context, img_context = (
|
||||
encoder_hidden_states if isinstance(encoder_hidden_states, tuple) else (encoder_hidden_states, None)
|
||||
)
|
||||
if self.config.use_crossattn_projection:
|
||||
encoder_hidden_states = self.crossattn_proj(encoder_hidden_states)
|
||||
text_context = self.crossattn_proj(text_context)
|
||||
|
||||
# 5. Transformer blocks
|
||||
for block in self.transformer_blocks:
|
||||
if img_context is not None and self.config.img_context_dim_in:
|
||||
img_context = self.img_context_proj(img_context)
|
||||
|
||||
processed_encoder_hidden_states = (
|
||||
(text_context, img_context) if isinstance(encoder_hidden_states, tuple) else text_context
|
||||
)
|
||||
|
||||
# 6. Build controlnet block index map
|
||||
controlnet_block_index_map = {}
|
||||
if block_controlnet_hidden_states is not None:
|
||||
n_blocks = len(self.transformer_blocks)
|
||||
controlnet_block_index_map = {
|
||||
block_idx: block_controlnet_hidden_states[idx]
|
||||
for idx, block_idx in list(enumerate(range(0, n_blocks, self.config.controlnet_block_every_n)))
|
||||
}
|
||||
|
||||
# 7. Transformer blocks
|
||||
for block_idx, block in enumerate(self.transformer_blocks):
|
||||
controlnet_residual = controlnet_block_index_map.get(block_idx)
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states = self._gradient_checkpointing_func(
|
||||
block,
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
processed_encoder_hidden_states,
|
||||
embedded_timestep,
|
||||
temb,
|
||||
image_rotary_emb,
|
||||
extra_pos_emb,
|
||||
attention_mask,
|
||||
controlnet_residual,
|
||||
)
|
||||
else:
|
||||
hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
embedded_timestep=embedded_timestep,
|
||||
temb=temb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
extra_pos_emb=extra_pos_emb,
|
||||
attention_mask=attention_mask,
|
||||
hidden_states,
|
||||
processed_encoder_hidden_states,
|
||||
embedded_timestep,
|
||||
temb,
|
||||
image_rotary_emb,
|
||||
extra_pos_emb,
|
||||
attention_mask,
|
||||
controlnet_residual,
|
||||
)
|
||||
|
||||
# 6. Output norm & projection & unpatchify
|
||||
# 8. Output norm & projection & unpatchify
|
||||
hidden_states = self.norm_out(hidden_states, embedded_timestep, temb)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
hidden_states = hidden_states.unflatten(2, (p_h, p_w, p_t, -1))
|
||||
|
||||
@@ -56,10 +56,8 @@ def apply_split_rotary_emb(x: torch.Tensor, freqs: Tuple[torch.Tensor, torch.Ten
|
||||
x_dtype = x.dtype
|
||||
needs_reshape = False
|
||||
if x.ndim != 4 and cos.ndim == 4:
|
||||
# cos is (#b, h, t, r) -> reshape x to (b, h, t, dim_per_head)
|
||||
# The cos/sin batch dim may only be broadcastable, so take batch size from x
|
||||
b = x.shape[0]
|
||||
_, h, t, _ = cos.shape
|
||||
# cos is (b, h, t, r) -> reshape x to (b, h, t, dim_per_head)
|
||||
b, h, t, _ = cos.shape
|
||||
x = x.reshape(b, t, h, -1).swapaxes(1, 2)
|
||||
needs_reshape = True
|
||||
|
||||
|
||||
@@ -42,7 +42,7 @@ def _get_qkv_projections(attn: "WanAttention", hidden_states: torch.Tensor, enco
|
||||
encoder_hidden_states = hidden_states
|
||||
|
||||
if attn.fused_projections:
|
||||
if attn.cross_attention_dim_head is None:
|
||||
if not attn.is_cross_attention:
|
||||
# In self-attention layers, we can fuse the entire QKV projection into a single linear
|
||||
query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1)
|
||||
else:
|
||||
@@ -214,7 +214,10 @@ class WanAttention(torch.nn.Module, AttentionModuleMixin):
|
||||
self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True)
|
||||
self.norm_added_k = torch.nn.RMSNorm(dim_head * heads, eps=eps)
|
||||
|
||||
self.is_cross_attention = cross_attention_dim_head is not None
|
||||
if is_cross_attention is not None:
|
||||
self.is_cross_attention = is_cross_attention
|
||||
else:
|
||||
self.is_cross_attention = cross_attention_dim_head is not None
|
||||
|
||||
self.set_processor(processor)
|
||||
|
||||
@@ -222,7 +225,7 @@ class WanAttention(torch.nn.Module, AttentionModuleMixin):
|
||||
if getattr(self, "fused_projections", False):
|
||||
return
|
||||
|
||||
if self.cross_attention_dim_head is None:
|
||||
if not self.is_cross_attention:
|
||||
concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
|
||||
concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data])
|
||||
out_features, in_features = concatenated_weights.shape
|
||||
|
||||
@@ -54,7 +54,7 @@ def _get_qkv_projections(attn: "WanAttention", hidden_states: torch.Tensor, enco
|
||||
encoder_hidden_states = hidden_states
|
||||
|
||||
if attn.fused_projections:
|
||||
if attn.cross_attention_dim_head is None:
|
||||
if not attn.is_cross_attention:
|
||||
# In self-attention layers, we can fuse the entire QKV projection into a single linear
|
||||
query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1)
|
||||
else:
|
||||
@@ -502,13 +502,16 @@ class WanAnimateFaceBlockCrossAttention(nn.Module, AttentionModuleMixin):
|
||||
dim_head: int = 64,
|
||||
eps: float = 1e-6,
|
||||
cross_attention_dim_head: Optional[int] = None,
|
||||
bias: bool = True,
|
||||
processor=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.inner_dim = dim_head * heads
|
||||
self.heads = heads
|
||||
self.cross_attention_head_dim = cross_attention_dim_head
|
||||
self.cross_attention_dim_head = cross_attention_dim_head
|
||||
self.kv_inner_dim = self.inner_dim if cross_attention_dim_head is None else cross_attention_dim_head * heads
|
||||
self.use_bias = bias
|
||||
self.is_cross_attention = cross_attention_dim_head is not None
|
||||
|
||||
# 1. Pre-Attention Norms for the hidden_states (video latents) and encoder_hidden_states (motion vector).
|
||||
# NOTE: this is not used in "vanilla" WanAttention
|
||||
@@ -516,10 +519,10 @@ class WanAnimateFaceBlockCrossAttention(nn.Module, AttentionModuleMixin):
|
||||
self.pre_norm_kv = nn.LayerNorm(dim, eps, elementwise_affine=False)
|
||||
|
||||
# 2. QKV and Output Projections
|
||||
self.to_q = torch.nn.Linear(dim, self.inner_dim, bias=True)
|
||||
self.to_k = torch.nn.Linear(dim, self.kv_inner_dim, bias=True)
|
||||
self.to_v = torch.nn.Linear(dim, self.kv_inner_dim, bias=True)
|
||||
self.to_out = torch.nn.Linear(self.inner_dim, dim, bias=True)
|
||||
self.to_q = torch.nn.Linear(dim, self.inner_dim, bias=bias)
|
||||
self.to_k = torch.nn.Linear(dim, self.kv_inner_dim, bias=bias)
|
||||
self.to_v = torch.nn.Linear(dim, self.kv_inner_dim, bias=bias)
|
||||
self.to_out = torch.nn.Linear(self.inner_dim, dim, bias=bias)
|
||||
|
||||
# 3. QK Norm
|
||||
# NOTE: this is applied after the reshape, so only over dim_head rather than dim_head * heads
|
||||
@@ -682,7 +685,10 @@ class WanAttention(torch.nn.Module, AttentionModuleMixin):
|
||||
self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True)
|
||||
self.norm_added_k = torch.nn.RMSNorm(dim_head * heads, eps=eps)
|
||||
|
||||
self.is_cross_attention = cross_attention_dim_head is not None
|
||||
if is_cross_attention is not None:
|
||||
self.is_cross_attention = is_cross_attention
|
||||
else:
|
||||
self.is_cross_attention = cross_attention_dim_head is not None
|
||||
|
||||
self.set_processor(processor)
|
||||
|
||||
@@ -690,7 +696,7 @@ class WanAttention(torch.nn.Module, AttentionModuleMixin):
|
||||
if getattr(self, "fused_projections", False):
|
||||
return
|
||||
|
||||
if self.cross_attention_dim_head is None:
|
||||
if not self.is_cross_attention:
|
||||
concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
|
||||
concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data])
|
||||
out_features, in_features = concatenated_weights.shape
|
||||
|
||||
@@ -76,6 +76,7 @@ class WanVACETransformerBlock(nn.Module):
|
||||
eps=eps,
|
||||
added_kv_proj_dim=added_kv_proj_dim,
|
||||
processor=WanAttnProcessor(),
|
||||
is_cross_attention=True,
|
||||
)
|
||||
self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity()
|
||||
|
||||
@@ -178,6 +179,7 @@ class WanVACETransformer3DModel(
|
||||
_no_split_modules = ["WanTransformerBlock", "WanVACETransformerBlock"]
|
||||
_keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"]
|
||||
_keys_to_ignore_on_load_unexpected = ["norm_added_q"]
|
||||
_repeated_blocks = ["WanTransformerBlock", "WanVACETransformerBlock"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
|
||||
@@ -167,6 +167,7 @@ else:
|
||||
_import_structure["consisid"] = ["ConsisIDPipeline"]
|
||||
_import_structure["cosmos"] = [
|
||||
"Cosmos2_5_PredictBasePipeline",
|
||||
"Cosmos2_5_TransferPipeline",
|
||||
"Cosmos2TextToImagePipeline",
|
||||
"CosmosTextToWorldPipeline",
|
||||
"CosmosVideoToWorldPipeline",
|
||||
@@ -631,6 +632,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
)
|
||||
from .cosmos import (
|
||||
Cosmos2_5_PredictBasePipeline,
|
||||
Cosmos2_5_TransferPipeline,
|
||||
Cosmos2TextToImagePipeline,
|
||||
Cosmos2VideoToWorldPipeline,
|
||||
CosmosTextToWorldPipeline,
|
||||
|
||||
@@ -25,6 +25,9 @@ else:
|
||||
_import_structure["pipeline_cosmos2_5_predict"] = [
|
||||
"Cosmos2_5_PredictBasePipeline",
|
||||
]
|
||||
_import_structure["pipeline_cosmos2_5_transfer"] = [
|
||||
"Cosmos2_5_TransferPipeline",
|
||||
]
|
||||
_import_structure["pipeline_cosmos2_text2image"] = ["Cosmos2TextToImagePipeline"]
|
||||
_import_structure["pipeline_cosmos2_video2world"] = ["Cosmos2VideoToWorldPipeline"]
|
||||
_import_structure["pipeline_cosmos_text2world"] = ["CosmosTextToWorldPipeline"]
|
||||
@@ -41,6 +44,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .pipeline_cosmos2_5_predict import (
|
||||
Cosmos2_5_PredictBasePipeline,
|
||||
)
|
||||
from .pipeline_cosmos2_5_transfer import Cosmos2_5_TransferPipeline
|
||||
from .pipeline_cosmos2_text2image import Cosmos2TextToImagePipeline
|
||||
from .pipeline_cosmos2_video2world import Cosmos2VideoToWorldPipeline
|
||||
from .pipeline_cosmos_text2world import CosmosTextToWorldPipeline
|
||||
|
||||
923
src/diffusers/pipelines/cosmos/pipeline_cosmos2_5_transfer.py
Normal file
923
src/diffusers/pipelines/cosmos/pipeline_cosmos2_5_transfer.py
Normal file
@@ -0,0 +1,923 @@
|
||||
# Copyright 2025 The NVIDIA Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
import torch
|
||||
import torchvision
|
||||
import torchvision.transforms
|
||||
import torchvision.transforms.functional
|
||||
from transformers import AutoTokenizer, Qwen2_5_VLForConditionalGeneration
|
||||
|
||||
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from ...image_processor import PipelineImageInput
|
||||
from ...models import AutoencoderKLWan, CosmosControlNetModel, CosmosTransformer3DModel
|
||||
from ...schedulers import UniPCMultistepScheduler
|
||||
from ...utils import is_cosmos_guardrail_available, is_torch_xla_available, logging, replace_example_docstring
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ...video_processor import VideoProcessor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from .pipeline_output import CosmosPipelineOutput
|
||||
|
||||
|
||||
if is_cosmos_guardrail_available():
|
||||
from cosmos_guardrail import CosmosSafetyChecker
|
||||
else:
|
||||
|
||||
class CosmosSafetyChecker:
|
||||
def __init__(self, *args, **kwargs):
|
||||
raise ImportError(
|
||||
"`cosmos_guardrail` is not installed. Please install it to use the safety checker for Cosmos: `pip install cosmos_guardrail`."
|
||||
)
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
XLA_AVAILABLE = True
|
||||
else:
|
||||
XLA_AVAILABLE = False
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def _maybe_pad_video(video: torch.Tensor, num_frames: int):
|
||||
n_pad_frames = num_frames - video.shape[2]
|
||||
if n_pad_frames > 0:
|
||||
last_frame = video[:, :, -1:, :, :]
|
||||
video = torch.cat((video, last_frame.repeat(1, 1, n_pad_frames, 1, 1)), dim=2)
|
||||
return video
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
):
|
||||
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
||||
return encoder_output.latent_dist.sample(generator)
|
||||
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
||||
return encoder_output.latent_dist.mode()
|
||||
elif hasattr(encoder_output, "latents"):
|
||||
return encoder_output.latents
|
||||
else:
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
DEFAULT_NEGATIVE_PROMPT = "The video captures a series of frames showing ugly scenes, static with no motion, motion blur, over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. Overall, the video is of poor quality."
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```python
|
||||
>>> import cv2
|
||||
>>> import numpy as np
|
||||
>>> import torch
|
||||
>>> from diffusers import Cosmos2_5_TransferPipeline, AutoModel
|
||||
>>> from diffusers.utils import export_to_video, load_video
|
||||
|
||||
>>> model_id = "nvidia/Cosmos-Transfer2.5-2B"
|
||||
>>> # Load a Transfer2.5 controlnet variant (edge, depth, seg, or blur)
|
||||
>>> controlnet = AutoModel.from_pretrained(model_id, revision="diffusers/controlnet/general/edge")
|
||||
>>> pipe = Cosmos2_5_TransferPipeline.from_pretrained(
|
||||
... model_id, controlnet=controlnet, revision="diffusers/general", torch_dtype=torch.bfloat16
|
||||
... )
|
||||
>>> pipe = pipe.to("cuda")
|
||||
|
||||
>>> # Video2World with edge control: Generate video guided by edge maps extracted from input video.
|
||||
>>> prompt = (
|
||||
... "The video is a demonstration of robotic manipulation, likely in a laboratory or testing environment. It"
|
||||
... "features two robotic arms interacting with a piece of blue fabric. The setting is a room with a beige"
|
||||
... "couch in the background, providing a neutral backdrop for the robotic activity. The robotic arms are"
|
||||
... "positioned on either side of the fabric, which is placed on a yellow cushion. The left robotic arm is"
|
||||
... "white with a black gripper, while the right arm is black with a more complex, articulated gripper. At the"
|
||||
... "beginning, the fabric is laid out on the cushion. The left robotic arm approaches the fabric, its gripper"
|
||||
... "opening and closing as it positions itself. The right arm remains stationary initially, poised to assist."
|
||||
... "As the video progresses, the left arm grips the fabric, lifting it slightly off the cushion. The right arm"
|
||||
... "then moves in, its gripper adjusting to grasp the opposite side of the fabric. Both arms work in"
|
||||
... "coordination, lifting and holding the fabric between them. The fabric is manipulated with precision,"
|
||||
... "showcasing the dexterity and control of the robotic arms. The camera remains static throughout, focusing"
|
||||
... "on the interaction between the robotic arms and the fabric, allowing viewers to observe the detailed"
|
||||
... "movements and coordination involved in the task."
|
||||
... )
|
||||
>>> negative_prompt = (
|
||||
... "The video captures a series of frames showing ugly scenes, static with no motion, motion blur, "
|
||||
... "over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, "
|
||||
... "underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, jerky "
|
||||
... "movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, "
|
||||
... "fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. "
|
||||
... "Overall, the video is of poor quality."
|
||||
... )
|
||||
>>> input_video = load_video(
|
||||
... "https://github.com/nvidia-cosmos/cosmos-transfer2.5/raw/refs/heads/main/assets/robot_example/robot_input.mp4"
|
||||
... )
|
||||
>>> num_frames = 93
|
||||
|
||||
>>> # Extract edge maps from the input video using Canny edge detection
|
||||
>>> edge_maps = [
|
||||
... cv2.Canny(cv2.cvtColor(np.array(frame.convert("RGB")), cv2.COLOR_RGB2BGR), 100, 200)
|
||||
... for frame in input_video[:num_frames]
|
||||
... ]
|
||||
>>> edge_maps = np.stack(edge_maps)[None] # (T, H, W) -> (1, T, H, W)
|
||||
>>> controls = torch.from_numpy(edge_maps).expand(3, -1, -1, -1) # (1, T, H, W) -> (3, T, H, W)
|
||||
>>> controls = [Image.fromarray(x.numpy()) for x in controls.permute(1, 2, 3, 0)]
|
||||
>>> export_to_video(controls, "edge_controlled_video_edge.mp4", fps=30)
|
||||
|
||||
>>> video = pipe(
|
||||
... video=input_video[:num_frames],
|
||||
... controls=controls,
|
||||
... controls_conditioning_scale=1.0,
|
||||
... prompt=prompt,
|
||||
... negative_prompt=negative_prompt,
|
||||
... num_frames=num_frames,
|
||||
... ).frames[0]
|
||||
>>> export_to_video(video, "edge_controlled_video.mp4", fps=30)
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
class Cosmos2_5_TransferPipeline(DiffusionPipeline):
|
||||
r"""
|
||||
Pipeline for Cosmos Transfer2.5 base model.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
||||
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
||||
|
||||
Args:
|
||||
text_encoder ([`Qwen2_5_VLForConditionalGeneration`]):
|
||||
Frozen text-encoder. Cosmos Transfer2.5 uses the [Qwen2.5
|
||||
VL](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) encoder.
|
||||
tokenizer (`AutoTokenizer`):
|
||||
Tokenizer associated with the Qwen2.5 VL encoder.
|
||||
transformer ([`CosmosTransformer3DModel`]):
|
||||
Conditional Transformer to denoise the encoded image latents.
|
||||
scheduler ([`UniPCMultistepScheduler`]):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
||||
vae ([`AutoencoderKLWan`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->transformer->controlnet->vae"
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
||||
# We mark safety_checker as optional here to get around some test failures, but it is not really optional
|
||||
_optional_components = ["safety_checker", "controlnet"]
|
||||
_exclude_from_cpu_offload = ["safety_checker"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
text_encoder: Qwen2_5_VLForConditionalGeneration,
|
||||
tokenizer: AutoTokenizer,
|
||||
transformer: CosmosTransformer3DModel,
|
||||
vae: AutoencoderKLWan,
|
||||
scheduler: UniPCMultistepScheduler,
|
||||
controlnet: Optional[CosmosControlNetModel],
|
||||
safety_checker: CosmosSafetyChecker = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if safety_checker is None:
|
||||
safety_checker = CosmosSafetyChecker()
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
transformer=transformer,
|
||||
controlnet=controlnet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=safety_checker,
|
||||
)
|
||||
|
||||
self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4
|
||||
self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
|
||||
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
|
||||
|
||||
latents_mean = (
|
||||
torch.tensor(self.vae.config.latents_mean).view(1, self.vae.config.z_dim, 1, 1, 1).float()
|
||||
if getattr(self.vae.config, "latents_mean", None) is not None
|
||||
else None
|
||||
)
|
||||
latents_std = (
|
||||
torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).float()
|
||||
if getattr(self.vae.config, "latents_std", None) is not None
|
||||
else None
|
||||
)
|
||||
self.latents_mean = latents_mean
|
||||
self.latents_std = latents_std
|
||||
|
||||
if self.latents_mean is None or self.latents_std is None:
|
||||
raise ValueError("VAE configuration must define both `latents_mean` and `latents_std`.")
|
||||
|
||||
def _get_prompt_embeds(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
max_sequence_length: int = 512,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
device = device or self._execution_device
|
||||
dtype = dtype or self.text_encoder.dtype
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
input_ids_batch = []
|
||||
|
||||
for sample_idx in range(len(prompt)):
|
||||
conversations = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "You are a helpful assistant who will provide prompts to an image generator.",
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": prompt[sample_idx],
|
||||
}
|
||||
],
|
||||
},
|
||||
]
|
||||
input_ids = self.tokenizer.apply_chat_template(
|
||||
conversations,
|
||||
tokenize=True,
|
||||
add_generation_prompt=False,
|
||||
add_vision_id=False,
|
||||
max_length=max_sequence_length,
|
||||
truncation=True,
|
||||
padding="max_length",
|
||||
)
|
||||
input_ids = torch.LongTensor(input_ids)
|
||||
input_ids_batch.append(input_ids)
|
||||
|
||||
input_ids_batch = torch.stack(input_ids_batch, dim=0)
|
||||
|
||||
outputs = self.text_encoder(
|
||||
input_ids_batch.to(device),
|
||||
output_hidden_states=True,
|
||||
)
|
||||
hidden_states = outputs.hidden_states
|
||||
|
||||
normalized_hidden_states = []
|
||||
for layer_idx in range(1, len(hidden_states)):
|
||||
normalized_state = (hidden_states[layer_idx] - hidden_states[layer_idx].mean(dim=-1, keepdim=True)) / (
|
||||
hidden_states[layer_idx].std(dim=-1, keepdim=True) + 1e-8
|
||||
)
|
||||
normalized_hidden_states.append(normalized_state)
|
||||
|
||||
prompt_embeds = torch.cat(normalized_hidden_states, dim=-1)
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
# Modified from diffusers.pipelines.cosmos.pipeline_cosmos_text2world.CosmosTextToWorldPipeline.encode_prompt
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
num_videos_per_prompt: int = 1,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
max_sequence_length: int = 512,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use classifier free guidance or not.
|
||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
device: (`torch.device`, *optional*):
|
||||
torch device
|
||||
dtype: (`torch.dtype`, *optional*):
|
||||
torch dtype
|
||||
"""
|
||||
device = device or self._execution_device
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
if prompt is not None:
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if prompt_embeds is None:
|
||||
prompt_embeds = self._get_prompt_embeds(
|
||||
prompt=prompt, max_sequence_length=max_sequence_length, device=device, dtype=dtype
|
||||
)
|
||||
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
_, seq_len, _ = prompt_embeds.shape
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
||||
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
negative_prompt = negative_prompt or ""
|
||||
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
||||
|
||||
if prompt is not None and type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
|
||||
negative_prompt_embeds = self._get_prompt_embeds(
|
||||
prompt=negative_prompt, max_sequence_length=max_sequence_length, device=device, dtype=dtype
|
||||
)
|
||||
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
_, seq_len, _ = negative_prompt_embeds.shape
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
# Modified from diffusers.pipelines.cosmos.pipeline_cosmos2_video2world.Cosmos2VideoToWorldPipeline.prepare_latents and
|
||||
# diffusers.pipelines.cosmos.pipeline_cosmos2_video2world.Cosmos2TextToImagePipeline.prepare_latents
|
||||
def prepare_latents(
|
||||
self,
|
||||
video: Optional[torch.Tensor],
|
||||
batch_size: int,
|
||||
num_channels_latents: int = 16,
|
||||
height: int = 704,
|
||||
width: int = 1280,
|
||||
num_frames_in: int = 93,
|
||||
num_frames_out: int = 93,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
B = batch_size
|
||||
C = num_channels_latents
|
||||
T = (num_frames_out - 1) // self.vae_scale_factor_temporal + 1
|
||||
H = height // self.vae_scale_factor_spatial
|
||||
W = width // self.vae_scale_factor_spatial
|
||||
shape = (B, C, T, H, W)
|
||||
|
||||
if num_frames_in == 0:
|
||||
if latents is None:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
|
||||
cond_mask = torch.zeros((B, 1, T, H, W), dtype=latents.dtype, device=latents.device)
|
||||
cond_indicator = torch.zeros((B, 1, T, 1, 1), dtype=latents.dtype, device=latents.device)
|
||||
|
||||
cond_latents = torch.zeros_like(latents)
|
||||
|
||||
return (
|
||||
latents,
|
||||
cond_latents,
|
||||
cond_mask,
|
||||
cond_indicator,
|
||||
)
|
||||
else:
|
||||
if video is None:
|
||||
raise ValueError("`video` must be provided when `num_frames_in` is greater than 0.")
|
||||
video = video.to(device=device, dtype=self.vae.dtype)
|
||||
if isinstance(generator, list):
|
||||
cond_latents = [
|
||||
retrieve_latents(self.vae.encode(video[i].unsqueeze(0)), generator=generator[i])
|
||||
for i in range(batch_size)
|
||||
]
|
||||
else:
|
||||
cond_latents = [retrieve_latents(self.vae.encode(vid.unsqueeze(0)), generator) for vid in video]
|
||||
|
||||
cond_latents = torch.cat(cond_latents, dim=0).to(dtype)
|
||||
|
||||
latents_mean = self.latents_mean.to(device=device, dtype=dtype)
|
||||
latents_std = self.latents_std.to(device=device, dtype=dtype)
|
||||
cond_latents = (cond_latents - latents_mean) / latents_std
|
||||
|
||||
if latents is None:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
else:
|
||||
latents = latents.to(device=device, dtype=dtype)
|
||||
|
||||
padding_shape = (B, 1, T, H, W)
|
||||
ones_padding = latents.new_ones(padding_shape)
|
||||
zeros_padding = latents.new_zeros(padding_shape)
|
||||
|
||||
cond_indicator = latents.new_zeros(1, 1, latents.size(2), 1, 1)
|
||||
cond_mask = cond_indicator * ones_padding + (1 - cond_indicator) * zeros_padding
|
||||
|
||||
return (
|
||||
latents,
|
||||
cond_latents,
|
||||
cond_mask,
|
||||
cond_indicator,
|
||||
)
|
||||
|
||||
def _encode_controls(
|
||||
self,
|
||||
controls: Optional[torch.Tensor],
|
||||
height: int,
|
||||
width: int,
|
||||
num_frames: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]],
|
||||
) -> Optional[torch.Tensor]:
|
||||
if controls is None:
|
||||
return None
|
||||
|
||||
control_video = self.video_processor.preprocess_video(controls, height, width)
|
||||
control_video = _maybe_pad_video(control_video, num_frames)
|
||||
|
||||
control_video = control_video.to(device=device, dtype=self.vae.dtype)
|
||||
control_latents = [
|
||||
retrieve_latents(self.vae.encode(vid.unsqueeze(0)), generator=generator) for vid in control_video
|
||||
]
|
||||
control_latents = torch.cat(control_latents, dim=0).to(dtype)
|
||||
|
||||
latents_mean = self.latents_mean.to(device=device, dtype=dtype)
|
||||
latents_std = self.latents_std.to(device=device, dtype=dtype)
|
||||
control_latents = (control_latents - latents_mean) / latents_std
|
||||
return control_latents
|
||||
|
||||
# Copied from diffusers.pipelines.cosmos.pipeline_cosmos_text2world.CosmosTextToWorldPipeline.check_inputs
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds=None,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
):
|
||||
if height % 16 != 0 or width % 16 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
return self._guidance_scale > 1.0
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def current_timestep(self):
|
||||
return self._current_timestep
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
image: PipelineImageInput | None = None,
|
||||
video: List[PipelineImageInput] | None = None,
|
||||
prompt: Union[str, List[str]] | None = None,
|
||||
negative_prompt: Union[str, List[str]] = DEFAULT_NEGATIVE_PROMPT,
|
||||
height: int = 704,
|
||||
width: Optional[int] = None,
|
||||
num_frames: int = 93,
|
||||
num_inference_steps: int = 36,
|
||||
guidance_scale: float = 3.0,
|
||||
num_videos_per_prompt: Optional[int] = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
controls: Optional[PipelineImageInput | List[PipelineImageInput]] = None,
|
||||
controls_conditioning_scale: Union[float, List[float]] = 1.0,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback_on_step_end: Optional[
|
||||
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
||||
] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
max_sequence_length: int = 512,
|
||||
conditional_frame_timestep: float = 0.1,
|
||||
):
|
||||
r"""
|
||||
The call function to the pipeline for generation. Supports three modes:
|
||||
|
||||
- **Text2World**: `image=None`, `video=None`, `prompt` provided. Generates a world clip.
|
||||
- **Image2World**: `image` provided, `video=None`, `prompt` provided. Conditions on a single frame.
|
||||
- **Video2World**: `video` provided, `image=None`, `prompt` provided. Conditions on an input clip.
|
||||
|
||||
Set `num_frames=93` (default) to produce a world video, or `num_frames=1` to produce a single image frame (the
|
||||
above in "*2Image mode").
|
||||
|
||||
Outputs follow `output_type` (e.g., `"pil"` returns a list of `num_frames` PIL images per prompt).
|
||||
|
||||
Args:
|
||||
image (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, *optional*):
|
||||
Optional single image for Image2World conditioning. Must be `None` when `video` is provided.
|
||||
video (`List[PIL.Image.Image]`, `np.ndarray`, `torch.Tensor`, *optional*):
|
||||
Optional input video for Video2World conditioning. Must be `None` when `image` is provided.
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide generation. Required unless `prompt_embeds` is supplied.
|
||||
height (`int`, defaults to `704`):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*):
|
||||
The width in pixels of the generated image. If not provided, this will be determined based on the
|
||||
aspect ratio of the input and the provided height.
|
||||
num_frames (`int`, defaults to `93`):
|
||||
Number of output frames. Use `93` for world (video) generation; set to `1` to return a single frame.
|
||||
num_inference_steps (`int`, defaults to `35`):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
guidance_scale (`float`, defaults to `3.0`):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion
|
||||
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
|
||||
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
|
||||
`guidance_scale > 1`.
|
||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
||||
generation deterministic.
|
||||
latents (`torch.Tensor`, *optional*):
|
||||
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor is generated by sampling using the supplied random `generator`.
|
||||
controls (`PipelineImageInput`, `List[PipelineImageInput]`, *optional*):
|
||||
Control image or video input used by the ControlNet. If `None`, ControlNet is skipped.
|
||||
controls_conditioning_scale (`float` or `List[float]`, *optional*, defaults to `1.0`):
|
||||
The scale factor(s) for the ControlNet outputs. A single float is broadcast to all control blocks.
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not
|
||||
provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`CosmosPipelineOutput`] instead of a plain tuple.
|
||||
callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
|
||||
A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
|
||||
each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
|
||||
DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
|
||||
list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||
max_sequence_length (`int`, defaults to `512`):
|
||||
The maximum number of tokens in the prompt. If the prompt exceeds this length, it will be truncated. If
|
||||
the prompt is shorter than this length, it will be padded.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~CosmosPipelineOutput`] or `tuple`:
|
||||
If `return_dict` is `True`, [`CosmosPipelineOutput`] is returned, otherwise a `tuple` is returned where
|
||||
the first element is a list with the generated images and the second element is a list of `bool`s
|
||||
indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
|
||||
"""
|
||||
if self.safety_checker is None:
|
||||
raise ValueError(
|
||||
f"You have disabled the safety checker for {self.__class__}. This is in violation of the "
|
||||
"[NVIDIA Open Model License Agreement](https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-open-model-license). "
|
||||
f"Please ensure that you are compliant with the license agreement."
|
||||
)
|
||||
|
||||
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
||||
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
||||
|
||||
if width is None:
|
||||
frame = image or video[0] if image or video else None
|
||||
if frame is None and controls is not None:
|
||||
frame = controls[0] if isinstance(controls, list) else controls
|
||||
if isinstance(frame, (torch.Tensor, np.ndarray)) and len(frame.shape) == 4:
|
||||
frame = controls[0]
|
||||
|
||||
if frame is None:
|
||||
width = int((height + 16) * (1280 / 720))
|
||||
elif isinstance(frame, PIL.Image.Image):
|
||||
width = int((height + 16) * (frame.width / frame.height))
|
||||
else:
|
||||
width = int((height + 16) * (frame.shape[2] / frame.shape[1])) # NOTE: assuming C H W
|
||||
|
||||
# Check inputs. Raise error if not correct
|
||||
self.check_inputs(prompt, height, width, prompt_embeds, callback_on_step_end_tensor_inputs)
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._current_timestep = None
|
||||
self._interrupt = False
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
if self.safety_checker is not None:
|
||||
self.safety_checker.to(device)
|
||||
if prompt is not None:
|
||||
prompt_list = [prompt] if isinstance(prompt, str) else prompt
|
||||
for p in prompt_list:
|
||||
if not self.safety_checker.check_text_safety(p):
|
||||
raise ValueError(
|
||||
f"Cosmos Guardrail detected unsafe text in the prompt: {p}. Please ensure that the "
|
||||
f"prompt abides by the NVIDIA Open Model License Agreement."
|
||||
)
|
||||
|
||||
# Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
# Encode input prompt
|
||||
(
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
) = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
device=device,
|
||||
max_sequence_length=max_sequence_length,
|
||||
)
|
||||
|
||||
vae_dtype = self.vae.dtype
|
||||
transformer_dtype = self.transformer.dtype
|
||||
|
||||
img_context = torch.zeros(
|
||||
batch_size,
|
||||
self.transformer.config.img_context_num_tokens,
|
||||
self.transformer.config.img_context_dim_in,
|
||||
device=prompt_embeds.device,
|
||||
dtype=transformer_dtype,
|
||||
)
|
||||
encoder_hidden_states = (prompt_embeds, img_context)
|
||||
neg_encoder_hidden_states = (negative_prompt_embeds, img_context)
|
||||
|
||||
num_frames_in = None
|
||||
if image is not None:
|
||||
if batch_size != 1:
|
||||
raise ValueError(f"batch_size must be 1 for image input (given {batch_size})")
|
||||
|
||||
image = torchvision.transforms.functional.to_tensor(image).unsqueeze(0)
|
||||
video = torch.cat([image, torch.zeros_like(image).repeat(num_frames - 1, 1, 1, 1)], dim=0)
|
||||
video = video.unsqueeze(0)
|
||||
num_frames_in = 1
|
||||
elif video is None:
|
||||
video = torch.zeros(batch_size, num_frames, 3, height, width, dtype=torch.uint8)
|
||||
num_frames_in = 0
|
||||
else:
|
||||
num_frames_in = len(video)
|
||||
|
||||
if batch_size != 1:
|
||||
raise ValueError(f"batch_size must be 1 for video input (given {batch_size})")
|
||||
|
||||
assert video is not None
|
||||
video = self.video_processor.preprocess_video(video, height, width)
|
||||
|
||||
# pad with last frame (for video2world)
|
||||
num_frames_out = num_frames
|
||||
video = _maybe_pad_video(video, num_frames_out)
|
||||
assert num_frames_in <= num_frames_out, f"expected ({num_frames_in=}) <= ({num_frames_out=})"
|
||||
|
||||
video = video.to(device=device, dtype=vae_dtype)
|
||||
|
||||
num_channels_latents = self.transformer.config.in_channels - 1
|
||||
latents, cond_latent, cond_mask, cond_indicator = self.prepare_latents(
|
||||
video=video,
|
||||
batch_size=batch_size * num_videos_per_prompt,
|
||||
num_channels_latents=num_channels_latents,
|
||||
height=height,
|
||||
width=width,
|
||||
num_frames_in=num_frames_in,
|
||||
num_frames_out=num_frames,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
generator=generator,
|
||||
latents=latents,
|
||||
)
|
||||
cond_timestep = torch.ones_like(cond_indicator) * conditional_frame_timestep
|
||||
cond_mask = cond_mask.to(transformer_dtype)
|
||||
|
||||
controls_latents = None
|
||||
if controls is not None:
|
||||
controls_latents = self._encode_controls(
|
||||
controls,
|
||||
height=height,
|
||||
width=width,
|
||||
num_frames=num_frames,
|
||||
dtype=transformer_dtype,
|
||||
device=device,
|
||||
generator=generator,
|
||||
)
|
||||
|
||||
padding_mask = latents.new_zeros(1, 1, height, width, dtype=transformer_dtype)
|
||||
|
||||
# Denoising loop
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps = self.scheduler.timesteps
|
||||
self._num_timesteps = len(timesteps)
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
|
||||
gt_velocity = (latents - cond_latent) * cond_mask
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
self._current_timestep = t.cpu().item()
|
||||
|
||||
# NOTE: assumes sigma(t) \in [0, 1]
|
||||
sigma_t = (
|
||||
torch.tensor(self.scheduler.sigmas[i].item())
|
||||
.unsqueeze(0)
|
||||
.to(device=device, dtype=transformer_dtype)
|
||||
)
|
||||
|
||||
in_latents = cond_mask * cond_latent + (1 - cond_mask) * latents
|
||||
in_latents = in_latents.to(transformer_dtype)
|
||||
in_timestep = cond_indicator * cond_timestep + (1 - cond_indicator) * sigma_t
|
||||
control_blocks = None
|
||||
if controls_latents is not None and self.controlnet is not None:
|
||||
control_output = self.controlnet(
|
||||
controls_latents=controls_latents,
|
||||
latents=in_latents,
|
||||
timestep=in_timestep,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
condition_mask=cond_mask,
|
||||
conditioning_scale=controls_conditioning_scale,
|
||||
padding_mask=padding_mask,
|
||||
return_dict=False,
|
||||
)
|
||||
control_blocks = control_output[0]
|
||||
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=in_latents,
|
||||
timestep=in_timestep,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
block_controlnet_hidden_states=control_blocks,
|
||||
condition_mask=cond_mask,
|
||||
padding_mask=padding_mask,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
noise_pred = gt_velocity + noise_pred * (1 - cond_mask)
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
control_blocks = None
|
||||
if controls_latents is not None and self.controlnet is not None:
|
||||
control_output = self.controlnet(
|
||||
controls_latents=controls_latents,
|
||||
latents=in_latents,
|
||||
timestep=in_timestep,
|
||||
encoder_hidden_states=neg_encoder_hidden_states, # NOTE: negative prompt
|
||||
condition_mask=cond_mask,
|
||||
conditioning_scale=controls_conditioning_scale,
|
||||
padding_mask=padding_mask,
|
||||
return_dict=False,
|
||||
)
|
||||
control_blocks = control_output[0]
|
||||
|
||||
noise_pred_neg = self.transformer(
|
||||
hidden_states=in_latents,
|
||||
timestep=in_timestep,
|
||||
encoder_hidden_states=neg_encoder_hidden_states, # NOTE: negative prompt
|
||||
block_controlnet_hidden_states=control_blocks,
|
||||
condition_mask=cond_mask,
|
||||
padding_mask=padding_mask,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
# NOTE: replace velocity (noise_pred_neg) with gt_velocity for conditioning inputs only
|
||||
noise_pred_neg = gt_velocity + noise_pred_neg * (1 - cond_mask)
|
||||
noise_pred = noise_pred + self.guidance_scale * (noise_pred - noise_pred_neg)
|
||||
|
||||
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
self._current_timestep = None
|
||||
|
||||
if not output_type == "latent":
|
||||
latents_mean = self.latents_mean.to(latents.device, latents.dtype)
|
||||
latents_std = self.latents_std.to(latents.device, latents.dtype)
|
||||
latents = latents * latents_std + latents_mean
|
||||
video = self.vae.decode(latents.to(self.vae.dtype), return_dict=False)[0]
|
||||
video = self._match_num_frames(video, num_frames)
|
||||
|
||||
assert self.safety_checker is not None
|
||||
self.safety_checker.to(device)
|
||||
video = self.video_processor.postprocess_video(video, output_type="np")
|
||||
video = (video * 255).astype(np.uint8)
|
||||
video_batch = []
|
||||
for vid in video:
|
||||
vid = self.safety_checker.check_video_safety(vid)
|
||||
if vid is None:
|
||||
video_batch.append(np.zeros_like(video[0]))
|
||||
else:
|
||||
video_batch.append(vid)
|
||||
video = np.stack(video_batch).astype(np.float32) / 255.0 * 2 - 1
|
||||
video = torch.from_numpy(video).permute(0, 4, 1, 2, 3)
|
||||
video = self.video_processor.postprocess_video(video, output_type=output_type)
|
||||
else:
|
||||
video = latents
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (video,)
|
||||
|
||||
return CosmosPipelineOutput(frames=video)
|
||||
|
||||
def _match_num_frames(self, video: torch.Tensor, target_num_frames: int) -> torch.Tensor:
|
||||
if target_num_frames <= 0 or video.shape[2] == target_num_frames:
|
||||
return video
|
||||
|
||||
frames_per_latent = max(self.vae_scale_factor_temporal, 1)
|
||||
video = torch.repeat_interleave(video, repeats=frames_per_latent, dim=2)
|
||||
|
||||
current_frames = video.shape[2]
|
||||
if current_frames < target_num_frames:
|
||||
pad = video[:, :, -1:, :, :].repeat(1, 1, target_num_frames - current_frames, 1, 1)
|
||||
video = torch.cat([video, pad], dim=2)
|
||||
elif current_frames > target_num_frames:
|
||||
video = video[:, :, :target_num_frames]
|
||||
|
||||
return video
|
||||
@@ -658,12 +658,7 @@ class GlmImagePipeline(DiffusionPipeline):
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
if prompt is not None and prior_token_ids is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prior_token_ids`: {prior_token_ids}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prior_token_ids is None:
|
||||
if prompt is None and prior_token_ids is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prior_token_ids`. Cannot leave both `prompt` and `prior_token_ids` undefined."
|
||||
)
|
||||
@@ -694,8 +689,8 @@ class GlmImagePipeline(DiffusionPipeline):
|
||||
"for i2i mode, as the images are needed for VAE encoding to build the KV cache."
|
||||
)
|
||||
|
||||
if prior_token_ids is not None and prompt_embeds is None:
|
||||
raise ValueError("`prompt_embeds` must also be provided with `prior_token_ids`.")
|
||||
if prior_token_ids is not None and prompt_embeds is None and prompt is None:
|
||||
raise ValueError("`prompt_embeds` or `prompt` must also be provided with `prior_token_ids`.")
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
|
||||
@@ -13,12 +13,20 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from collections.abc import Iterator
|
||||
from fractions import Fraction
|
||||
from typing import Optional
|
||||
from itertools import chain
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from ...utils import is_av_available
|
||||
from ...utils import get_logger, is_av_available
|
||||
|
||||
|
||||
logger = get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
_CAN_USE_AV = is_av_available()
|
||||
@@ -101,11 +109,59 @@ def _write_audio(
|
||||
|
||||
|
||||
def encode_video(
|
||||
video: torch.Tensor, fps: int, audio: Optional[torch.Tensor], audio_sample_rate: Optional[int], output_path: str
|
||||
video: Union[List[PIL.Image.Image], np.ndarray, torch.Tensor, Iterator[torch.Tensor]],
|
||||
fps: int,
|
||||
audio: Optional[torch.Tensor],
|
||||
audio_sample_rate: Optional[int],
|
||||
output_path: str,
|
||||
video_chunks_number: int = 1,
|
||||
) -> None:
|
||||
video_np = video.cpu().numpy()
|
||||
"""
|
||||
Encodes a video with audio using the PyAV library. Based on code from the original LTX-2 repo:
|
||||
https://github.com/Lightricks/LTX-2/blob/4f410820b198e05074a1e92de793e3b59e9ab5a0/packages/ltx-pipelines/src/ltx_pipelines/utils/media_io.py#L182
|
||||
|
||||
_, height, width, _ = video_np.shape
|
||||
Args:
|
||||
video (`List[PIL.Image.Image]` or `np.ndarray` or `torch.Tensor`):
|
||||
A video tensor of shape [frames, height, width, channels] with integer pixel values in [0, 255]. If the
|
||||
input is a `np.ndarray`, it is expected to be a float array with values in [0, 1] (which is what pipelines
|
||||
usually return with `output_type="np"`).
|
||||
fps (`int`)
|
||||
The frames per second (FPS) of the encoded video.
|
||||
audio (`torch.Tensor`, *optional*):
|
||||
An audio waveform of shape [audio_channels, samples].
|
||||
audio_sample_rate: (`int`, *optional*):
|
||||
The sampling rate of the audio waveform. For LTX 2, this is typically 24000 (24 kHz).
|
||||
output_path (`str`):
|
||||
The path to save the encoded video to.
|
||||
video_chunks_number (`int`, *optional*, defaults to `1`):
|
||||
The number of chunks to split the video into for encoding. Each chunk will be encoded separately. The
|
||||
number of chunks to use often depends on the tiling config for the video VAE.
|
||||
"""
|
||||
if isinstance(video, list) and isinstance(video[0], PIL.Image.Image):
|
||||
# Pipeline output_type="pil"; assumes each image is in "RGB" mode
|
||||
video_frames = [np.array(frame) for frame in video]
|
||||
video = np.stack(video_frames, axis=0)
|
||||
video = torch.from_numpy(video)
|
||||
elif isinstance(video, np.ndarray):
|
||||
# Pipeline output_type="np"
|
||||
is_denormalized = np.logical_and(np.zeros_like(video) <= video, video <= np.ones_like(video))
|
||||
if np.all(is_denormalized):
|
||||
video = (video * 255).round().astype("uint8")
|
||||
else:
|
||||
logger.warning(
|
||||
"Supplied `numpy.ndarray` does not have values in [0, 1]. The values will be assumed to be pixel "
|
||||
"values in [0, ..., 255] and will be used as is."
|
||||
)
|
||||
video = torch.from_numpy(video)
|
||||
|
||||
if isinstance(video, torch.Tensor):
|
||||
# Split into video_chunks_number along the frame dimension
|
||||
video = torch.tensor_split(video, video_chunks_number, dim=0)
|
||||
video = iter(video)
|
||||
|
||||
first_chunk = next(video)
|
||||
|
||||
_, height, width, _ = first_chunk.shape
|
||||
|
||||
container = av.open(output_path, mode="w")
|
||||
stream = container.add_stream("libx264", rate=int(fps))
|
||||
@@ -119,10 +175,12 @@ def encode_video(
|
||||
|
||||
audio_stream = _prepare_audio_stream(container, audio_sample_rate)
|
||||
|
||||
for frame_array in video_np:
|
||||
frame = av.VideoFrame.from_ndarray(frame_array, format="rgb24")
|
||||
for packet in stream.encode(frame):
|
||||
container.mux(packet)
|
||||
for video_chunk in tqdm(chain([first_chunk], video), total=video_chunks_number, desc="Encoding video chunks"):
|
||||
video_chunk_cpu = video_chunk.to("cpu").numpy()
|
||||
for frame_array in video_chunk_cpu:
|
||||
frame = av.VideoFrame.from_ndarray(frame_array, format="rgb24")
|
||||
for packet in stream.encode(frame):
|
||||
container.mux(packet)
|
||||
|
||||
# Flush encoder
|
||||
for packet in stream.encode():
|
||||
|
||||
@@ -69,8 +69,6 @@ EXAMPLE_DOC_STRING = """
|
||||
... output_type="np",
|
||||
... return_dict=False,
|
||||
... )
|
||||
>>> video = (video * 255).round().astype("uint8")
|
||||
>>> video = torch.from_numpy(video)
|
||||
|
||||
>>> encode_video(
|
||||
... video[0],
|
||||
@@ -1083,6 +1081,10 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
|
||||
audio_coords = self.transformer.audio_rope.prepare_audio_coords(
|
||||
audio_latents.shape[0], audio_num_frames, audio_latents.device
|
||||
)
|
||||
# Duplicate the positional ids as well if using CFG
|
||||
if self.do_classifier_free_guidance:
|
||||
video_coords = video_coords.repeat((2,) + (1,) * (video_coords.ndim - 1)) # Repeat twice in batch dim
|
||||
audio_coords = audio_coords.repeat((2,) + (1,) * (audio_coords.ndim - 1))
|
||||
|
||||
# 7. Denoising loop
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
|
||||
@@ -48,7 +48,7 @@ EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> import torch
|
||||
>>> from diffusers import LTX2Pipeline
|
||||
>>> from diffusers import LTX2ImageToVideoPipeline
|
||||
>>> from diffusers.pipelines.ltx2.export_utils import encode_video
|
||||
>>> from diffusers.utils import load_image
|
||||
|
||||
@@ -62,7 +62,7 @@ EXAMPLE_DOC_STRING = """
|
||||
>>> negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"
|
||||
|
||||
>>> frame_rate = 24.0
|
||||
>>> video = pipe(
|
||||
>>> video, audio = pipe(
|
||||
... image=image,
|
||||
... prompt=prompt,
|
||||
... negative_prompt=negative_prompt,
|
||||
@@ -75,8 +75,6 @@ EXAMPLE_DOC_STRING = """
|
||||
... output_type="np",
|
||||
... return_dict=False,
|
||||
... )
|
||||
>>> video = (video * 255).round().astype("uint8")
|
||||
>>> video = torch.from_numpy(video)
|
||||
|
||||
>>> encode_video(
|
||||
... video[0],
|
||||
@@ -1141,6 +1139,10 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL
|
||||
audio_coords = self.transformer.audio_rope.prepare_audio_coords(
|
||||
audio_latents.shape[0], audio_num_frames, audio_latents.device
|
||||
)
|
||||
# Duplicate the positional ids as well if using CFG
|
||||
if self.do_classifier_free_guidance:
|
||||
video_coords = video_coords.repeat((2,) + (1,) * (video_coords.ndim - 1)) # Repeat twice in batch dim
|
||||
audio_coords = audio_coords.repeat((2,) + (1,) * (audio_coords.ndim - 1))
|
||||
|
||||
# 7. Denoising loop
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
|
||||
@@ -76,8 +76,6 @@ EXAMPLE_DOC_STRING = """
|
||||
... output_type="np",
|
||||
... return_dict=False,
|
||||
... )[0]
|
||||
>>> video = (video * 255).round().astype("uint8")
|
||||
>>> video = torch.from_numpy(video)
|
||||
|
||||
>>> encode_video(
|
||||
... video[0],
|
||||
|
||||
@@ -18,7 +18,6 @@ import re
|
||||
from copy import deepcopy
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import ftfy
|
||||
import torch
|
||||
from transformers import AutoTokenizer, UMT5EncoderModel
|
||||
|
||||
|
||||
@@ -18,7 +18,6 @@ import re
|
||||
from copy import deepcopy
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import ftfy
|
||||
import PIL
|
||||
import torch
|
||||
from transformers import AutoTokenizer, UMT5EncoderModel
|
||||
|
||||
@@ -19,7 +19,6 @@ import re
|
||||
from copy import deepcopy
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import ftfy
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import AutoTokenizer, UMT5EncoderModel
|
||||
|
||||
@@ -41,7 +41,7 @@ class GGUFQuantizer(DiffusersQuantizer):
|
||||
|
||||
self.compute_dtype = quantization_config.compute_dtype
|
||||
self.pre_quantized = quantization_config.pre_quantized
|
||||
self.modules_to_not_convert = quantization_config.modules_to_not_convert
|
||||
self.modules_to_not_convert = quantization_config.modules_to_not_convert or []
|
||||
|
||||
if not isinstance(self.modules_to_not_convert, list):
|
||||
self.modules_to_not_convert = [self.modules_to_not_convert]
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
# DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver and https://github.com/NVlabs/edm
|
||||
|
||||
import math
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import List, Literal, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -51,13 +51,15 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
schedule was incorporated in this model: https://huggingface.co/stabilityai/cosxl.
|
||||
num_train_timesteps (`int`, defaults to 1000):
|
||||
The number of diffusion steps to train the model.
|
||||
solver_order (`int`, defaults to 2):
|
||||
The DPMSolver order which can be `1` or `2` or `3`. It is recommended to use `solver_order=2` for guided
|
||||
sampling, and `solver_order=3` for unconditional sampling.
|
||||
prediction_type (`str`, defaults to `epsilon`, *optional*):
|
||||
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
|
||||
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
|
||||
Video](https://huggingface.co/papers/2210.02303) paper).
|
||||
rho (`float`, *optional*, defaults to 7.0):
|
||||
The rho parameter in the Karras sigma schedule. This was set to 7.0 in the EDM paper [1].
|
||||
solver_order (`int`, defaults to 2):
|
||||
The DPMSolver order which can be `1` or `2` or `3`. It is recommended to use `solver_order=2` for guided
|
||||
sampling, and `solver_order=3` for unconditional sampling.
|
||||
thresholding (`bool`, defaults to `False`):
|
||||
Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
|
||||
as Stable Diffusion.
|
||||
@@ -94,19 +96,19 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
sigma_min: float = 0.002,
|
||||
sigma_max: float = 80.0,
|
||||
sigma_data: float = 0.5,
|
||||
sigma_schedule: str = "karras",
|
||||
sigma_schedule: Literal["karras", "exponential"] = "karras",
|
||||
num_train_timesteps: int = 1000,
|
||||
prediction_type: str = "epsilon",
|
||||
prediction_type: Literal["epsilon", "sample", "v_prediction"] = "epsilon",
|
||||
rho: float = 7.0,
|
||||
solver_order: int = 2,
|
||||
thresholding: bool = False,
|
||||
dynamic_thresholding_ratio: float = 0.995,
|
||||
sample_max_value: float = 1.0,
|
||||
algorithm_type: str = "dpmsolver++",
|
||||
solver_type: str = "midpoint",
|
||||
algorithm_type: Literal["dpmsolver++", "sde-dpmsolver++"] = "dpmsolver++",
|
||||
solver_type: Literal["midpoint", "heun"] = "midpoint",
|
||||
lower_order_final: bool = True,
|
||||
euler_at_final: bool = False,
|
||||
final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
|
||||
final_sigmas_type: Optional[Literal["zero", "sigma_min"]] = "zero", # "zero", "sigma_min"
|
||||
):
|
||||
# settings for DPM-Solver
|
||||
if algorithm_type not in ["dpmsolver++", "sde-dpmsolver++"]:
|
||||
@@ -145,19 +147,19 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
@property
|
||||
def init_noise_sigma(self):
|
||||
def init_noise_sigma(self) -> float:
|
||||
# standard deviation of the initial noise distribution
|
||||
return (self.config.sigma_max**2 + 1) ** 0.5
|
||||
|
||||
@property
|
||||
def step_index(self):
|
||||
def step_index(self) -> int:
|
||||
"""
|
||||
The index counter for current timestep. It will increase 1 after each scheduler step.
|
||||
"""
|
||||
return self._step_index
|
||||
|
||||
@property
|
||||
def begin_index(self):
|
||||
def begin_index(self) -> int:
|
||||
"""
|
||||
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
|
||||
"""
|
||||
@@ -274,7 +276,11 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.is_scale_input_called = True
|
||||
return sample
|
||||
|
||||
def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None):
|
||||
def set_timesteps(
|
||||
self,
|
||||
num_inference_steps: int = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
):
|
||||
"""
|
||||
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
||||
|
||||
@@ -460,13 +466,12 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
def _sigma_to_alpha_sigma_t(self, sigma):
|
||||
alpha_t = torch.tensor(1) # Inputs are pre-scaled before going into unet, so alpha_t = 1
|
||||
sigma_t = sigma
|
||||
|
||||
return alpha_t, sigma_t
|
||||
|
||||
def convert_model_output(
|
||||
self,
|
||||
model_output: torch.Tensor,
|
||||
sample: torch.Tensor = None,
|
||||
sample: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is
|
||||
@@ -497,7 +502,7 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
def dpm_solver_first_order_update(
|
||||
self,
|
||||
model_output: torch.Tensor,
|
||||
sample: torch.Tensor = None,
|
||||
sample: torch.Tensor,
|
||||
noise: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
@@ -508,6 +513,8 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
The direct output from the learned diffusion model.
|
||||
sample (`torch.Tensor`):
|
||||
A current instance of a sample created by the diffusion process.
|
||||
noise (`torch.Tensor`, *optional*):
|
||||
The noise tensor to add to the original samples.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
@@ -538,7 +545,7 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
def multistep_dpm_solver_second_order_update(
|
||||
self,
|
||||
model_output_list: List[torch.Tensor],
|
||||
sample: torch.Tensor = None,
|
||||
sample: torch.Tensor,
|
||||
noise: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
@@ -549,6 +556,8 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
The direct outputs from learned diffusion model at current and latter timesteps.
|
||||
sample (`torch.Tensor`):
|
||||
A current instance of a sample created by the diffusion process.
|
||||
noise (`torch.Tensor`, *optional*):
|
||||
The noise tensor to add to the original samples.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
@@ -609,7 +618,7 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
def multistep_dpm_solver_third_order_update(
|
||||
self,
|
||||
model_output_list: List[torch.Tensor],
|
||||
sample: torch.Tensor = None,
|
||||
sample: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
One step for the third-order multistep DPMSolver.
|
||||
@@ -698,7 +707,7 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
return step_index
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index
|
||||
def _init_step_index(self, timestep):
|
||||
def _init_step_index(self, timestep: Union[int, torch.Tensor]) -> None:
|
||||
"""
|
||||
Initialize the step_index counter for the scheduler.
|
||||
|
||||
@@ -719,7 +728,7 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
model_output: torch.Tensor,
|
||||
timestep: Union[int, torch.Tensor],
|
||||
sample: torch.Tensor,
|
||||
generator=None,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[SchedulerOutput, Tuple]:
|
||||
"""
|
||||
@@ -860,5 +869,5 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
c_in = 1 / ((sigma**2 + self.config.sigma_data**2) ** 0.5)
|
||||
return c_in
|
||||
|
||||
def __len__(self):
|
||||
def __len__(self) -> int:
|
||||
return self.config.num_train_timesteps
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import List, Literal, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -102,12 +102,21 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
use_karras_sigmas: Optional[bool] = False,
|
||||
use_exponential_sigmas: Optional[bool] = False,
|
||||
use_beta_sigmas: Optional[bool] = False,
|
||||
time_shift_type: str = "exponential",
|
||||
time_shift_type: Literal["exponential", "linear"] = "exponential",
|
||||
stochastic_sampling: bool = False,
|
||||
):
|
||||
if self.config.use_beta_sigmas and not is_scipy_available():
|
||||
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
|
||||
if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
|
||||
if (
|
||||
sum(
|
||||
[
|
||||
self.config.use_beta_sigmas,
|
||||
self.config.use_exponential_sigmas,
|
||||
self.config.use_karras_sigmas,
|
||||
]
|
||||
)
|
||||
> 1
|
||||
):
|
||||
raise ValueError(
|
||||
"Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
|
||||
)
|
||||
@@ -166,6 +175,13 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
self._begin_index = begin_index
|
||||
|
||||
def set_shift(self, shift: float):
|
||||
"""
|
||||
Sets the shift value for the scheduler.
|
||||
|
||||
Args:
|
||||
shift (`float`):
|
||||
The shift value to be set.
|
||||
"""
|
||||
self._shift = shift
|
||||
|
||||
def scale_noise(
|
||||
@@ -218,10 +234,25 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
return sample
|
||||
|
||||
def _sigma_to_t(self, sigma):
|
||||
def _sigma_to_t(self, sigma) -> float:
|
||||
return sigma * self.config.num_train_timesteps
|
||||
|
||||
def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
|
||||
def time_shift(self, mu: float, sigma: float, t: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Apply time shifting to the sigmas.
|
||||
|
||||
Args:
|
||||
mu (`float`):
|
||||
The mu parameter for the time shift.
|
||||
sigma (`float`):
|
||||
The sigma parameter for the time shift.
|
||||
t (`torch.Tensor`):
|
||||
The input timesteps.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
The time-shifted timesteps.
|
||||
"""
|
||||
if self.config.time_shift_type == "exponential":
|
||||
return self._time_shift_exponential(mu, sigma, t)
|
||||
elif self.config.time_shift_type == "linear":
|
||||
@@ -302,7 +333,9 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
if sigmas is None:
|
||||
if timesteps is None:
|
||||
timesteps = np.linspace(
|
||||
self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps
|
||||
self._sigma_to_t(self.sigma_max),
|
||||
self._sigma_to_t(self.sigma_min),
|
||||
num_inference_steps,
|
||||
)
|
||||
sigmas = timesteps / self.config.num_train_timesteps
|
||||
else:
|
||||
@@ -350,7 +383,24 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
def index_for_timestep(
|
||||
self,
|
||||
timestep: Union[float, torch.FloatTensor],
|
||||
schedule_timesteps: Optional[torch.FloatTensor] = None,
|
||||
) -> int:
|
||||
"""
|
||||
Get the index for the given timestep.
|
||||
|
||||
Args:
|
||||
timestep (`float` or `torch.FloatTensor`):
|
||||
The timestep to find the index for.
|
||||
schedule_timesteps (`torch.FloatTensor`, *optional*):
|
||||
The schedule timesteps to validate against. If `None`, the scheduler's timesteps are used.
|
||||
|
||||
Returns:
|
||||
`int`:
|
||||
The index of the timestep.
|
||||
"""
|
||||
if schedule_timesteps is None:
|
||||
schedule_timesteps = self.timesteps
|
||||
|
||||
@@ -364,7 +414,7 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
return indices[pos].item()
|
||||
|
||||
def _init_step_index(self, timestep):
|
||||
def _init_step_index(self, timestep: Union[float, torch.FloatTensor]) -> None:
|
||||
if self.begin_index is None:
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
@@ -405,7 +455,7 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
A random number generator.
|
||||
per_token_timesteps (`torch.Tensor`, *optional*):
|
||||
The timesteps for each token in the sample.
|
||||
return_dict (`bool`):
|
||||
return_dict (`bool`, defaults to `True`):
|
||||
Whether or not to return a
|
||||
[`~schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteSchedulerOutput`] or tuple.
|
||||
|
||||
@@ -474,7 +524,7 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample)
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
|
||||
def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
|
||||
def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
|
||||
"""
|
||||
Construct the noise schedule as proposed in [Elucidating the Design Space of Diffusion-Based Generative
|
||||
Models](https://huggingface.co/papers/2206.00364).
|
||||
@@ -595,11 +645,11 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
)
|
||||
return sigmas
|
||||
|
||||
def _time_shift_exponential(self, mu, sigma, t):
|
||||
def _time_shift_exponential(self, mu: float, sigma: float, t: torch.Tensor) -> torch.Tensor:
|
||||
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
||||
|
||||
def _time_shift_linear(self, mu, sigma, t):
|
||||
def _time_shift_linear(self, mu: float, sigma: float, t: torch.Tensor) -> torch.Tensor:
|
||||
return mu / (mu + (1 / t - 1) ** sigma)
|
||||
|
||||
def __len__(self):
|
||||
def __len__(self) -> int:
|
||||
return self.config.num_train_timesteps
|
||||
|
||||
@@ -51,9 +51,6 @@ class FlowMatchHeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
Args:
|
||||
num_train_timesteps (`int`, defaults to 1000):
|
||||
The number of diffusion steps to train the model.
|
||||
timestep_spacing (`str`, defaults to `"linspace"`):
|
||||
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
|
||||
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
|
||||
shift (`float`, defaults to 1.0):
|
||||
The shift value for the timestep schedule.
|
||||
"""
|
||||
@@ -110,7 +107,7 @@ class FlowMatchHeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
def scale_noise(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
timestep: torch.FloatTensor,
|
||||
timestep: Union[float, torch.FloatTensor],
|
||||
noise: torch.FloatTensor,
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
@@ -119,7 +116,7 @@ class FlowMatchHeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
Args:
|
||||
sample (`torch.FloatTensor`):
|
||||
The input sample.
|
||||
timestep (`torch.FloatTensor`):
|
||||
timestep (`float` or `torch.FloatTensor`):
|
||||
The current timestep in the diffusion chain.
|
||||
noise (`torch.FloatTensor`):
|
||||
The noise tensor.
|
||||
@@ -137,10 +134,14 @@ class FlowMatchHeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
return sample
|
||||
|
||||
def _sigma_to_t(self, sigma):
|
||||
def _sigma_to_t(self, sigma: float) -> float:
|
||||
return sigma * self.config.num_train_timesteps
|
||||
|
||||
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
|
||||
def set_timesteps(
|
||||
self,
|
||||
num_inference_steps: int,
|
||||
device: Union[str, torch.device] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
||||
|
||||
@@ -153,7 +154,9 @@ class FlowMatchHeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.num_inference_steps = num_inference_steps
|
||||
|
||||
timesteps = np.linspace(
|
||||
self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps
|
||||
self._sigma_to_t(self.sigma_max),
|
||||
self._sigma_to_t(self.sigma_min),
|
||||
num_inference_steps,
|
||||
)
|
||||
|
||||
sigmas = timesteps / self.config.num_train_timesteps
|
||||
@@ -174,7 +177,24 @@ class FlowMatchHeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
def index_for_timestep(
|
||||
self,
|
||||
timestep: Union[float, torch.FloatTensor],
|
||||
schedule_timesteps: Optional[torch.FloatTensor] = None,
|
||||
) -> int:
|
||||
"""
|
||||
Find the index of a given timestep in the timestep schedule.
|
||||
|
||||
Args:
|
||||
timestep (`float` or `torch.FloatTensor`):
|
||||
The timestep value to find in the schedule.
|
||||
schedule_timesteps (`torch.FloatTensor`, *optional*):
|
||||
The timestep schedule to search in. If `None`, uses `self.timesteps`.
|
||||
|
||||
Returns:
|
||||
`int`:
|
||||
The index of the timestep in the schedule.
|
||||
"""
|
||||
if schedule_timesteps is None:
|
||||
schedule_timesteps = self.timesteps
|
||||
|
||||
@@ -188,7 +208,7 @@ class FlowMatchHeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
return indices[pos].item()
|
||||
|
||||
def _init_step_index(self, timestep):
|
||||
def _init_step_index(self, timestep: Union[float, torch.FloatTensor]) -> None:
|
||||
if self.begin_index is None:
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
@@ -197,7 +217,10 @@ class FlowMatchHeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
self._step_index = self._begin_index
|
||||
|
||||
@property
|
||||
def state_in_first_order(self):
|
||||
def state_in_first_order(self) -> bool:
|
||||
"""
|
||||
Returns whether the scheduler is in the first-order state.
|
||||
"""
|
||||
return self.dt is None
|
||||
|
||||
def step(
|
||||
@@ -219,13 +242,19 @@ class FlowMatchHeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
Args:
|
||||
model_output (`torch.FloatTensor`):
|
||||
The direct output from learned diffusion model.
|
||||
timestep (`float`):
|
||||
timestep (`float` or `torch.FloatTensor`):
|
||||
The current discrete timestep in the diffusion chain.
|
||||
sample (`torch.FloatTensor`):
|
||||
A current instance of a sample created by the diffusion process.
|
||||
s_churn (`float`):
|
||||
s_tmin (`float`):
|
||||
s_tmax (`float`):
|
||||
Stochasticity parameter that controls the amount of noise added during sampling. Higher values increase
|
||||
randomness.
|
||||
s_tmin (`float`):
|
||||
Minimum timestep threshold for applying stochasticity. Only timesteps above this value will have noise
|
||||
added.
|
||||
s_tmax (`float`):
|
||||
Maximum timestep threshold for applying stochasticity. Only timesteps below this value will have noise
|
||||
added.
|
||||
s_noise (`float`, defaults to 1.0):
|
||||
Scaling factor for noise added to the sample.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
@@ -274,7 +303,10 @@ class FlowMatchHeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
if gamma > 0:
|
||||
noise = randn_tensor(
|
||||
model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator
|
||||
model_output.shape,
|
||||
dtype=model_output.dtype,
|
||||
device=model_output.device,
|
||||
generator=generator,
|
||||
)
|
||||
eps = noise * s_noise
|
||||
sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5
|
||||
@@ -320,5 +352,5 @@ class FlowMatchHeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
return FlowMatchHeunDiscreteSchedulerOutput(prev_sample=prev_sample)
|
||||
|
||||
def __len__(self):
|
||||
def __len__(self) -> int:
|
||||
return self.config.num_train_timesteps
|
||||
|
||||
@@ -482,6 +482,21 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler.time_shift
|
||||
def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
|
||||
"""
|
||||
Apply time shifting to the sigmas.
|
||||
|
||||
Args:
|
||||
mu (`float`):
|
||||
The mu parameter for the time shift.
|
||||
sigma (`float`):
|
||||
The sigma parameter for the time shift.
|
||||
t (`torch.Tensor`):
|
||||
The input timesteps.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
The time-shifted timesteps.
|
||||
"""
|
||||
if self.config.time_shift_type == "exponential":
|
||||
return self._time_shift_exponential(mu, sigma, t)
|
||||
elif self.config.time_shift_type == "linear":
|
||||
|
||||
@@ -896,6 +896,21 @@ class ControlNetXSAdapter(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class CosmosControlNetModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class CosmosTransformer3DModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
||||
@@ -977,6 +977,21 @@ class Cosmos2_5_PredictBasePipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class Cosmos2_5_TransferPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class Cosmos2TextToImagePipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
|
||||
0
tests/models/controlnets/__init__.py
Normal file
0
tests/models/controlnets/__init__.py
Normal file
255
tests/models/controlnets/test_models_controlnet_cosmos.py
Normal file
255
tests/models/controlnets/test_models_controlnet_cosmos.py
Normal file
@@ -0,0 +1,255 @@
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import CosmosControlNetModel
|
||||
from diffusers.models.controlnets.controlnet_cosmos import CosmosControlNetOutput
|
||||
|
||||
from ...testing_utils import enable_full_determinism, torch_device
|
||||
from ..test_modeling_common import ModelTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class CosmosControlNetModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = CosmosControlNetModel
|
||||
main_input_name = "controls_latents"
|
||||
uses_custom_attn_processor = True
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
batch_size = 1
|
||||
num_channels = 16
|
||||
num_frames = 1
|
||||
height = 16
|
||||
width = 16
|
||||
text_embed_dim = 32
|
||||
sequence_length = 12
|
||||
img_context_dim_in = 32
|
||||
img_context_num_tokens = 4
|
||||
|
||||
# Raw latents (not patchified) - the controlnet computes embeddings internally
|
||||
controls_latents = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
|
||||
latents = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
|
||||
timestep = torch.tensor([0.5]).to(torch_device) # Diffusion timestep
|
||||
condition_mask = torch.ones(batch_size, 1, num_frames, height, width).to(torch_device)
|
||||
padding_mask = torch.zeros(batch_size, 1, height, width).to(torch_device)
|
||||
|
||||
# Text embeddings
|
||||
text_context = torch.randn((batch_size, sequence_length, text_embed_dim)).to(torch_device)
|
||||
# Image context for Cosmos 2.5
|
||||
img_context = torch.randn((batch_size, img_context_num_tokens, img_context_dim_in)).to(torch_device)
|
||||
encoder_hidden_states = (text_context, img_context)
|
||||
|
||||
return {
|
||||
"controls_latents": controls_latents,
|
||||
"latents": latents,
|
||||
"timestep": timestep,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"condition_mask": condition_mask,
|
||||
"conditioning_scale": 1.0,
|
||||
"padding_mask": padding_mask,
|
||||
}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (16, 1, 16, 16)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
# Output is tuple of n_controlnet_blocks tensors, each with shape (batch, num_patches, model_channels)
|
||||
# After stacking by normalize_output: (n_blocks, batch, num_patches, model_channels)
|
||||
# For test config: n_blocks=2, num_patches=64 (1*8*8), model_channels=32
|
||||
# output_shape is used as (batch_size,) + output_shape, so: (2, 64, 32)
|
||||
return (2, 64, 32)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"n_controlnet_blocks": 2,
|
||||
"in_channels": 16 + 1 + 1, # control_latent_channels + condition_mask + padding_mask
|
||||
"latent_channels": 16 + 1 + 1, # base_latent_channels (16) + condition_mask (1) + padding_mask (1) = 18
|
||||
"model_channels": 32,
|
||||
"num_attention_heads": 2,
|
||||
"attention_head_dim": 16,
|
||||
"mlp_ratio": 2,
|
||||
"text_embed_dim": 32,
|
||||
"adaln_lora_dim": 4,
|
||||
"patch_size": (1, 2, 2),
|
||||
"max_size": (4, 32, 32),
|
||||
"rope_scale": (2.0, 1.0, 1.0),
|
||||
"extra_pos_embed_type": None,
|
||||
"img_context_dim_in": 32,
|
||||
"img_context_dim_out": 32,
|
||||
"use_crossattn_projection": False, # Test doesn't need this projection
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def test_output_format(self):
|
||||
"""Test that the model outputs CosmosControlNetOutput with correct structure."""
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(**inputs_dict)
|
||||
|
||||
self.assertIsInstance(output, CosmosControlNetOutput)
|
||||
self.assertIsInstance(output.control_block_samples, list)
|
||||
self.assertEqual(len(output.control_block_samples), init_dict["n_controlnet_blocks"])
|
||||
for tensor in output.control_block_samples:
|
||||
self.assertIsInstance(tensor, torch.Tensor)
|
||||
|
||||
def test_output_list_format(self):
|
||||
"""Test that return_dict=False returns a tuple containing a list."""
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(**inputs_dict, return_dict=False)
|
||||
|
||||
self.assertIsInstance(output, tuple)
|
||||
self.assertEqual(len(output), 1)
|
||||
self.assertIsInstance(output[0], list)
|
||||
self.assertEqual(len(output[0]), init_dict["n_controlnet_blocks"])
|
||||
|
||||
def test_conditioning_scale_single(self):
|
||||
"""Test that a single conditioning scale is broadcast to all blocks."""
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
inputs_dict["conditioning_scale"] = 0.5
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(**inputs_dict)
|
||||
|
||||
self.assertEqual(len(output.control_block_samples), init_dict["n_controlnet_blocks"])
|
||||
|
||||
def test_conditioning_scale_list(self):
|
||||
"""Test that a list of conditioning scales is applied per block."""
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
# Provide a scale for each block
|
||||
inputs_dict["conditioning_scale"] = [0.5, 1.0]
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(**inputs_dict)
|
||||
|
||||
self.assertEqual(len(output.control_block_samples), init_dict["n_controlnet_blocks"])
|
||||
|
||||
def test_forward_with_none_img_context(self):
|
||||
"""Test forward pass when img_context is None."""
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
# Set encoder_hidden_states to (text_context, None)
|
||||
text_context = inputs_dict["encoder_hidden_states"][0]
|
||||
inputs_dict["encoder_hidden_states"] = (text_context, None)
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(**inputs_dict)
|
||||
|
||||
self.assertIsInstance(output, CosmosControlNetOutput)
|
||||
self.assertEqual(len(output.control_block_samples), init_dict["n_controlnet_blocks"])
|
||||
|
||||
def test_forward_without_img_context_proj(self):
|
||||
"""Test forward pass when img_context_proj is not configured."""
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
# Disable img_context_proj
|
||||
init_dict["img_context_dim_in"] = None
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
# When img_context is disabled, pass only text context (not a tuple)
|
||||
text_context = inputs_dict["encoder_hidden_states"][0]
|
||||
inputs_dict["encoder_hidden_states"] = text_context
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(**inputs_dict)
|
||||
|
||||
self.assertIsInstance(output, CosmosControlNetOutput)
|
||||
self.assertEqual(len(output.control_block_samples), init_dict["n_controlnet_blocks"])
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"CosmosControlNetModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
# Note: test_set_attn_processor_for_determinism already handles uses_custom_attn_processor=True
|
||||
# so no explicit skip needed for it
|
||||
# Note: test_forward_signature and test_set_default_attn_processor don't exist in base class
|
||||
|
||||
# Skip tests that don't apply to this architecture
|
||||
@unittest.skip("CosmosControlNetModel doesn't use norm groups.")
|
||||
def test_forward_with_norm_groups(self):
|
||||
pass
|
||||
|
||||
# Skip tests that expect .sample attribute - ControlNets don't have this
|
||||
@unittest.skip("ControlNet output doesn't have .sample attribute")
|
||||
def test_effective_gradient_checkpointing(self):
|
||||
pass
|
||||
|
||||
# Skip tests that compute MSE loss against single tensor output
|
||||
@unittest.skip("ControlNet outputs list of control blocks, not single tensor for MSE loss")
|
||||
def test_ema_training(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("ControlNet outputs list of control blocks, not single tensor for MSE loss")
|
||||
def test_training(self):
|
||||
pass
|
||||
|
||||
# Skip tests where output shape comparison doesn't apply to ControlNets
|
||||
@unittest.skip("ControlNet output shape doesn't match input shape by design")
|
||||
def test_output(self):
|
||||
pass
|
||||
|
||||
# Skip outputs_equivalence - dict/list comparison logic not compatible (recursive_check expects dict.values())
|
||||
@unittest.skip("ControlNet output structure not compatible with recursive dict check")
|
||||
def test_outputs_equivalence(self):
|
||||
pass
|
||||
|
||||
# Skip model parallelism - base test uses torch.allclose(base_output[0], new_output[0]) which fails
|
||||
# because output[0] is the list of control_block_samples, not a tensor
|
||||
@unittest.skip("test_model_parallelism uses torch.allclose on output[0] which is a list, not a tensor")
|
||||
def test_model_parallelism(self):
|
||||
pass
|
||||
|
||||
# Skip layerwise casting tests - these have two issues:
|
||||
# 1. _inference and _memory: dtype compatibility issues with learnable_pos_embed and float8/bfloat16
|
||||
# 2. _training: same as test_training - mse_loss expects tensor, not list
|
||||
@unittest.skip("Layerwise casting has dtype issues with learnable_pos_embed")
|
||||
def test_layerwise_casting_inference(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Layerwise casting has dtype issues with learnable_pos_embed")
|
||||
def test_layerwise_casting_memory(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("test_layerwise_casting_training computes mse_loss on list output")
|
||||
def test_layerwise_casting_training(self):
|
||||
pass
|
||||
@@ -446,16 +446,17 @@ class ModelTesterMixin:
|
||||
torch_device not in ["cuda", "xpu"],
|
||||
reason="float16 and bfloat16 can only be used with an accelerator",
|
||||
)
|
||||
def test_keep_in_fp32_modules(self):
|
||||
def test_keep_in_fp32_modules(self, tmp_path):
|
||||
model = self.model_class(**self.get_init_dict())
|
||||
fp32_modules = model._keep_in_fp32_modules
|
||||
|
||||
if fp32_modules is None or len(fp32_modules) == 0:
|
||||
pytest.skip("Model does not have _keep_in_fp32_modules defined.")
|
||||
|
||||
# Test with float16
|
||||
model.to(torch_device)
|
||||
model.to(torch.float16)
|
||||
# Save the model and reload with float16 dtype
|
||||
# _keep_in_fp32_modules is only enforced during from_pretrained loading
|
||||
model.save_pretrained(tmp_path)
|
||||
model = self.model_class.from_pretrained(tmp_path, torch_dtype=torch.float16).to(torch_device)
|
||||
|
||||
for name, param in model.named_parameters():
|
||||
if any(module_to_keep_in_fp32 in name.split(".") for module_to_keep_in_fp32 in fp32_modules):
|
||||
@@ -470,7 +471,7 @@ class ModelTesterMixin:
|
||||
)
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"])
|
||||
@torch.no_grad()
|
||||
def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype):
|
||||
def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype, atol=1e-4, rtol=0):
|
||||
model = self.model_class(**self.get_init_dict())
|
||||
model.to(torch_device)
|
||||
fp32_modules = model._keep_in_fp32_modules or []
|
||||
@@ -490,10 +491,6 @@ class ModelTesterMixin:
|
||||
output = model(**inputs, return_dict=False)[0]
|
||||
output_loaded = model_loaded(**inputs, return_dict=False)[0]
|
||||
|
||||
self._check_dtype_inference_output(output, output_loaded, dtype)
|
||||
|
||||
def _check_dtype_inference_output(self, output, output_loaded, dtype, atol=1e-4, rtol=0):
|
||||
"""Check dtype inference output with configurable tolerance."""
|
||||
assert_tensors_close(
|
||||
output, output_loaded, atol=atol, rtol=rtol, msg=f"Loaded model output differs for {dtype}"
|
||||
)
|
||||
|
||||
@@ -176,15 +176,7 @@ class QuantizationTesterMixin:
|
||||
model_quantized = self._create_quantized_model(config_kwargs)
|
||||
model_quantized.to(torch_device)
|
||||
|
||||
# Get model dtype from first parameter
|
||||
model_dtype = next(model_quantized.parameters()).dtype
|
||||
|
||||
inputs = self.get_dummy_inputs()
|
||||
# Cast inputs to model dtype
|
||||
inputs = {
|
||||
k: v.to(model_dtype) if isinstance(v, torch.Tensor) and v.is_floating_point() else v
|
||||
for k, v in inputs.items()
|
||||
}
|
||||
output = model_quantized(**inputs, return_dict=False)[0]
|
||||
|
||||
assert output is not None, "Model output is None"
|
||||
@@ -229,6 +221,8 @@ class QuantizationTesterMixin:
|
||||
init_lora_weights=False,
|
||||
)
|
||||
model.add_adapter(lora_config)
|
||||
# Move LoRA adapter weights to device (they default to CPU)
|
||||
model.to(torch_device)
|
||||
|
||||
inputs = self.get_dummy_inputs()
|
||||
output = model(**inputs, return_dict=False)[0]
|
||||
@@ -1021,9 +1015,6 @@ class GGUFTesterMixin(GGUFConfigMixin, QuantizationTesterMixin):
|
||||
"""Test that dequantize() works correctly."""
|
||||
self._test_dequantize({"compute_dtype": torch.bfloat16})
|
||||
|
||||
def test_gguf_quantized_layers(self):
|
||||
self._test_quantized_layers({"compute_dtype": torch.bfloat16})
|
||||
|
||||
|
||||
@is_quantization
|
||||
@is_modelopt
|
||||
|
||||
@@ -12,57 +12,57 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from diffusers import WanTransformer3DModel
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
from ...testing_utils import (
|
||||
enable_full_determinism,
|
||||
torch_device,
|
||||
from ...testing_utils import enable_full_determinism, torch_device
|
||||
from ..testing_utils import (
|
||||
AttentionTesterMixin,
|
||||
BaseModelTesterConfig,
|
||||
BitsAndBytesTesterMixin,
|
||||
GGUFCompileTesterMixin,
|
||||
GGUFTesterMixin,
|
||||
MemoryTesterMixin,
|
||||
ModelTesterMixin,
|
||||
TorchAoTesterMixin,
|
||||
TorchCompileTesterMixin,
|
||||
TrainingTesterMixin,
|
||||
)
|
||||
from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class WanTransformer3DTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = WanTransformer3DModel
|
||||
main_input_name = "hidden_states"
|
||||
uses_custom_attn_processor = True
|
||||
class WanTransformer3DTesterConfig(BaseModelTesterConfig):
|
||||
@property
|
||||
def model_class(self):
|
||||
return WanTransformer3DModel
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
batch_size = 1
|
||||
num_channels = 4
|
||||
num_frames = 2
|
||||
height = 16
|
||||
width = 16
|
||||
text_encoder_embedding_dim = 16
|
||||
sequence_length = 12
|
||||
def pretrained_model_name_or_path(self):
|
||||
return "hf-internal-testing/tiny-wan22-transformer"
|
||||
|
||||
hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
|
||||
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
|
||||
encoder_hidden_states = torch.randn((batch_size, sequence_length, text_encoder_embedding_dim)).to(torch_device)
|
||||
@property
|
||||
def output_shape(self) -> tuple[int, ...]:
|
||||
return (4, 2, 16, 16)
|
||||
|
||||
@property
|
||||
def input_shape(self) -> tuple[int, ...]:
|
||||
return (4, 2, 16, 16)
|
||||
|
||||
@property
|
||||
def main_input_name(self) -> str:
|
||||
return "hidden_states"
|
||||
|
||||
@property
|
||||
def generator(self):
|
||||
return torch.Generator("cpu").manual_seed(0)
|
||||
|
||||
def get_init_dict(self) -> dict[str, int | list[int] | tuple | str | bool]:
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"timestep": timestep,
|
||||
}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (4, 1, 16, 16)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (4, 1, 16, 16)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"patch_size": (1, 2, 2),
|
||||
"num_attention_heads": 2,
|
||||
"attention_head_dim": 12,
|
||||
@@ -76,16 +76,160 @@ class WanTransformer3DTests(ModelTesterMixin, unittest.TestCase):
|
||||
"qk_norm": "rms_norm_across_heads",
|
||||
"rope_max_seq_len": 32,
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
|
||||
batch_size = 1
|
||||
num_channels = 4
|
||||
num_frames = 2
|
||||
height = 16
|
||||
width = 16
|
||||
text_encoder_embedding_dim = 16
|
||||
sequence_length = 12
|
||||
|
||||
return {
|
||||
"hidden_states": randn_tensor(
|
||||
(batch_size, num_channels, num_frames, height, width),
|
||||
generator=self.generator,
|
||||
device=torch_device,
|
||||
),
|
||||
"encoder_hidden_states": randn_tensor(
|
||||
(batch_size, sequence_length, text_encoder_embedding_dim),
|
||||
generator=self.generator,
|
||||
device=torch_device,
|
||||
),
|
||||
"timestep": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device),
|
||||
}
|
||||
|
||||
|
||||
class TestWanTransformer3D(WanTransformer3DTesterConfig, ModelTesterMixin):
|
||||
"""Core model tests for Wan Transformer 3D."""
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"])
|
||||
def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype):
|
||||
# Skip: fp16/bf16 require very high atol to pass, providing little signal.
|
||||
# Dtype preservation is already tested by test_from_save_pretrained_dtype and test_keep_in_fp32_modules.
|
||||
pytest.skip("Tolerance requirements too high for meaningful test")
|
||||
|
||||
|
||||
class TestWanTransformer3DMemory(WanTransformer3DTesterConfig, MemoryTesterMixin):
|
||||
"""Memory optimization tests for Wan Transformer 3D."""
|
||||
|
||||
|
||||
class TestWanTransformer3DTraining(WanTransformer3DTesterConfig, TrainingTesterMixin):
|
||||
"""Training tests for Wan Transformer 3D."""
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"WanTransformer3DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
|
||||
class WanTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
|
||||
model_class = WanTransformer3DModel
|
||||
class TestWanTransformer3DAttention(WanTransformer3DTesterConfig, AttentionTesterMixin):
|
||||
"""Attention processor tests for Wan Transformer 3D."""
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
return WanTransformer3DTests().prepare_init_args_and_inputs_for_common()
|
||||
|
||||
class TestWanTransformer3DCompile(WanTransformer3DTesterConfig, TorchCompileTesterMixin):
|
||||
"""Torch compile tests for Wan Transformer 3D."""
|
||||
|
||||
|
||||
class TestWanTransformer3DBitsAndBytes(WanTransformer3DTesterConfig, BitsAndBytesTesterMixin):
|
||||
"""BitsAndBytes quantization tests for Wan Transformer 3D."""
|
||||
|
||||
@property
|
||||
def torch_dtype(self):
|
||||
return torch.float16
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
"""Override to provide inputs matching the tiny Wan model dimensions."""
|
||||
return {
|
||||
"hidden_states": randn_tensor(
|
||||
(1, 36, 2, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
),
|
||||
"encoder_hidden_states": randn_tensor(
|
||||
(1, 512, 4096), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
),
|
||||
"timestep": torch.tensor([1.0]).to(torch_device, self.torch_dtype),
|
||||
}
|
||||
|
||||
|
||||
class TestWanTransformer3DTorchAo(WanTransformer3DTesterConfig, TorchAoTesterMixin):
|
||||
"""TorchAO quantization tests for Wan Transformer 3D."""
|
||||
|
||||
@property
|
||||
def torch_dtype(self):
|
||||
return torch.bfloat16
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
"""Override to provide inputs matching the tiny Wan model dimensions."""
|
||||
return {
|
||||
"hidden_states": randn_tensor(
|
||||
(1, 36, 2, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
),
|
||||
"encoder_hidden_states": randn_tensor(
|
||||
(1, 512, 4096), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
),
|
||||
"timestep": torch.tensor([1.0]).to(torch_device, self.torch_dtype),
|
||||
}
|
||||
|
||||
|
||||
class TestWanTransformer3DGGUF(WanTransformer3DTesterConfig, GGUFTesterMixin):
|
||||
"""GGUF quantization tests for Wan Transformer 3D."""
|
||||
|
||||
@property
|
||||
def gguf_filename(self):
|
||||
return "https://huggingface.co/QuantStack/Wan2.2-I2V-A14B-GGUF/blob/main/LowNoise/Wan2.2-I2V-A14B-LowNoise-Q2_K.gguf"
|
||||
|
||||
@property
|
||||
def torch_dtype(self):
|
||||
return torch.bfloat16
|
||||
|
||||
def _create_quantized_model(self, config_kwargs=None, **extra_kwargs):
|
||||
return super()._create_quantized_model(
|
||||
config_kwargs, config="Wan-AI/Wan2.2-I2V-A14B-Diffusers", subfolder="transformer", **extra_kwargs
|
||||
)
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
"""Override to provide inputs matching the real Wan I2V model dimensions.
|
||||
|
||||
Wan 2.2 I2V: in_channels=36, text_dim=4096
|
||||
"""
|
||||
return {
|
||||
"hidden_states": randn_tensor(
|
||||
(1, 36, 2, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
),
|
||||
"encoder_hidden_states": randn_tensor(
|
||||
(1, 512, 4096), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
),
|
||||
"timestep": torch.tensor([1.0]).to(torch_device, self.torch_dtype),
|
||||
}
|
||||
|
||||
|
||||
class TestWanTransformer3DGGUFCompile(WanTransformer3DTesterConfig, GGUFCompileTesterMixin):
|
||||
"""GGUF + compile tests for Wan Transformer 3D."""
|
||||
|
||||
@property
|
||||
def gguf_filename(self):
|
||||
return "https://huggingface.co/QuantStack/Wan2.2-I2V-A14B-GGUF/blob/main/LowNoise/Wan2.2-I2V-A14B-LowNoise-Q2_K.gguf"
|
||||
|
||||
@property
|
||||
def torch_dtype(self):
|
||||
return torch.bfloat16
|
||||
|
||||
def _create_quantized_model(self, config_kwargs=None, **extra_kwargs):
|
||||
return super()._create_quantized_model(
|
||||
config_kwargs, config="Wan-AI/Wan2.2-I2V-A14B-Diffusers", subfolder="transformer", **extra_kwargs
|
||||
)
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
"""Override to provide inputs matching the real Wan I2V model dimensions.
|
||||
|
||||
Wan 2.2 I2V: in_channels=36, text_dim=4096
|
||||
"""
|
||||
return {
|
||||
"hidden_states": randn_tensor(
|
||||
(1, 36, 2, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
),
|
||||
"encoder_hidden_states": randn_tensor(
|
||||
(1, 512, 4096), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
),
|
||||
"timestep": torch.tensor([1.0]).to(torch_device, self.torch_dtype),
|
||||
}
|
||||
|
||||
@@ -12,76 +12,62 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from diffusers import WanAnimateTransformer3DModel
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
from ...testing_utils import (
|
||||
enable_full_determinism,
|
||||
torch_device,
|
||||
from ...testing_utils import enable_full_determinism, torch_device
|
||||
from ..testing_utils import (
|
||||
AttentionTesterMixin,
|
||||
BaseModelTesterConfig,
|
||||
BitsAndBytesTesterMixin,
|
||||
GGUFCompileTesterMixin,
|
||||
GGUFTesterMixin,
|
||||
MemoryTesterMixin,
|
||||
ModelTesterMixin,
|
||||
TorchAoTesterMixin,
|
||||
TorchCompileTesterMixin,
|
||||
TrainingTesterMixin,
|
||||
)
|
||||
from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class WanAnimateTransformer3DTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = WanAnimateTransformer3DModel
|
||||
main_input_name = "hidden_states"
|
||||
uses_custom_attn_processor = True
|
||||
class WanAnimateTransformer3DTesterConfig(BaseModelTesterConfig):
|
||||
@property
|
||||
def model_class(self):
|
||||
return WanAnimateTransformer3DModel
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
batch_size = 1
|
||||
num_channels = 4
|
||||
num_frames = 20 # To make the shapes work out; for complicated reasons we want 21 to divide num_frames + 1
|
||||
height = 16
|
||||
width = 16
|
||||
text_encoder_embedding_dim = 16
|
||||
sequence_length = 12
|
||||
|
||||
clip_seq_len = 12
|
||||
clip_dim = 16
|
||||
|
||||
inference_segment_length = 77 # The inference segment length in the full Wan2.2-Animate-14B model
|
||||
face_height = 16 # Should be square and match `motion_encoder_size` below
|
||||
face_width = 16
|
||||
|
||||
hidden_states = torch.randn((batch_size, 2 * num_channels + 4, num_frames + 1, height, width)).to(torch_device)
|
||||
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
|
||||
encoder_hidden_states = torch.randn((batch_size, sequence_length, text_encoder_embedding_dim)).to(torch_device)
|
||||
clip_ref_features = torch.randn((batch_size, clip_seq_len, clip_dim)).to(torch_device)
|
||||
pose_latents = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
|
||||
face_pixel_values = torch.randn((batch_size, 3, inference_segment_length, face_height, face_width)).to(
|
||||
torch_device
|
||||
)
|
||||
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"timestep": timestep,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"encoder_hidden_states_image": clip_ref_features,
|
||||
"pose_hidden_states": pose_latents,
|
||||
"face_pixel_values": face_pixel_values,
|
||||
}
|
||||
def pretrained_model_name_or_path(self):
|
||||
return "hf-internal-testing/tiny-wan-animate-transformer"
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (12, 1, 16, 16)
|
||||
def output_shape(self) -> tuple[int, ...]:
|
||||
# Output has fewer channels than input (4 vs 12)
|
||||
return (4, 21, 16, 16)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (4, 1, 16, 16)
|
||||
def input_shape(self) -> tuple[int, ...]:
|
||||
return (12, 21, 16, 16)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
@property
|
||||
def main_input_name(self) -> str:
|
||||
return "hidden_states"
|
||||
|
||||
@property
|
||||
def generator(self):
|
||||
return torch.Generator("cpu").manual_seed(0)
|
||||
|
||||
def get_init_dict(self) -> dict[str, int | list[int] | tuple | str | bool | float | dict]:
|
||||
# Use custom channel sizes since the default Wan Animate channel sizes will cause the motion encoder to
|
||||
# contain the vast majority of the parameters in the test model
|
||||
channel_sizes = {"4": 16, "8": 16, "16": 16}
|
||||
|
||||
init_dict = {
|
||||
return {
|
||||
"patch_size": (1, 2, 2),
|
||||
"num_attention_heads": 2,
|
||||
"attention_head_dim": 12,
|
||||
@@ -105,22 +91,219 @@ class WanAnimateTransformer3DTests(ModelTesterMixin, unittest.TestCase):
|
||||
"face_encoder_num_heads": 2,
|
||||
"inject_face_latents_blocks": 2,
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
|
||||
batch_size = 1
|
||||
num_channels = 4
|
||||
num_frames = 20 # To make the shapes work out; for complicated reasons we want 21 to divide num_frames + 1
|
||||
height = 16
|
||||
width = 16
|
||||
text_encoder_embedding_dim = 16
|
||||
sequence_length = 12
|
||||
|
||||
clip_seq_len = 12
|
||||
clip_dim = 16
|
||||
|
||||
inference_segment_length = 77 # The inference segment length in the full Wan2.2-Animate-14B model
|
||||
face_height = 16 # Should be square and match `motion_encoder_size`
|
||||
face_width = 16
|
||||
|
||||
return {
|
||||
"hidden_states": randn_tensor(
|
||||
(batch_size, 2 * num_channels + 4, num_frames + 1, height, width),
|
||||
generator=self.generator,
|
||||
device=torch_device,
|
||||
),
|
||||
"timestep": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device),
|
||||
"encoder_hidden_states": randn_tensor(
|
||||
(batch_size, sequence_length, text_encoder_embedding_dim),
|
||||
generator=self.generator,
|
||||
device=torch_device,
|
||||
),
|
||||
"encoder_hidden_states_image": randn_tensor(
|
||||
(batch_size, clip_seq_len, clip_dim),
|
||||
generator=self.generator,
|
||||
device=torch_device,
|
||||
),
|
||||
"pose_hidden_states": randn_tensor(
|
||||
(batch_size, num_channels, num_frames, height, width),
|
||||
generator=self.generator,
|
||||
device=torch_device,
|
||||
),
|
||||
"face_pixel_values": randn_tensor(
|
||||
(batch_size, 3, inference_segment_length, face_height, face_width),
|
||||
generator=self.generator,
|
||||
device=torch_device,
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
class TestWanAnimateTransformer3D(WanAnimateTransformer3DTesterConfig, ModelTesterMixin):
|
||||
"""Core model tests for Wan Animate Transformer 3D."""
|
||||
|
||||
def test_output(self):
|
||||
# Override test_output because the transformer output is expected to have less channels
|
||||
# than the main transformer input.
|
||||
expected_output_shape = (1, 4, 21, 16, 16)
|
||||
super().test_output(expected_output_shape=expected_output_shape)
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"])
|
||||
def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype):
|
||||
# Skip: fp16/bf16 require very high atol (~1e-2) to pass, providing little signal.
|
||||
# Dtype preservation is already tested by test_from_save_pretrained_dtype and test_keep_in_fp32_modules.
|
||||
pytest.skip("Tolerance requirements too high for meaningful test")
|
||||
|
||||
|
||||
class TestWanAnimateTransformer3DMemory(WanAnimateTransformer3DTesterConfig, MemoryTesterMixin):
|
||||
"""Memory optimization tests for Wan Animate Transformer 3D."""
|
||||
|
||||
|
||||
class TestWanAnimateTransformer3DTraining(WanAnimateTransformer3DTesterConfig, TrainingTesterMixin):
|
||||
"""Training tests for Wan Animate Transformer 3D."""
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"WanAnimateTransformer3DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
# Override test_output because the transformer output is expected to have less channels than the main transformer
|
||||
# input.
|
||||
def test_output(self):
|
||||
expected_output_shape = (1, 4, 21, 16, 16)
|
||||
super().test_output(expected_output_shape=expected_output_shape)
|
||||
|
||||
class TestWanAnimateTransformer3DAttention(WanAnimateTransformer3DTesterConfig, AttentionTesterMixin):
|
||||
"""Attention processor tests for Wan Animate Transformer 3D."""
|
||||
|
||||
|
||||
class WanAnimateTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
|
||||
model_class = WanAnimateTransformer3DModel
|
||||
class TestWanAnimateTransformer3DCompile(WanAnimateTransformer3DTesterConfig, TorchCompileTesterMixin):
|
||||
"""Torch compile tests for Wan Animate Transformer 3D."""
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
return WanAnimateTransformer3DTests().prepare_init_args_and_inputs_for_common()
|
||||
def test_torch_compile_recompilation_and_graph_break(self):
|
||||
# Skip: F.pad with mode="replicate" in WanAnimateFaceEncoder triggers importlib.import_module
|
||||
# internally, which dynamo doesn't support tracing through.
|
||||
pytest.skip("F.pad with replicate mode triggers unsupported import in torch.compile")
|
||||
|
||||
|
||||
class TestWanAnimateTransformer3DBitsAndBytes(WanAnimateTransformer3DTesterConfig, BitsAndBytesTesterMixin):
|
||||
"""BitsAndBytes quantization tests for Wan Animate Transformer 3D."""
|
||||
|
||||
@property
|
||||
def torch_dtype(self):
|
||||
return torch.float16
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
"""Override to provide inputs matching the tiny Wan Animate model dimensions."""
|
||||
return {
|
||||
"hidden_states": randn_tensor(
|
||||
(1, 36, 21, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
),
|
||||
"encoder_hidden_states": randn_tensor(
|
||||
(1, 512, 4096), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
),
|
||||
"encoder_hidden_states_image": randn_tensor(
|
||||
(1, 257, 1280), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
),
|
||||
"pose_hidden_states": randn_tensor(
|
||||
(1, 16, 20, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
),
|
||||
"face_pixel_values": randn_tensor(
|
||||
(1, 3, 77, 512, 512), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
),
|
||||
"timestep": torch.tensor([1.0]).to(torch_device, self.torch_dtype),
|
||||
}
|
||||
|
||||
|
||||
class TestWanAnimateTransformer3DTorchAo(WanAnimateTransformer3DTesterConfig, TorchAoTesterMixin):
|
||||
"""TorchAO quantization tests for Wan Animate Transformer 3D."""
|
||||
|
||||
@property
|
||||
def torch_dtype(self):
|
||||
return torch.bfloat16
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
"""Override to provide inputs matching the tiny Wan Animate model dimensions."""
|
||||
return {
|
||||
"hidden_states": randn_tensor(
|
||||
(1, 36, 21, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
),
|
||||
"encoder_hidden_states": randn_tensor(
|
||||
(1, 512, 4096), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
),
|
||||
"encoder_hidden_states_image": randn_tensor(
|
||||
(1, 257, 1280), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
),
|
||||
"pose_hidden_states": randn_tensor(
|
||||
(1, 16, 20, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
),
|
||||
"face_pixel_values": randn_tensor(
|
||||
(1, 3, 77, 512, 512), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
),
|
||||
"timestep": torch.tensor([1.0]).to(torch_device, self.torch_dtype),
|
||||
}
|
||||
|
||||
|
||||
class TestWanAnimateTransformer3DGGUF(WanAnimateTransformer3DTesterConfig, GGUFTesterMixin):
|
||||
"""GGUF quantization tests for Wan Animate Transformer 3D."""
|
||||
|
||||
@property
|
||||
def gguf_filename(self):
|
||||
return "https://huggingface.co/QuantStack/Wan2.2-Animate-14B-GGUF/blob/main/Wan2.2-Animate-14B-Q2_K.gguf"
|
||||
|
||||
@property
|
||||
def torch_dtype(self):
|
||||
return torch.bfloat16
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
"""Override to provide inputs matching the real Wan Animate model dimensions.
|
||||
|
||||
Wan 2.2 Animate: in_channels=36 (2*16+4), text_dim=4096, image_dim=1280
|
||||
"""
|
||||
return {
|
||||
"hidden_states": randn_tensor(
|
||||
(1, 36, 21, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
),
|
||||
"encoder_hidden_states": randn_tensor(
|
||||
(1, 512, 4096), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
),
|
||||
"encoder_hidden_states_image": randn_tensor(
|
||||
(1, 257, 1280), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
),
|
||||
"pose_hidden_states": randn_tensor(
|
||||
(1, 16, 20, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
),
|
||||
"face_pixel_values": randn_tensor(
|
||||
(1, 3, 77, 512, 512), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
),
|
||||
"timestep": torch.tensor([1.0]).to(torch_device, self.torch_dtype),
|
||||
}
|
||||
|
||||
|
||||
class TestWanAnimateTransformer3DGGUFCompile(WanAnimateTransformer3DTesterConfig, GGUFCompileTesterMixin):
|
||||
"""GGUF + compile tests for Wan Animate Transformer 3D."""
|
||||
|
||||
@property
|
||||
def gguf_filename(self):
|
||||
return "https://huggingface.co/QuantStack/Wan2.2-Animate-14B-GGUF/blob/main/Wan2.2-Animate-14B-Q2_K.gguf"
|
||||
|
||||
@property
|
||||
def torch_dtype(self):
|
||||
return torch.bfloat16
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
"""Override to provide inputs matching the real Wan Animate model dimensions.
|
||||
|
||||
Wan 2.2 Animate: in_channels=36 (2*16+4), text_dim=4096, image_dim=1280
|
||||
"""
|
||||
return {
|
||||
"hidden_states": randn_tensor(
|
||||
(1, 36, 21, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
),
|
||||
"encoder_hidden_states": randn_tensor(
|
||||
(1, 512, 4096), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
),
|
||||
"encoder_hidden_states_image": randn_tensor(
|
||||
(1, 257, 1280), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
),
|
||||
"pose_hidden_states": randn_tensor(
|
||||
(1, 16, 20, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
),
|
||||
"face_pixel_values": randn_tensor(
|
||||
(1, 3, 77, 512, 512), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
),
|
||||
"timestep": torch.tensor([1.0]).to(torch_device, self.torch_dtype),
|
||||
}
|
||||
|
||||
271
tests/models/transformers/test_models_transformer_wan_vace.py
Normal file
271
tests/models/transformers/test_models_transformer_wan_vace.py
Normal file
@@ -0,0 +1,271 @@
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from diffusers import WanVACETransformer3DModel
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
from ...testing_utils import enable_full_determinism, torch_device
|
||||
from ..testing_utils import (
|
||||
AttentionTesterMixin,
|
||||
BaseModelTesterConfig,
|
||||
BitsAndBytesTesterMixin,
|
||||
GGUFCompileTesterMixin,
|
||||
GGUFTesterMixin,
|
||||
MemoryTesterMixin,
|
||||
ModelTesterMixin,
|
||||
TorchAoTesterMixin,
|
||||
TorchCompileTesterMixin,
|
||||
TrainingTesterMixin,
|
||||
)
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class WanVACETransformer3DTesterConfig(BaseModelTesterConfig):
|
||||
@property
|
||||
def model_class(self):
|
||||
return WanVACETransformer3DModel
|
||||
|
||||
@property
|
||||
def pretrained_model_name_or_path(self):
|
||||
return "hf-internal-testing/tiny-wan-vace-transformer"
|
||||
|
||||
@property
|
||||
def output_shape(self) -> tuple[int, ...]:
|
||||
return (16, 2, 16, 16)
|
||||
|
||||
@property
|
||||
def input_shape(self) -> tuple[int, ...]:
|
||||
return (16, 2, 16, 16)
|
||||
|
||||
@property
|
||||
def main_input_name(self) -> str:
|
||||
return "hidden_states"
|
||||
|
||||
@property
|
||||
def generator(self):
|
||||
return torch.Generator("cpu").manual_seed(0)
|
||||
|
||||
def get_init_dict(self) -> dict[str, int | list[int] | tuple | str | bool | None]:
|
||||
return {
|
||||
"patch_size": (1, 2, 2),
|
||||
"num_attention_heads": 2,
|
||||
"attention_head_dim": 12,
|
||||
"in_channels": 16,
|
||||
"out_channels": 16,
|
||||
"text_dim": 32,
|
||||
"freq_dim": 256,
|
||||
"ffn_dim": 32,
|
||||
"num_layers": 4,
|
||||
"cross_attn_norm": True,
|
||||
"qk_norm": "rms_norm_across_heads",
|
||||
"rope_max_seq_len": 32,
|
||||
"vace_layers": [0, 2],
|
||||
"vace_in_channels": 48, # 3 * in_channels = 3 * 16 = 48
|
||||
}
|
||||
|
||||
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
|
||||
batch_size = 1
|
||||
num_channels = 16
|
||||
num_frames = 2
|
||||
height = 16
|
||||
width = 16
|
||||
text_encoder_embedding_dim = 32
|
||||
sequence_length = 12
|
||||
|
||||
# VACE requires control_hidden_states with vace_in_channels (3 * in_channels)
|
||||
vace_in_channels = 48
|
||||
|
||||
return {
|
||||
"hidden_states": randn_tensor(
|
||||
(batch_size, num_channels, num_frames, height, width),
|
||||
generator=self.generator,
|
||||
device=torch_device,
|
||||
),
|
||||
"encoder_hidden_states": randn_tensor(
|
||||
(batch_size, sequence_length, text_encoder_embedding_dim),
|
||||
generator=self.generator,
|
||||
device=torch_device,
|
||||
),
|
||||
"control_hidden_states": randn_tensor(
|
||||
(batch_size, vace_in_channels, num_frames, height, width),
|
||||
generator=self.generator,
|
||||
device=torch_device,
|
||||
),
|
||||
"timestep": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device),
|
||||
}
|
||||
|
||||
|
||||
class TestWanVACETransformer3D(WanVACETransformer3DTesterConfig, ModelTesterMixin):
|
||||
"""Core model tests for Wan VACE Transformer 3D."""
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"])
|
||||
def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype):
|
||||
# Skip: fp16/bf16 require very high atol to pass, providing little signal.
|
||||
# Dtype preservation is already tested by test_from_save_pretrained_dtype and test_keep_in_fp32_modules.
|
||||
pytest.skip("Tolerance requirements too high for meaningful test")
|
||||
|
||||
def test_model_parallelism(self, tmp_path):
|
||||
# Skip: Device mismatch between cuda:0 and cuda:1 in VACE control flow
|
||||
pytest.skip("Model parallelism not yet supported for WanVACE")
|
||||
|
||||
|
||||
class TestWanVACETransformer3DMemory(WanVACETransformer3DTesterConfig, MemoryTesterMixin):
|
||||
"""Memory optimization tests for Wan VACE Transformer 3D."""
|
||||
|
||||
|
||||
class TestWanVACETransformer3DTraining(WanVACETransformer3DTesterConfig, TrainingTesterMixin):
|
||||
"""Training tests for Wan VACE Transformer 3D."""
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"WanVACETransformer3DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
|
||||
class TestWanVACETransformer3DAttention(WanVACETransformer3DTesterConfig, AttentionTesterMixin):
|
||||
"""Attention processor tests for Wan VACE Transformer 3D."""
|
||||
|
||||
|
||||
class TestWanVACETransformer3DCompile(WanVACETransformer3DTesterConfig, TorchCompileTesterMixin):
|
||||
"""Torch compile tests for Wan VACE Transformer 3D."""
|
||||
|
||||
def test_torch_compile_repeated_blocks(self):
|
||||
# WanVACE has two block types (WanTransformerBlock and WanVACETransformerBlock),
|
||||
# so we need recompile_limit=2 instead of the default 1.
|
||||
import torch._dynamo
|
||||
import torch._inductor.utils
|
||||
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
model.eval()
|
||||
model.compile_repeated_blocks(fullgraph=True)
|
||||
|
||||
with (
|
||||
torch._inductor.utils.fresh_inductor_cache(),
|
||||
torch._dynamo.config.patch(recompile_limit=2),
|
||||
):
|
||||
_ = model(**inputs_dict)
|
||||
_ = model(**inputs_dict)
|
||||
|
||||
|
||||
class TestWanVACETransformer3DBitsAndBytes(WanVACETransformer3DTesterConfig, BitsAndBytesTesterMixin):
|
||||
"""BitsAndBytes quantization tests for Wan VACE Transformer 3D."""
|
||||
|
||||
@property
|
||||
def torch_dtype(self):
|
||||
return torch.float16
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
"""Override to provide inputs matching the tiny Wan VACE model dimensions."""
|
||||
return {
|
||||
"hidden_states": randn_tensor(
|
||||
(1, 16, 2, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
),
|
||||
"encoder_hidden_states": randn_tensor(
|
||||
(1, 512, 4096), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
),
|
||||
"control_hidden_states": randn_tensor(
|
||||
(1, 96, 2, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
),
|
||||
"timestep": torch.tensor([1.0]).to(torch_device, self.torch_dtype),
|
||||
}
|
||||
|
||||
|
||||
class TestWanVACETransformer3DTorchAo(WanVACETransformer3DTesterConfig, TorchAoTesterMixin):
|
||||
"""TorchAO quantization tests for Wan VACE Transformer 3D."""
|
||||
|
||||
@property
|
||||
def torch_dtype(self):
|
||||
return torch.bfloat16
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
"""Override to provide inputs matching the tiny Wan VACE model dimensions."""
|
||||
return {
|
||||
"hidden_states": randn_tensor(
|
||||
(1, 16, 2, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
),
|
||||
"encoder_hidden_states": randn_tensor(
|
||||
(1, 512, 4096), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
),
|
||||
"control_hidden_states": randn_tensor(
|
||||
(1, 96, 2, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
),
|
||||
"timestep": torch.tensor([1.0]).to(torch_device, self.torch_dtype),
|
||||
}
|
||||
|
||||
|
||||
class TestWanVACETransformer3DGGUF(WanVACETransformer3DTesterConfig, GGUFTesterMixin):
|
||||
"""GGUF quantization tests for Wan VACE Transformer 3D."""
|
||||
|
||||
@property
|
||||
def gguf_filename(self):
|
||||
return "https://huggingface.co/QuantStack/Wan2.1_14B_VACE-GGUF/blob/main/Wan2.1_14B_VACE-Q3_K_S.gguf"
|
||||
|
||||
@property
|
||||
def torch_dtype(self):
|
||||
return torch.bfloat16
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
"""Override to provide inputs matching the real Wan VACE model dimensions.
|
||||
|
||||
Wan 2.1 VACE: in_channels=16, text_dim=4096, vace_in_channels=96
|
||||
"""
|
||||
return {
|
||||
"hidden_states": randn_tensor(
|
||||
(1, 16, 2, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
),
|
||||
"encoder_hidden_states": randn_tensor(
|
||||
(1, 512, 4096), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
),
|
||||
"control_hidden_states": randn_tensor(
|
||||
(1, 96, 2, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
),
|
||||
"timestep": torch.tensor([1.0]).to(torch_device, self.torch_dtype),
|
||||
}
|
||||
|
||||
|
||||
class TestWanVACETransformer3DGGUFCompile(WanVACETransformer3DTesterConfig, GGUFCompileTesterMixin):
|
||||
"""GGUF + compile tests for Wan VACE Transformer 3D."""
|
||||
|
||||
@property
|
||||
def gguf_filename(self):
|
||||
return "https://huggingface.co/QuantStack/Wan2.1_14B_VACE-GGUF/blob/main/Wan2.1_14B_VACE-Q3_K_S.gguf"
|
||||
|
||||
@property
|
||||
def torch_dtype(self):
|
||||
return torch.bfloat16
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
"""Override to provide inputs matching the real Wan VACE model dimensions.
|
||||
|
||||
Wan 2.1 VACE: in_channels=16, text_dim=4096, vace_in_channels=96
|
||||
"""
|
||||
return {
|
||||
"hidden_states": randn_tensor(
|
||||
(1, 16, 2, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
),
|
||||
"encoder_hidden_states": randn_tensor(
|
||||
(1, 512, 4096), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
),
|
||||
"control_hidden_states": randn_tensor(
|
||||
(1, 96, 2, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
),
|
||||
"timestep": torch.tensor([1.0]).to(torch_device, self.torch_dtype),
|
||||
}
|
||||
@@ -6,7 +6,7 @@ import pytest
|
||||
import torch
|
||||
|
||||
import diffusers
|
||||
from diffusers import ComponentsManager, ModularPipeline, ModularPipelineBlocks
|
||||
from diffusers import AutoModel, ComponentsManager, ModularPipeline, ModularPipelineBlocks
|
||||
from diffusers.guiders import ClassifierFreeGuidance
|
||||
from diffusers.modular_pipelines.modular_pipeline_utils import (
|
||||
ComponentSpec,
|
||||
@@ -37,6 +37,9 @@ class ModularPipelineTesterMixin:
|
||||
optional_params = frozenset(["num_inference_steps", "num_images_per_prompt", "latents", "output_type"])
|
||||
# this is modular specific: generator needs to be a intermediate input because it's mutable
|
||||
intermediate_params = frozenset(["generator"])
|
||||
# Output type for the pipeline (e.g., "images" for image pipelines, "videos" for video pipelines)
|
||||
# Subclasses can override this to change the expected output type
|
||||
output_name = "images"
|
||||
|
||||
def get_generator(self, seed=0):
|
||||
generator = torch.Generator("cpu").manual_seed(seed)
|
||||
@@ -163,7 +166,7 @@ class ModularPipelineTesterMixin:
|
||||
|
||||
logger.setLevel(level=diffusers.logging.WARNING)
|
||||
for batch_size, batched_input in zip(batch_sizes, batched_inputs):
|
||||
output = pipe(**batched_input, output="images")
|
||||
output = pipe(**batched_input, output=self.output_name)
|
||||
assert len(output) == batch_size, "Output is different from expected batch size"
|
||||
|
||||
def test_inference_batch_single_identical(
|
||||
@@ -197,12 +200,16 @@ class ModularPipelineTesterMixin:
|
||||
if "batch_size" in inputs:
|
||||
batched_inputs["batch_size"] = batch_size
|
||||
|
||||
output = pipe(**inputs, output="images")
|
||||
output_batch = pipe(**batched_inputs, output="images")
|
||||
output = pipe(**inputs, output=self.output_name)
|
||||
output_batch = pipe(**batched_inputs, output=self.output_name)
|
||||
|
||||
assert output_batch.shape[0] == batch_size
|
||||
|
||||
max_diff = torch.abs(output_batch[0] - output[0]).max()
|
||||
# For batch comparison, we only need to compare the first item
|
||||
if output_batch.shape[0] == batch_size and output.shape[0] == 1:
|
||||
output_batch = output_batch[0:1]
|
||||
|
||||
max_diff = torch.abs(output_batch - output).max()
|
||||
assert max_diff < expected_max_diff, "Batch inference results different from single inference results"
|
||||
|
||||
@require_accelerator
|
||||
@@ -217,19 +224,32 @@ class ModularPipelineTesterMixin:
|
||||
# Reset generator in case it is used inside dummy inputs
|
||||
if "generator" in inputs:
|
||||
inputs["generator"] = self.get_generator(0)
|
||||
output = pipe(**inputs, output="images")
|
||||
|
||||
output = pipe(**inputs, output=self.output_name)
|
||||
|
||||
fp16_inputs = self.get_dummy_inputs()
|
||||
# Reset generator in case it is used inside dummy inputs
|
||||
if "generator" in fp16_inputs:
|
||||
fp16_inputs["generator"] = self.get_generator(0)
|
||||
output_fp16 = pipe_fp16(**fp16_inputs, output="images")
|
||||
|
||||
output = output.cpu()
|
||||
output_fp16 = output_fp16.cpu()
|
||||
output_fp16 = pipe_fp16(**fp16_inputs, output=self.output_name)
|
||||
|
||||
max_diff = numpy_cosine_similarity_distance(output.flatten(), output_fp16.flatten())
|
||||
assert max_diff < expected_max_diff, "FP16 inference is different from FP32 inference"
|
||||
output_tensor = output.float().cpu()
|
||||
output_fp16_tensor = output_fp16.float().cpu()
|
||||
|
||||
# Check for NaNs in outputs (can happen with tiny models in FP16)
|
||||
if torch.isnan(output_tensor).any() or torch.isnan(output_fp16_tensor).any():
|
||||
pytest.skip("FP16 inference produces NaN values - this is a known issue with tiny models")
|
||||
|
||||
max_diff = numpy_cosine_similarity_distance(
|
||||
output_tensor.flatten().numpy(), output_fp16_tensor.flatten().numpy()
|
||||
)
|
||||
|
||||
# Check if cosine similarity is NaN (which can happen if vectors are zero or very small)
|
||||
if torch.isnan(torch.tensor(max_diff)):
|
||||
pytest.skip("Cosine similarity is NaN - outputs may be too small for reliable comparison")
|
||||
|
||||
assert max_diff < expected_max_diff, f"FP16 inference is different from FP32 inference (max_diff: {max_diff})"
|
||||
|
||||
@require_accelerator
|
||||
def test_to_device(self):
|
||||
@@ -251,14 +271,16 @@ class ModularPipelineTesterMixin:
|
||||
def test_inference_is_not_nan_cpu(self):
|
||||
pipe = self.get_pipeline().to("cpu")
|
||||
|
||||
output = pipe(**self.get_dummy_inputs(), output="images")
|
||||
inputs = self.get_dummy_inputs()
|
||||
output = pipe(**inputs, output=self.output_name)
|
||||
assert torch.isnan(output).sum() == 0, "CPU Inference returns NaN"
|
||||
|
||||
@require_accelerator
|
||||
def test_inference_is_not_nan(self):
|
||||
pipe = self.get_pipeline().to(torch_device)
|
||||
|
||||
output = pipe(**self.get_dummy_inputs(), output="images")
|
||||
inputs = self.get_dummy_inputs()
|
||||
output = pipe(**inputs, output=self.output_name)
|
||||
assert torch.isnan(output).sum() == 0, "Accelerator Inference returns NaN"
|
||||
|
||||
def test_num_images_per_prompt(self):
|
||||
@@ -278,7 +300,7 @@ class ModularPipelineTesterMixin:
|
||||
if key in self.batch_params:
|
||||
inputs[key] = batch_size * [inputs[key]]
|
||||
|
||||
images = pipe(**inputs, num_images_per_prompt=num_images_per_prompt, output="images")
|
||||
images = pipe(**inputs, num_images_per_prompt=num_images_per_prompt, output=self.output_name)
|
||||
|
||||
assert images.shape[0] == batch_size * num_images_per_prompt
|
||||
|
||||
@@ -293,8 +315,7 @@ class ModularPipelineTesterMixin:
|
||||
image_slices = []
|
||||
for pipe in [base_pipe, offload_pipe]:
|
||||
inputs = self.get_dummy_inputs()
|
||||
image = pipe(**inputs, output="images")
|
||||
|
||||
image = pipe(**inputs, output=self.output_name)
|
||||
image_slices.append(image[0, -3:, -3:, -1].flatten())
|
||||
|
||||
assert torch.abs(image_slices[0] - image_slices[1]).max() < 1e-3
|
||||
@@ -315,8 +336,7 @@ class ModularPipelineTesterMixin:
|
||||
image_slices = []
|
||||
for pipe in pipes:
|
||||
inputs = self.get_dummy_inputs()
|
||||
image = pipe(**inputs, output="images")
|
||||
|
||||
image = pipe(**inputs, output=self.output_name)
|
||||
image_slices.append(image[0, -3:, -3:, -1].flatten())
|
||||
|
||||
assert torch.abs(image_slices[0] - image_slices[1]).max() < 1e-3
|
||||
@@ -331,13 +351,13 @@ class ModularGuiderTesterMixin:
|
||||
pipe.update_components(guider=guider)
|
||||
|
||||
inputs = self.get_dummy_inputs()
|
||||
out_no_cfg = pipe(**inputs, output="images")
|
||||
out_no_cfg = pipe(**inputs, output=self.output_name)
|
||||
|
||||
# forward pass with CFG applied
|
||||
guider = ClassifierFreeGuidance(guidance_scale=7.5)
|
||||
pipe.update_components(guider=guider)
|
||||
inputs = self.get_dummy_inputs()
|
||||
out_cfg = pipe(**inputs, output="images")
|
||||
out_cfg = pipe(**inputs, output=self.output_name)
|
||||
|
||||
assert out_cfg.shape == out_no_cfg.shape
|
||||
max_diff = torch.abs(out_cfg - out_no_cfg).max()
|
||||
@@ -578,3 +598,68 @@ class TestModularModelCardContent:
|
||||
content = generate_modular_model_card_content(blocks)
|
||||
|
||||
assert "5-block architecture" in content["model_description"]
|
||||
|
||||
|
||||
class TestAutoModelLoadIdTagging:
|
||||
def test_automodel_tags_load_id(self):
|
||||
model = AutoModel.from_pretrained("hf-internal-testing/tiny-stable-diffusion-xl-pipe", subfolder="unet")
|
||||
|
||||
assert hasattr(model, "_diffusers_load_id"), "Model should have _diffusers_load_id attribute"
|
||||
assert model._diffusers_load_id != "null", "_diffusers_load_id should not be 'null'"
|
||||
|
||||
# Verify load_id contains the expected fields
|
||||
load_id = model._diffusers_load_id
|
||||
assert "hf-internal-testing/tiny-stable-diffusion-xl-pipe" in load_id
|
||||
assert "unet" in load_id
|
||||
|
||||
def test_automodel_update_components(self):
|
||||
pipe = ModularPipeline.from_pretrained("hf-internal-testing/tiny-stable-diffusion-xl-pipe")
|
||||
pipe.load_components(torch_dtype=torch.float32)
|
||||
|
||||
auto_model = AutoModel.from_pretrained("hf-internal-testing/tiny-stable-diffusion-xl-pipe", subfolder="unet")
|
||||
|
||||
pipe.update_components(unet=auto_model)
|
||||
|
||||
assert pipe.unet is auto_model
|
||||
|
||||
assert "unet" in pipe._component_specs
|
||||
spec = pipe._component_specs["unet"]
|
||||
assert spec.pretrained_model_name_or_path == "hf-internal-testing/tiny-stable-diffusion-xl-pipe"
|
||||
assert spec.subfolder == "unet"
|
||||
|
||||
|
||||
class TestLoadComponentsSkipBehavior:
|
||||
def test_load_components_skips_already_loaded(self):
|
||||
pipe = ModularPipeline.from_pretrained("hf-internal-testing/tiny-stable-diffusion-xl-pipe")
|
||||
pipe.load_components(torch_dtype=torch.float32)
|
||||
|
||||
original_unet = pipe.unet
|
||||
|
||||
pipe.load_components()
|
||||
|
||||
# Verify that the unet is the same object (not reloaded)
|
||||
assert pipe.unet is original_unet, "load_components should skip already loaded components"
|
||||
|
||||
def test_load_components_selective_loading(self):
|
||||
pipe = ModularPipeline.from_pretrained("hf-internal-testing/tiny-stable-diffusion-xl-pipe")
|
||||
|
||||
pipe.load_components(names="unet", torch_dtype=torch.float32)
|
||||
|
||||
# Verify only requested component was loaded.
|
||||
assert hasattr(pipe, "unet")
|
||||
assert pipe.unet is not None
|
||||
assert getattr(pipe, "vae", None) is None
|
||||
|
||||
def test_load_components_skips_invalid_pretrained_path(self):
|
||||
pipe = ModularPipeline.from_pretrained("hf-internal-testing/tiny-stable-diffusion-xl-pipe")
|
||||
|
||||
pipe._component_specs["test_component"] = ComponentSpec(
|
||||
name="test_component",
|
||||
type_hint=torch.nn.Module,
|
||||
pretrained_model_name_or_path=None,
|
||||
default_creation_method="from_pretrained",
|
||||
)
|
||||
pipe.load_components(torch_dtype=torch.float32)
|
||||
|
||||
# Verify test_component was not loaded
|
||||
assert not hasattr(pipe, "test_component") or pipe.test_component is None
|
||||
|
||||
0
tests/modular_pipelines/wan/__init__.py
Normal file
0
tests/modular_pipelines/wan/__init__.py
Normal file
49
tests/modular_pipelines/wan/test_modular_pipeline_wan.py
Normal file
49
tests/modular_pipelines/wan/test_modular_pipeline_wan.py
Normal file
@@ -0,0 +1,49 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import pytest
|
||||
|
||||
from diffusers.modular_pipelines import WanBlocks, WanModularPipeline
|
||||
|
||||
from ..test_modular_pipelines_common import ModularPipelineTesterMixin
|
||||
|
||||
|
||||
class TestWanModularPipelineFast(ModularPipelineTesterMixin):
|
||||
pipeline_class = WanModularPipeline
|
||||
pipeline_blocks_class = WanBlocks
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-wan-modular-pipe"
|
||||
|
||||
params = frozenset(["prompt", "height", "width", "num_frames"])
|
||||
batch_params = frozenset(["prompt"])
|
||||
optional_params = frozenset(["num_inference_steps", "num_videos_per_prompt", "latents"])
|
||||
output_name = "videos"
|
||||
|
||||
def get_dummy_inputs(self, seed=0):
|
||||
generator = self.get_generator(seed)
|
||||
inputs = {
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"height": 16,
|
||||
"width": 16,
|
||||
"num_frames": 9,
|
||||
"max_sequence_length": 16,
|
||||
"output_type": "pt",
|
||||
}
|
||||
return inputs
|
||||
|
||||
@pytest.mark.skip(reason="num_videos_per_prompt")
|
||||
def test_num_images_per_prompt(self):
|
||||
pass
|
||||
0
tests/modular_pipelines/z_image/__init__.py
Normal file
0
tests/modular_pipelines/z_image/__init__.py
Normal file
@@ -0,0 +1,44 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from diffusers.modular_pipelines import ZImageAutoBlocks, ZImageModularPipeline
|
||||
|
||||
from ..test_modular_pipelines_common import ModularPipelineTesterMixin
|
||||
|
||||
|
||||
class TestZImageModularPipelineFast(ModularPipelineTesterMixin):
|
||||
pipeline_class = ZImageModularPipeline
|
||||
pipeline_blocks_class = ZImageAutoBlocks
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-zimage-modular-pipe"
|
||||
|
||||
params = frozenset(["prompt", "height", "width"])
|
||||
batch_params = frozenset(["prompt"])
|
||||
|
||||
def get_dummy_inputs(self, seed=0):
|
||||
generator = self.get_generator(seed)
|
||||
inputs = {
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"height": 32,
|
||||
"width": 32,
|
||||
"max_sequence_length": 16,
|
||||
"output_type": "pt",
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_inference_batch_single_identical(self):
|
||||
super().test_inference_batch_single_identical(expected_max_diff=5e-3)
|
||||
386
tests/pipelines/cosmos/test_cosmos2_5_transfer.py
Normal file
386
tests/pipelines/cosmos/test_cosmos2_5_transfer.py
Normal file
@@ -0,0 +1,386 @@
|
||||
# Copyright 2025 The HuggingFace Team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import Qwen2_5_VLConfig, Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKLWan,
|
||||
Cosmos2_5_TransferPipeline,
|
||||
CosmosControlNetModel,
|
||||
CosmosTransformer3DModel,
|
||||
UniPCMultistepScheduler,
|
||||
)
|
||||
|
||||
from ...testing_utils import enable_full_determinism, torch_device
|
||||
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
|
||||
from ..test_pipelines_common import PipelineTesterMixin, to_np
|
||||
from .cosmos_guardrail import DummyCosmosSafetyChecker
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class Cosmos2_5_TransferWrapper(Cosmos2_5_TransferPipeline):
|
||||
@staticmethod
|
||||
def from_pretrained(*args, **kwargs):
|
||||
if "safety_checker" not in kwargs or kwargs["safety_checker"] is None:
|
||||
safety_checker = DummyCosmosSafetyChecker()
|
||||
device_map = kwargs.get("device_map", "cpu")
|
||||
torch_dtype = kwargs.get("torch_dtype")
|
||||
if device_map is not None or torch_dtype is not None:
|
||||
safety_checker = safety_checker.to(device_map, dtype=torch_dtype)
|
||||
kwargs["safety_checker"] = safety_checker
|
||||
return Cosmos2_5_TransferPipeline.from_pretrained(*args, **kwargs)
|
||||
|
||||
|
||||
class Cosmos2_5_TransferPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = Cosmos2_5_TransferWrapper
|
||||
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
|
||||
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
|
||||
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
required_optional_params = frozenset(
|
||||
[
|
||||
"num_inference_steps",
|
||||
"generator",
|
||||
"latents",
|
||||
"return_dict",
|
||||
"callback_on_step_end",
|
||||
"callback_on_step_end_tensor_inputs",
|
||||
]
|
||||
)
|
||||
supports_dduf = False
|
||||
test_xformers_attention = False
|
||||
test_layerwise_casting = True
|
||||
test_group_offloading = True
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
# Transformer with img_context support for Transfer2.5
|
||||
transformer = CosmosTransformer3DModel(
|
||||
in_channels=16 + 1,
|
||||
out_channels=16,
|
||||
num_attention_heads=2,
|
||||
attention_head_dim=16,
|
||||
num_layers=2,
|
||||
mlp_ratio=2,
|
||||
text_embed_dim=32,
|
||||
adaln_lora_dim=4,
|
||||
max_size=(4, 32, 32),
|
||||
patch_size=(1, 2, 2),
|
||||
rope_scale=(2.0, 1.0, 1.0),
|
||||
concat_padding_mask=True,
|
||||
extra_pos_embed_type="learnable",
|
||||
controlnet_block_every_n=1,
|
||||
img_context_dim_in=32,
|
||||
img_context_num_tokens=4,
|
||||
img_context_dim_out=32,
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
controlnet = CosmosControlNetModel(
|
||||
n_controlnet_blocks=2,
|
||||
in_channels=16 + 1 + 1, # control latent channels + condition_mask + padding_mask
|
||||
latent_channels=16 + 1 + 1, # base latent channels (16) + condition_mask (1) + padding_mask (1) = 18
|
||||
model_channels=32,
|
||||
num_attention_heads=2,
|
||||
attention_head_dim=16,
|
||||
mlp_ratio=2,
|
||||
text_embed_dim=32,
|
||||
adaln_lora_dim=4,
|
||||
patch_size=(1, 2, 2),
|
||||
max_size=(4, 32, 32),
|
||||
rope_scale=(2.0, 1.0, 1.0),
|
||||
extra_pos_embed_type="learnable", # Match transformer's config
|
||||
img_context_dim_in=32,
|
||||
img_context_dim_out=32,
|
||||
use_crossattn_projection=False, # Test doesn't need this projection
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKLWan(
|
||||
base_dim=3,
|
||||
z_dim=16,
|
||||
dim_mult=[1, 1, 1, 1],
|
||||
num_res_blocks=1,
|
||||
temperal_downsample=[False, True, True],
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
scheduler = UniPCMultistepScheduler()
|
||||
|
||||
torch.manual_seed(0)
|
||||
config = Qwen2_5_VLConfig(
|
||||
text_config={
|
||||
"hidden_size": 16,
|
||||
"intermediate_size": 16,
|
||||
"num_hidden_layers": 2,
|
||||
"num_attention_heads": 2,
|
||||
"num_key_value_heads": 2,
|
||||
"rope_scaling": {
|
||||
"mrope_section": [1, 1, 2],
|
||||
"rope_type": "default",
|
||||
"type": "default",
|
||||
},
|
||||
"rope_theta": 1000000.0,
|
||||
},
|
||||
vision_config={
|
||||
"depth": 2,
|
||||
"hidden_size": 16,
|
||||
"intermediate_size": 16,
|
||||
"num_heads": 2,
|
||||
"out_hidden_size": 16,
|
||||
},
|
||||
hidden_size=16,
|
||||
vocab_size=152064,
|
||||
vision_end_token_id=151653,
|
||||
vision_start_token_id=151652,
|
||||
vision_token_id=151654,
|
||||
)
|
||||
text_encoder = Qwen2_5_VLForConditionalGeneration(config)
|
||||
tokenizer = Qwen2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration")
|
||||
|
||||
components = {
|
||||
"transformer": transformer,
|
||||
"controlnet": controlnet,
|
||||
"vae": vae,
|
||||
"scheduler": scheduler,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
"safety_checker": DummyCosmosSafetyChecker(),
|
||||
}
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
|
||||
inputs = {
|
||||
"prompt": "dance monkey",
|
||||
"negative_prompt": "bad quality",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 3.0,
|
||||
"height": 32,
|
||||
"width": 32,
|
||||
"num_frames": 3,
|
||||
"max_sequence_length": 16,
|
||||
"output_type": "pt",
|
||||
}
|
||||
|
||||
return inputs
|
||||
|
||||
def test_components_function(self):
|
||||
init_components = self.get_dummy_components()
|
||||
init_components = {k: v for k, v in init_components.items() if not isinstance(v, (str, int, float))}
|
||||
pipe = self.pipeline_class(**init_components)
|
||||
self.assertTrue(hasattr(pipe, "components"))
|
||||
self.assertTrue(set(pipe.components.keys()) == set(init_components.keys()))
|
||||
|
||||
def test_inference(self):
|
||||
device = "cpu"
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
video = pipe(**inputs).frames
|
||||
generated_video = video[0]
|
||||
self.assertEqual(generated_video.shape, (3, 3, 32, 32))
|
||||
self.assertTrue(torch.isfinite(generated_video).all())
|
||||
|
||||
def test_inference_with_controls(self):
|
||||
"""Test inference with control inputs (ControlNet)."""
|
||||
device = "cpu"
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
# Add control video input - should be a video tensor
|
||||
inputs["controls"] = [torch.randn(3, 3, 32, 32)] # num_frames, channels, height, width
|
||||
inputs["controls_conditioning_scale"] = 1.0
|
||||
|
||||
video = pipe(**inputs).frames
|
||||
generated_video = video[0]
|
||||
self.assertEqual(generated_video.shape, (3, 3, 32, 32))
|
||||
self.assertTrue(torch.isfinite(generated_video).all())
|
||||
|
||||
def test_callback_inputs(self):
|
||||
sig = inspect.signature(self.pipeline_class.__call__)
|
||||
has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters
|
||||
has_callback_step_end = "callback_on_step_end" in sig.parameters
|
||||
|
||||
if not (has_callback_tensor_inputs and has_callback_step_end):
|
||||
return
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
self.assertTrue(
|
||||
hasattr(pipe, "_callback_tensor_inputs"),
|
||||
f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs",
|
||||
)
|
||||
|
||||
def callback_inputs_subset(pipe, i, t, callback_kwargs):
|
||||
for tensor_name in callback_kwargs.keys():
|
||||
assert tensor_name in pipe._callback_tensor_inputs
|
||||
return callback_kwargs
|
||||
|
||||
def callback_inputs_all(pipe, i, t, callback_kwargs):
|
||||
for tensor_name in pipe._callback_tensor_inputs:
|
||||
assert tensor_name in callback_kwargs
|
||||
for tensor_name in callback_kwargs.keys():
|
||||
assert tensor_name in pipe._callback_tensor_inputs
|
||||
return callback_kwargs
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
|
||||
inputs["callback_on_step_end"] = callback_inputs_subset
|
||||
inputs["callback_on_step_end_tensor_inputs"] = ["latents"]
|
||||
_ = pipe(**inputs)[0]
|
||||
|
||||
inputs["callback_on_step_end"] = callback_inputs_all
|
||||
inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
|
||||
_ = pipe(**inputs)[0]
|
||||
|
||||
def callback_inputs_change_tensor(pipe, i, t, callback_kwargs):
|
||||
is_last = i == (pipe.num_timesteps - 1)
|
||||
if is_last:
|
||||
callback_kwargs["latents"] = torch.zeros_like(callback_kwargs["latents"])
|
||||
return callback_kwargs
|
||||
|
||||
inputs["callback_on_step_end"] = callback_inputs_change_tensor
|
||||
inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
|
||||
output = pipe(**inputs)[0]
|
||||
assert output.abs().sum() < 1e10
|
||||
|
||||
def test_inference_batch_single_identical(self):
|
||||
self._test_inference_batch_single_identical(batch_size=2, expected_max_diff=1e-2)
|
||||
|
||||
def test_attention_slicing_forward_pass(
|
||||
self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3
|
||||
):
|
||||
if not getattr(self, "test_attention_slicing", True):
|
||||
return
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
for component in pipe.components.values():
|
||||
if hasattr(component, "set_default_attn_processor"):
|
||||
component.set_default_attn_processor()
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator_device = "cpu"
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
output_without_slicing = pipe(**inputs)[0]
|
||||
|
||||
pipe.enable_attention_slicing(slice_size=1)
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
output_with_slicing1 = pipe(**inputs)[0]
|
||||
|
||||
pipe.enable_attention_slicing(slice_size=2)
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
output_with_slicing2 = pipe(**inputs)[0]
|
||||
|
||||
if test_max_difference:
|
||||
max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max()
|
||||
max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max()
|
||||
self.assertLess(
|
||||
max(max_diff1, max_diff2),
|
||||
expected_max_diff,
|
||||
"Attention slicing should not affect the inference results",
|
||||
)
|
||||
|
||||
def test_serialization_with_variants(self):
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
model_components = [
|
||||
component_name
|
||||
for component_name, component in pipe.components.items()
|
||||
if isinstance(component, torch.nn.Module)
|
||||
]
|
||||
# Remove components that aren't saved as standard diffusers models
|
||||
if "safety_checker" in model_components:
|
||||
model_components.remove("safety_checker")
|
||||
variant = "fp16"
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
pipe.save_pretrained(tmpdir, variant=variant, safe_serialization=False)
|
||||
|
||||
with open(f"{tmpdir}/model_index.json", "r") as f:
|
||||
config = json.load(f)
|
||||
|
||||
for subfolder in os.listdir(tmpdir):
|
||||
if not os.path.isfile(subfolder) and subfolder in model_components:
|
||||
folder_path = os.path.join(tmpdir, subfolder)
|
||||
is_folder = os.path.isdir(folder_path) and subfolder in config
|
||||
assert is_folder and any(p.split(".")[1].startswith(variant) for p in os.listdir(folder_path))
|
||||
|
||||
def test_torch_dtype_dict(self):
|
||||
components = self.get_dummy_components()
|
||||
if not components:
|
||||
self.skipTest("No dummy components defined.")
|
||||
|
||||
pipe = self.pipeline_class(**components)
|
||||
|
||||
specified_key = next(iter(components.keys()))
|
||||
|
||||
with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as tmpdirname:
|
||||
pipe.save_pretrained(tmpdirname, safe_serialization=False)
|
||||
torch_dtype_dict = {specified_key: torch.bfloat16, "default": torch.float16}
|
||||
loaded_pipe = self.pipeline_class.from_pretrained(
|
||||
tmpdirname, safety_checker=DummyCosmosSafetyChecker(), torch_dtype=torch_dtype_dict
|
||||
)
|
||||
|
||||
for name, component in loaded_pipe.components.items():
|
||||
# Skip components that are not loaded from disk or have special handling
|
||||
if name == "safety_checker":
|
||||
continue
|
||||
if isinstance(component, torch.nn.Module) and hasattr(component, "dtype"):
|
||||
expected_dtype = torch_dtype_dict.get(name, torch_dtype_dict.get("default", torch.float32))
|
||||
self.assertEqual(
|
||||
component.dtype,
|
||||
expected_dtype,
|
||||
f"Component '{name}' has dtype {component.dtype} but expected {expected_dtype}",
|
||||
)
|
||||
|
||||
def test_save_load_optional_components(self, expected_max_difference=1e-4):
|
||||
self.pipeline_class._optional_components.remove("safety_checker")
|
||||
super().test_save_load_optional_components(expected_max_difference=expected_max_difference)
|
||||
self.pipeline_class._optional_components.append("safety_checker")
|
||||
|
||||
@unittest.skip(
|
||||
"The pipeline should not be runnable without a safety checker. The test creates a pipeline without passing in "
|
||||
"a safety checker, which makes the pipeline default to the actual Cosmos Guardrail. The Cosmos Guardrail is "
|
||||
"too large and slow to run on CI."
|
||||
)
|
||||
def test_encode_prompt_works_in_isolation(self):
|
||||
pass
|
||||
@@ -281,6 +281,86 @@ class GlmImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
# Should return 4 images (2 prompts × 2 images per prompt)
|
||||
self.assertEqual(len(images), 4)
|
||||
|
||||
def test_prompt_with_prior_token_ids(self):
|
||||
"""Test that prompt and prior_token_ids can be provided together.
|
||||
|
||||
When both are given, the AR generation step is skipped (prior_token_ids is used
|
||||
directly) and prompt is used to generate prompt_embeds via the glyph encoder.
|
||||
"""
|
||||
device = "cpu"
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
height, width = 32, 32
|
||||
|
||||
# Step 1: Run with prompt only to get prior_token_ids from AR model
|
||||
generator = torch.Generator(device=device).manual_seed(0)
|
||||
prior_token_ids, _, _ = pipe.generate_prior_tokens(
|
||||
prompt="A photo of a cat",
|
||||
height=height,
|
||||
width=width,
|
||||
device=torch.device(device),
|
||||
generator=torch.Generator(device=device).manual_seed(0),
|
||||
)
|
||||
|
||||
# Step 2: Run with both prompt and prior_token_ids — should not raise
|
||||
generator = torch.Generator(device=device).manual_seed(0)
|
||||
inputs_both = {
|
||||
"prompt": "A photo of a cat",
|
||||
"prior_token_ids": prior_token_ids,
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 1.5,
|
||||
"height": height,
|
||||
"width": width,
|
||||
"max_sequence_length": 16,
|
||||
"output_type": "pt",
|
||||
}
|
||||
images = pipe(**inputs_both).images
|
||||
self.assertEqual(len(images), 1)
|
||||
self.assertEqual(images[0].shape, (3, 32, 32))
|
||||
|
||||
def test_check_inputs_rejects_invalid_combinations(self):
|
||||
"""Test that check_inputs correctly rejects invalid input combinations."""
|
||||
device = "cpu"
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(device)
|
||||
|
||||
height, width = 32, 32
|
||||
|
||||
# Neither prompt nor prior_token_ids → error
|
||||
with self.assertRaises(ValueError):
|
||||
pipe.check_inputs(
|
||||
prompt=None,
|
||||
height=height,
|
||||
width=width,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
prompt_embeds=torch.randn(1, 16, 32),
|
||||
)
|
||||
|
||||
# prior_token_ids alone without prompt or prompt_embeds → error
|
||||
with self.assertRaises(ValueError):
|
||||
pipe.check_inputs(
|
||||
prompt=None,
|
||||
height=height,
|
||||
width=width,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
prior_token_ids=torch.randint(0, 100, (1, 64)),
|
||||
)
|
||||
|
||||
# prompt + prompt_embeds together → error
|
||||
with self.assertRaises(ValueError):
|
||||
pipe.check_inputs(
|
||||
prompt="A cat",
|
||||
height=height,
|
||||
width=width,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
prompt_embeds=torch.randn(1, 16, 32),
|
||||
)
|
||||
|
||||
@unittest.skip("Needs to be revisited.")
|
||||
def test_encode_prompt_works_in_isolation(self):
|
||||
pass
|
||||
|
||||
@@ -2406,7 +2406,11 @@ class PipelineTesterMixin:
|
||||
if name not in [exclude_module_name] and isinstance(component, torch.nn.Module):
|
||||
# `component.device` prints the `onload_device` type. We should probably override the
|
||||
# `device` property in `ModelMixin`.
|
||||
component_device = next(component.parameters())[0].device
|
||||
# Skip modules with no parameters (e.g., dummy safety checkers with only buffers)
|
||||
params = list(component.parameters())
|
||||
if not params:
|
||||
continue
|
||||
component_device = params[0].device
|
||||
self.assertTrue(torch.device(component_device).type == torch.device(offload_device).type)
|
||||
|
||||
@require_torch_accelerator
|
||||
|
||||
@@ -168,7 +168,7 @@ def assert_tensors_close(
|
||||
max_diff = abs_diff.max().item()
|
||||
|
||||
flat_idx = abs_diff.argmax().item()
|
||||
max_idx = tuple(torch.unravel_index(torch.tensor(flat_idx), actual.shape).tolist())
|
||||
max_idx = tuple(idx.item() for idx in torch.unravel_index(torch.tensor(flat_idx), actual.shape))
|
||||
|
||||
threshold = atol + rtol * expected.abs()
|
||||
mismatched = (abs_diff > threshold).sum().item()
|
||||
|
||||
Reference in New Issue
Block a user