[WIP]Add Wan2.2 Animate Pipeline (Continuation of #12442 by tolgacangoz) (#12526)

---------

Co-authored-by: Tolga Cangöz <mtcangoz@gmail.com>
Co-authored-by: Tolga Cangöz <46008593+tolgacangoz@users.noreply.github.com>
This commit is contained in:
dg845
2025-11-12 18:52:31 -08:00
committed by GitHub
parent 44c3101685
commit d8e4805816
19 changed files with 3676 additions and 33 deletions

View File

@@ -387,6 +387,8 @@
title: Transformer2DModel
- local: api/models/transformer_temporal
title: TransformerTemporalModel
- local: api/models/wan_animate_transformer_3d
title: WanAnimateTransformer3DModel
- local: api/models/wan_transformer_3d
title: WanTransformer3DModel
title: Transformers

View File

@@ -0,0 +1,30 @@
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->
# WanAnimateTransformer3DModel
A Diffusion Transformer model for 3D video-like data was introduced in [Wan Animate](https://github.com/Wan-Video/Wan2.2) by the Alibaba Wan Team.
The model can be loaded with the following code snippet.
```python
from diffusers import WanAnimateTransformer3DModel
transformer = WanAnimateTransformer3DModel.from_pretrained("Wan-AI/Wan2.2-Animate-14B-720P-Diffusers", subfolder="transformer", torch_dtype=torch.bfloat16)
```
## WanAnimateTransformer3DModel
[[autodoc]] WanAnimateTransformer3DModel
## Transformer2DModelOutput
[[autodoc]] models.modeling_outputs.Transformer2DModelOutput

View File

@@ -40,6 +40,7 @@ The following Wan models are supported in Diffusers:
- [Wan 2.2 T2V 14B](https://huggingface.co/Wan-AI/Wan2.2-T2V-A14B-Diffusers)
- [Wan 2.2 I2V 14B](https://huggingface.co/Wan-AI/Wan2.2-I2V-A14B-Diffusers)
- [Wan 2.2 TI2V 5B](https://huggingface.co/Wan-AI/Wan2.2-TI2V-5B-Diffusers)
- [Wan 2.2 Animate 14B](https://huggingface.co/Wan-AI/Wan2.2-Animate-14B-Diffusers)
> [!TIP]
> Click on the Wan models in the right sidebar for more examples of video generation.
@@ -95,15 +96,15 @@ pipeline = WanPipeline.from_pretrained(
pipeline.to("cuda")
prompt = """
The camera rushes from far to near in a low-angle shot,
revealing a white ferret on a log. It plays, leaps into the water, and emerges, as the camera zooms in
for a close-up. Water splashes berry bushes nearby, while moss, snow, and leaves blanket the ground.
Birch trees and a light blue sky frame the scene, with ferns in the foreground. Side lighting casts dynamic
The camera rushes from far to near in a low-angle shot,
revealing a white ferret on a log. It plays, leaps into the water, and emerges, as the camera zooms in
for a close-up. Water splashes berry bushes nearby, while moss, snow, and leaves blanket the ground.
Birch trees and a light blue sky frame the scene, with ferns in the foreground. Side lighting casts dynamic
shadows and warm highlights. Medium composition, front view, low angle, with depth of field.
"""
negative_prompt = """
Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality,
low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured,
Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality,
low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured,
misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards
"""
@@ -150,15 +151,15 @@ pipeline.transformer = torch.compile(
)
prompt = """
The camera rushes from far to near in a low-angle shot,
revealing a white ferret on a log. It plays, leaps into the water, and emerges, as the camera zooms in
for a close-up. Water splashes berry bushes nearby, while moss, snow, and leaves blanket the ground.
Birch trees and a light blue sky frame the scene, with ferns in the foreground. Side lighting casts dynamic
The camera rushes from far to near in a low-angle shot,
revealing a white ferret on a log. It plays, leaps into the water, and emerges, as the camera zooms in
for a close-up. Water splashes berry bushes nearby, while moss, snow, and leaves blanket the ground.
Birch trees and a light blue sky frame the scene, with ferns in the foreground. Side lighting casts dynamic
shadows and warm highlights. Medium composition, front view, low angle, with depth of field.
"""
negative_prompt = """
Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality,
low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured,
Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality,
low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured,
misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards
"""
@@ -249,6 +250,220 @@ The code snippets available in [this](https://github.com/huggingface/diffusers/p
The general rule of thumb to keep in mind when preparing inputs for the VACE pipeline is that the input images, or frames of a video that you want to use for conditioning, should have a corresponding mask that is black in color. The black mask signifies that the model will not generate new content for that area, and only use those parts for conditioning the generation process. For parts/frames that should be generated by the model, the mask should be white in color.
</hfoption>
</hfoptions>
### Wan-Animate: Unified Character Animation and Replacement with Holistic Replication
[Wan-Animate](https://huggingface.co/papers/2509.14055) by the Wan Team.
*We introduce Wan-Animate, a unified framework for character animation and replacement. Given a character image and a reference video, Wan-Animate can animate the character by precisely replicating the expressions and movements of the character in the video to generate high-fidelity character videos. Alternatively, it can integrate the animated character into the reference video to replace the original character, replicating the scene's lighting and color tone to achieve seamless environmental integration. Wan-Animate is built upon the Wan model. To adapt it for character animation tasks, we employ a modified input paradigm to differentiate between reference conditions and regions for generation. This design unifies multiple tasks into a common symbolic representation. We use spatially-aligned skeleton signals to replicate body motion and implicit facial features extracted from source images to reenact expressions, enabling the generation of character videos with high controllability and expressiveness. Furthermore, to enhance environmental integration during character replacement, we develop an auxiliary Relighting LoRA. This module preserves the character's appearance consistency while applying the appropriate environmental lighting and color tone. Experimental results demonstrate that Wan-Animate achieves state-of-the-art performance. We are committed to open-sourcing the model weights and its source code.*
The project page: https://humanaigc.github.io/wan-animate
This model was mostly contributed by [M. Tolga Cangöz](https://github.com/tolgacangoz).
#### Usage
The Wan-Animate pipeline supports two modes of operation:
1. **Animation Mode** (default): Animates a character image based on motion and expression from reference videos
2. **Replacement Mode**: Replaces a character in a background video with a new character while preserving the scene
##### Prerequisites
Before using the pipeline, you need to preprocess your reference video to extract:
- **Pose video**: Contains skeletal keypoints representing body motion
- **Face video**: Contains facial feature representations for expression control
For replacement mode, you additionally need:
- **Background video**: The original video containing the scene
- **Mask video**: A mask indicating where to generate content (white) vs. preserve original (black)
> [!NOTE]
> The preprocessing tools are available in the original Wan-Animate repository. Integration of these preprocessing steps into Diffusers is planned for a future release.
The example below demonstrates how to use the Wan-Animate pipeline:
<hfoptions id="Animate usage">
<hfoption id="Animation mode">
```python
import numpy as np
import torch
from diffusers import AutoencoderKLWan, WanAnimatePipeline
from diffusers.utils import export_to_video, load_image, load_video
from transformers import CLIPVisionModel
model_id = "Wan-AI/Wan2.2-Animate-14B-Diffusers"
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
pipe = WanAnimatePipeline.from_pretrained(
model_id, vae=vae, torch_dtype=torch.bfloat16
)
pipe.to("cuda")
# Load character image and preprocessed videos
image = load_image("path/to/character.jpg")
pose_video = load_video("path/to/pose_video.mp4") # Preprocessed skeletal keypoints
face_video = load_video("path/to/face_video.mp4") # Preprocessed facial features
# Resize image to match VAE constraints
def aspect_ratio_resize(image, pipe, max_area=720 * 1280):
aspect_ratio = image.height / image.width
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
image = image.resize((width, height))
return image, height, width
image, height, width = aspect_ratio_resize(image, pipe)
prompt = "A person dancing energetically in a studio with dynamic lighting and professional camera work"
negative_prompt = "blurry, low quality, distorted, deformed, static, poorly drawn"
# Generate animated video
output = pipe(
image=image,
pose_video=pose_video,
face_video=face_video,
prompt=prompt,
negative_prompt=negative_prompt,
height=height,
width=width,
num_frames=81,
guidance_scale=5.0,
mode="animation", # Animation mode (default)
).frames[0]
export_to_video(output, "animated_character.mp4", fps=16)
```
</hfoption>
<hfoption id="Replacement mode">
```python
import numpy as np
import torch
from diffusers import AutoencoderKLWan, WanAnimatePipeline
from diffusers.utils import export_to_video, load_image, load_video
from transformers import CLIPVisionModel
model_id = "Wan-AI/Wan2.2-Animate-14B-Diffusers"
image_encoder = CLIPVisionModel.from_pretrained(model_id, subfolder="image_encoder", torch_dtype=torch.float16)
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
pipe = WanAnimatePipeline.from_pretrained(
model_id, vae=vae, image_encoder=image_encoder, torch_dtype=torch.bfloat16
)
pipe.to("cuda")
# Load all required inputs for replacement mode
image = load_image("path/to/new_character.jpg")
pose_video = load_video("path/to/pose_video.mp4") # Preprocessed skeletal keypoints
face_video = load_video("path/to/face_video.mp4") # Preprocessed facial features
background_video = load_video("path/to/background_video.mp4") # Original scene
mask_video = load_video("path/to/mask_video.mp4") # Black: preserve, White: generate
# Resize image to match video dimensions
def aspect_ratio_resize(image, pipe, max_area=720 * 1280):
aspect_ratio = image.height / image.width
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
image = image.resize((width, height))
return image, height, width
image, height, width = aspect_ratio_resize(image, pipe)
prompt = "A person seamlessly integrated into the scene with consistent lighting and environment"
negative_prompt = "blurry, low quality, inconsistent lighting, floating, disconnected from scene"
# Replace character in background video
output = pipe(
image=image,
pose_video=pose_video,
face_video=face_video,
background_video=background_video,
mask_video=mask_video,
prompt=prompt,
negative_prompt=negative_prompt,
height=height,
width=width,
num_frames=81,
guidance_scale=5.0,
mode="replacement", # Replacement mode
).frames[0]
export_to_video(output, "character_replaced.mp4", fps=16)
```
</hfoption>
<hfoption id="Advanced options">
```python
import numpy as np
import torch
from diffusers import AutoencoderKLWan, WanAnimatePipeline
from diffusers.utils import export_to_video, load_image, load_video
from transformers import CLIPVisionModel
model_id = "Wan-AI/Wan2.2-Animate-14B-Diffusers"
image_encoder = CLIPVisionModel.from_pretrained(model_id, subfolder="image_encoder", torch_dtype=torch.float16)
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
pipe = WanAnimatePipeline.from_pretrained(
model_id, vae=vae, image_encoder=image_encoder, torch_dtype=torch.bfloat16
)
pipe.to("cuda")
image = load_image("path/to/character.jpg")
pose_video = load_video("path/to/pose_video.mp4")
face_video = load_video("path/to/face_video.mp4")
def aspect_ratio_resize(image, pipe, max_area=720 * 1280):
aspect_ratio = image.height / image.width
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
image = image.resize((width, height))
return image, height, width
image, height, width = aspect_ratio_resize(image, pipe)
prompt = "A person dancing energetically in a studio"
negative_prompt = "blurry, low quality"
# Advanced: Use temporal guidance and custom callback
def callback_fn(pipe, step_index, timestep, callback_kwargs):
# You can modify latents or other tensors here
print(f"Step {step_index}, Timestep {timestep}")
return callback_kwargs
output = pipe(
image=image,
pose_video=pose_video,
face_video=face_video,
prompt=prompt,
negative_prompt=negative_prompt,
height=height,
width=width,
num_frames=81,
num_inference_steps=50,
guidance_scale=5.0,
num_frames_for_temporal_guidance=5, # Use 5 frames for temporal guidance (1 or 5 recommended)
callback_on_step_end=callback_fn,
callback_on_step_end_tensor_inputs=["latents"],
).frames[0]
export_to_video(output, "animated_advanced.mp4", fps=16)
```
</hfoption>
</hfoptions>
#### Key Parameters
- **mode**: Choose between `"animation"` (default) or `"replacement"`
- **num_frames_for_temporal_guidance**: Number of frames for temporal guidance (1 or 5 recommended). Using 5 provides better temporal consistency but requires more memory
- **guidance_scale**: Controls how closely the output follows the text prompt. Higher values (5-7) produce results more aligned with the prompt
- **num_frames**: Total number of frames to generate. Should be divisible by `vae_scale_factor_temporal` (default: 4)
## Notes
- Wan2.1 supports LoRAs with [`~loaders.WanLoraLoaderMixin.load_lora_weights`].
@@ -281,10 +496,10 @@ The general rule of thumb to keep in mind when preparing inputs for the VACE pip
# use "steamboat willie style" to trigger the LoRA
prompt = """
steamboat willie style, golden era animation, The camera rushes from far to near in a low-angle shot,
revealing a white ferret on a log. It plays, leaps into the water, and emerges, as the camera zooms in
for a close-up. Water splashes berry bushes nearby, while moss, snow, and leaves blanket the ground.
Birch trees and a light blue sky frame the scene, with ferns in the foreground. Side lighting casts dynamic
steamboat willie style, golden era animation, The camera rushes from far to near in a low-angle shot,
revealing a white ferret on a log. It plays, leaps into the water, and emerges, as the camera zooms in
for a close-up. Water splashes berry bushes nearby, while moss, snow, and leaves blanket the ground.
Birch trees and a light blue sky frame the scene, with ferns in the foreground. Side lighting casts dynamic
shadows and warm highlights. Medium composition, front view, low angle, with depth of field.
"""
@@ -359,6 +574,12 @@ The general rule of thumb to keep in mind when preparing inputs for the VACE pip
- all
- __call__
## WanAnimatePipeline
[[autodoc]] WanAnimatePipeline
- all
- __call__
## WanPipelineOutput
[[autodoc]] pipelines.wan.pipeline_output.WanPipelineOutput
[[autodoc]] pipelines.wan.pipeline_output.WanPipelineOutput

View File

@@ -6,11 +6,20 @@ import torch
from accelerate import init_empty_weights
from huggingface_hub import hf_hub_download, snapshot_download
from safetensors.torch import load_file
from transformers import AutoProcessor, AutoTokenizer, CLIPVisionModelWithProjection, UMT5EncoderModel
from transformers import (
AutoProcessor,
AutoTokenizer,
CLIPImageProcessor,
CLIPVisionModel,
CLIPVisionModelWithProjection,
UMT5EncoderModel,
)
from diffusers import (
AutoencoderKLWan,
UniPCMultistepScheduler,
WanAnimatePipeline,
WanAnimateTransformer3DModel,
WanImageToVideoPipeline,
WanPipeline,
WanTransformer3DModel,
@@ -105,8 +114,203 @@ VACE_TRANSFORMER_KEYS_RENAME_DICT = {
"after_proj": "proj_out",
}
ANIMATE_TRANSFORMER_KEYS_RENAME_DICT = {
"time_embedding.0": "condition_embedder.time_embedder.linear_1",
"time_embedding.2": "condition_embedder.time_embedder.linear_2",
"text_embedding.0": "condition_embedder.text_embedder.linear_1",
"text_embedding.2": "condition_embedder.text_embedder.linear_2",
"time_projection.1": "condition_embedder.time_proj",
"head.modulation": "scale_shift_table",
"head.head": "proj_out",
"modulation": "scale_shift_table",
"ffn.0": "ffn.net.0.proj",
"ffn.2": "ffn.net.2",
# Hack to swap the layer names
# The original model calls the norms in following order: norm1, norm3, norm2
# We convert it to: norm1, norm2, norm3
"norm2": "norm__placeholder",
"norm3": "norm2",
"norm__placeholder": "norm3",
"img_emb.proj.0": "condition_embedder.image_embedder.norm1",
"img_emb.proj.1": "condition_embedder.image_embedder.ff.net.0.proj",
"img_emb.proj.3": "condition_embedder.image_embedder.ff.net.2",
"img_emb.proj.4": "condition_embedder.image_embedder.norm2",
# Add attention component mappings
"self_attn.q": "attn1.to_q",
"self_attn.k": "attn1.to_k",
"self_attn.v": "attn1.to_v",
"self_attn.o": "attn1.to_out.0",
"self_attn.norm_q": "attn1.norm_q",
"self_attn.norm_k": "attn1.norm_k",
"cross_attn.q": "attn2.to_q",
"cross_attn.k": "attn2.to_k",
"cross_attn.v": "attn2.to_v",
"cross_attn.o": "attn2.to_out.0",
"cross_attn.norm_q": "attn2.norm_q",
"cross_attn.norm_k": "attn2.norm_k",
"cross_attn.k_img": "attn2.to_k_img",
"cross_attn.v_img": "attn2.to_v_img",
"cross_attn.norm_k_img": "attn2.norm_k_img",
# After cross_attn -> attn2 rename, we need to rename the img keys
"attn2.to_k_img": "attn2.add_k_proj",
"attn2.to_v_img": "attn2.add_v_proj",
"attn2.norm_k_img": "attn2.norm_added_k",
# Wan Animate-specific mappings (motion encoder, face encoder, face adapter)
# Motion encoder mappings
# The name mapping is complicated for the convolutional part so we handle that in its own function
"motion_encoder.enc.fc": "motion_encoder.motion_network",
"motion_encoder.dec.direction.weight": "motion_encoder.motion_synthesis_weight",
# Face encoder mappings - CausalConv1d has a .conv submodule that we need to flatten
"face_encoder.conv1_local.conv": "face_encoder.conv1_local",
"face_encoder.conv2.conv": "face_encoder.conv2",
"face_encoder.conv3.conv": "face_encoder.conv3",
# Face adapter mappings are handled in a separate function
}
# TODO: Verify this and simplify if possible.
def convert_animate_motion_encoder_weights(key: str, state_dict: Dict[str, Any], final_conv_idx: int = 8) -> None:
"""
Convert all motion encoder weights for Animate model.
In the original model:
- All Linear layers in fc use EqualLinear
- All Conv2d layers in convs use EqualConv2d (except blur_conv which is initialized separately)
- Blur kernels are stored as buffers in Sequential modules
- ConvLayer is nn.Sequential with indices: [Blur (optional), EqualConv2d, FusedLeakyReLU (optional)]
Conversion strategy:
1. Drop .kernel buffers (blur kernels)
2. Rename sequential indices to named components (e.g., 0 -> conv2d, 1 -> bias_leaky_relu)
"""
# Skip if not a weight, bias, or kernel
if ".weight" not in key and ".bias" not in key and ".kernel" not in key:
return
# Handle Blur kernel buffers from original implementation.
# After renaming, these appear under: motion_encoder.res_blocks.*.conv{2,skip}.blur_kernel
# Diffusers constructs blur kernels as a non-persistent buffer so we must drop these keys
if ".kernel" in key and "motion_encoder" in key:
# Remove unexpected blur kernel buffers to avoid strict load errors
state_dict.pop(key, None)
return
# Rename Sequential indices to named components in ConvLayer and ResBlock
if ".enc.net_app.convs." in key and (".weight" in key or ".bias" in key):
parts = key.split(".")
# Find the sequential index (digit) after convs or after conv1/conv2/skip
# Examples:
# - enc.net_app.convs.0.0.weight -> conv_in.weight (initial conv layer weight)
# - enc.net_app.convs.0.1.bias -> conv_in.act_fn.bias (initial conv layer bias)
# - enc.net_app.convs.{n:1-7}.conv1.0.weight -> res_blocks.{(n-1):0-6}.conv1.weight (conv1 weight)
# - e.g. enc.net_app.convs.1.conv1.0.weight -> res_blocks.0.conv1.weight
# - enc.net_app.convs.{n:1-7}.conv1.1.bias -> res_blocks.{(n-1):0-6}.conv1.act_fn.bias (conv1 bias)
# - e.g. enc.net_app.convs.1.conv1.1.bias -> res_blocks.0.conv1.act_fn.bias
# - enc.net_app.convs.{n:1-7}.conv2.1.weight -> res_blocks.{(n-1):0-6}.conv2.weight (conv2 weight)
# - enc.net_app.convs.1.conv2.2.bias -> res_blocks.0.conv2.act_fn.bias (conv2 bias)
# - enc.net_app.convs.{n:1-7}.skip.1.weight -> res_blocks.{(n-1):0-6}.conv_skip.weight (skip conv weight)
# - enc.net_app.convs.8 -> conv_out (final conv layer)
convs_idx = parts.index("convs") if "convs" in parts else -1
if convs_idx >= 0 and len(parts) - convs_idx >= 2:
bias = False
# The nn.Sequential index will always follow convs
sequential_idx = int(parts[convs_idx + 1])
if sequential_idx == 0:
if key.endswith(".weight"):
new_key = "motion_encoder.conv_in.weight"
elif key.endswith(".bias"):
new_key = "motion_encoder.conv_in.act_fn.bias"
bias = True
elif sequential_idx == final_conv_idx:
if key.endswith(".weight"):
new_key = "motion_encoder.conv_out.weight"
else:
# Intermediate .convs. layers, which get mapped to .res_blocks.
prefix = "motion_encoder.res_blocks."
layer_name = parts[convs_idx + 2]
if layer_name == "skip":
layer_name = "conv_skip"
if key.endswith(".weight"):
param_name = "weight"
elif key.endswith(".bias"):
param_name = "act_fn.bias"
bias = True
suffix_parts = [str(sequential_idx - 1), layer_name, param_name]
suffix = ".".join(suffix_parts)
new_key = prefix + suffix
param = state_dict.pop(key)
if bias:
param = param.squeeze()
state_dict[new_key] = param
return
return
return
def convert_animate_face_adapter_weights(key: str, state_dict: Dict[str, Any]) -> None:
"""
Convert face adapter weights for the Animate model.
The original model uses a fused KV projection but the diffusers models uses separate K and V projections.
"""
# Skip if not a weight or bias
if ".weight" not in key and ".bias" not in key:
return
prefix = "face_adapter."
if ".fuser_blocks." in key:
parts = key.split(".")
module_list_idx = parts.index("fuser_blocks") if "fuser_blocks" in parts else -1
if module_list_idx >= 0 and (len(parts) - 1) - module_list_idx == 3:
block_idx = parts[module_list_idx + 1]
layer_name = parts[module_list_idx + 2]
param_name = parts[module_list_idx + 3]
if layer_name == "linear1_kv":
layer_name_k = "to_k"
layer_name_v = "to_v"
suffix_k = ".".join([block_idx, layer_name_k, param_name])
suffix_v = ".".join([block_idx, layer_name_v, param_name])
new_key_k = prefix + suffix_k
new_key_v = prefix + suffix_v
kv_proj = state_dict.pop(key)
k_proj, v_proj = torch.chunk(kv_proj, 2, dim=0)
state_dict[new_key_k] = k_proj
state_dict[new_key_v] = v_proj
return
else:
if layer_name == "q_norm":
new_layer_name = "norm_q"
elif layer_name == "k_norm":
new_layer_name = "norm_k"
elif layer_name == "linear1_q":
new_layer_name = "to_q"
elif layer_name == "linear2":
new_layer_name = "to_out"
suffix_parts = [block_idx, new_layer_name, param_name]
suffix = ".".join(suffix_parts)
new_key = prefix + suffix
state_dict[new_key] = state_dict.pop(key)
return
return
TRANSFORMER_SPECIAL_KEYS_REMAP = {}
VACE_TRANSFORMER_SPECIAL_KEYS_REMAP = {}
ANIMATE_TRANSFORMER_SPECIAL_KEYS_REMAP = {
"motion_encoder": convert_animate_motion_encoder_weights,
"face_adapter": convert_animate_face_adapter_weights,
}
def update_state_dict_(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]:
@@ -364,6 +568,37 @@ def get_transformer_config(model_type: str) -> Tuple[Dict[str, Any], ...]:
}
RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT
SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP
elif model_type == "Wan2.2-Animate-14B":
config = {
"model_id": "Wan-AI/Wan2.2-Animate-14B",
"diffusers_config": {
"image_dim": 1280,
"added_kv_proj_dim": 5120,
"attention_head_dim": 128,
"cross_attn_norm": True,
"eps": 1e-06,
"ffn_dim": 13824,
"freq_dim": 256,
"in_channels": 36,
"num_attention_heads": 40,
"num_layers": 40,
"out_channels": 16,
"patch_size": (1, 2, 2),
"qk_norm": "rms_norm_across_heads",
"text_dim": 4096,
"rope_max_seq_len": 1024,
"pos_embed_seq_len": None,
"motion_encoder_size": 512, # Start of Wan Animate-specific configs
"motion_style_dim": 512,
"motion_dim": 20,
"motion_encoder_dim": 512,
"face_encoder_hidden_dim": 1024,
"face_encoder_num_heads": 4,
"inject_face_latents_blocks": 5,
},
}
RENAME_DICT = ANIMATE_TRANSFORMER_KEYS_RENAME_DICT
SPECIAL_KEYS_REMAP = ANIMATE_TRANSFORMER_SPECIAL_KEYS_REMAP
return config, RENAME_DICT, SPECIAL_KEYS_REMAP
@@ -380,10 +615,12 @@ def convert_transformer(model_type: str, stage: str = None):
original_state_dict = load_sharded_safetensors(model_dir)
with init_empty_weights():
if "VACE" not in model_type:
transformer = WanTransformer3DModel.from_config(diffusers_config)
else:
if "Animate" in model_type:
transformer = WanAnimateTransformer3DModel.from_config(diffusers_config)
elif "VACE" in model_type:
transformer = WanVACETransformer3DModel.from_config(diffusers_config)
else:
transformer = WanTransformer3DModel.from_config(diffusers_config)
for key in list(original_state_dict.keys()):
new_key = key[:]
@@ -397,7 +634,12 @@ def convert_transformer(model_type: str, stage: str = None):
continue
handler_fn_inplace(key, original_state_dict)
# Load state dict into the meta model, which will materialize the tensors
transformer.load_state_dict(original_state_dict, strict=True, assign=True)
# Move to CPU to ensure all tensors are materialized
transformer = transformer.to("cpu")
return transformer
@@ -926,7 +1168,7 @@ DTYPE_MAPPING = {
if __name__ == "__main__":
args = get_args()
if "Wan2.2" in args.model_type and "TI2V" not in args.model_type:
if "Wan2.2" in args.model_type and "TI2V" not in args.model_type and "Animate" not in args.model_type:
transformer = convert_transformer(args.model_type, stage="high_noise_model")
transformer_2 = convert_transformer(args.model_type, stage="low_noise_model")
else:
@@ -942,7 +1184,7 @@ if __name__ == "__main__":
tokenizer = AutoTokenizer.from_pretrained("google/umt5-xxl")
if "FLF2V" in args.model_type:
flow_shift = 16.0
elif "TI2V" in args.model_type:
elif "TI2V" in args.model_type or "Animate" in args.model_type:
flow_shift = 5.0
else:
flow_shift = 3.0
@@ -954,6 +1196,8 @@ if __name__ == "__main__":
if args.dtype != "none":
dtype = DTYPE_MAPPING[args.dtype]
transformer.to(dtype)
if transformer_2 is not None:
transformer_2.to(dtype)
if "Wan2.2" and "I2V" in args.model_type and "TI2V" not in args.model_type:
pipe = WanImageToVideoPipeline(
@@ -1016,6 +1260,21 @@ if __name__ == "__main__":
vae=vae,
scheduler=scheduler,
)
elif "Animate" in args.model_type:
image_encoder = CLIPVisionModel.from_pretrained(
"laion/CLIP-ViT-H-14-laion2B-s32B-b79K", torch_dtype=torch.bfloat16
)
image_processor = CLIPImageProcessor.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
pipe = WanAnimatePipeline(
transformer=transformer,
text_encoder=text_encoder,
tokenizer=tokenizer,
vae=vae,
scheduler=scheduler,
image_encoder=image_encoder,
image_processor=image_processor,
)
else:
pipe = WanPipeline(
transformer=transformer,

View File

@@ -268,6 +268,7 @@ else:
"UNetSpatioTemporalConditionModel",
"UVit2DModel",
"VQModel",
"WanAnimateTransformer3DModel",
"WanTransformer3DModel",
"WanVACETransformer3DModel",
"attention_backend",
@@ -636,6 +637,7 @@ else:
"VisualClozeGenerationPipeline",
"VisualClozePipeline",
"VQDiffusionPipeline",
"WanAnimatePipeline",
"WanImageToVideoPipeline",
"WanPipeline",
"WanVACEPipeline",
@@ -977,6 +979,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
UNetSpatioTemporalConditionModel,
UVit2DModel,
VQModel,
WanAnimateTransformer3DModel,
WanTransformer3DModel,
WanVACETransformer3DModel,
attention_backend,
@@ -1315,6 +1318,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
VisualClozeGenerationPipeline,
VisualClozePipeline,
VQDiffusionPipeline,
WanAnimatePipeline,
WanImageToVideoPipeline,
WanPipeline,
WanVACEPipeline,

View File

@@ -409,7 +409,7 @@ class VaeImageProcessor(ConfigMixin):
src_w = width if ratio < src_ratio else image.width * height // image.height
src_h = height if ratio >= src_ratio else image.height * width // image.width
resized = image.resize((src_w, src_h), resample=PIL_INTERPOLATION["lanczos"])
resized = image.resize((src_w, src_h), resample=PIL_INTERPOLATION[self.config.resample])
res = Image.new("RGB", (width, height))
res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
@@ -460,7 +460,7 @@ class VaeImageProcessor(ConfigMixin):
src_w = width if ratio > src_ratio else image.width * height // image.height
src_h = height if ratio <= src_ratio else image.height * width // image.width
resized = image.resize((src_w, src_h), resample=PIL_INTERPOLATION["lanczos"])
resized = image.resize((src_w, src_h), resample=PIL_INTERPOLATION[self.config.resample])
res = Image.new("RGB", (width, height))
res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
return res

View File

@@ -108,6 +108,7 @@ if is_torch_available():
_import_structure["transformers.transformer_skyreels_v2"] = ["SkyReelsV2Transformer3DModel"]
_import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"]
_import_structure["transformers.transformer_wan"] = ["WanTransformer3DModel"]
_import_structure["transformers.transformer_wan_animate"] = ["WanAnimateTransformer3DModel"]
_import_structure["transformers.transformer_wan_vace"] = ["WanVACETransformer3DModel"]
_import_structure["unets.unet_1d"] = ["UNet1DModel"]
_import_structure["unets.unet_2d"] = ["UNet2DModel"]
@@ -214,6 +215,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
T5FilmDecoder,
Transformer2DModel,
TransformerTemporalModel,
WanAnimateTransformer3DModel,
WanTransformer3DModel,
WanVACETransformer3DModel,
)

View File

@@ -42,4 +42,5 @@ if is_torch_available():
from .transformer_skyreels_v2 import SkyReelsV2Transformer3DModel
from .transformer_temporal import TransformerTemporalModel
from .transformer_wan import WanTransformer3DModel
from .transformer_wan_animate import WanAnimateTransformer3DModel
from .transformer_wan_vace import WanVACETransformer3DModel

View File

@@ -188,6 +188,11 @@ class WanRotaryPosEmbed(nn.Module):
h_dim = w_dim = 2 * (attention_head_dim // 6)
t_dim = attention_head_dim - h_dim - w_dim
self.t_dim = t_dim
self.h_dim = h_dim
self.w_dim = w_dim
freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
freqs_cos = []
@@ -213,11 +218,7 @@ class WanRotaryPosEmbed(nn.Module):
p_t, p_h, p_w = self.patch_size
ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w
split_sizes = [
self.attention_head_dim - 2 * (self.attention_head_dim // 3),
self.attention_head_dim // 3,
self.attention_head_dim // 3,
]
split_sizes = [self.t_dim, self.h_dim, self.w_dim]
freqs_cos = self.freqs_cos.split(split_sizes, dim=1)
freqs_sin = self.freqs_sin.split(split_sizes, dim=1)

File diff suppressed because it is too large Load Diff

View File

@@ -385,7 +385,13 @@ else:
"WuerstchenDecoderPipeline",
"WuerstchenPriorPipeline",
]
_import_structure["wan"] = ["WanPipeline", "WanImageToVideoPipeline", "WanVideoToVideoPipeline", "WanVACEPipeline"]
_import_structure["wan"] = [
"WanPipeline",
"WanImageToVideoPipeline",
"WanVideoToVideoPipeline",
"WanVACEPipeline",
"WanAnimatePipeline",
]
_import_structure["kandinsky5"] = ["Kandinsky5T2VPipeline"]
_import_structure["skyreels_v2"] = [
"SkyReelsV2DiffusionForcingPipeline",
@@ -803,7 +809,13 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
UniDiffuserTextDecoder,
)
from .visualcloze import VisualClozeGenerationPipeline, VisualClozePipeline
from .wan import WanImageToVideoPipeline, WanPipeline, WanVACEPipeline, WanVideoToVideoPipeline
from .wan import (
WanAnimatePipeline,
WanImageToVideoPipeline,
WanPipeline,
WanVACEPipeline,
WanVideoToVideoPipeline,
)
from .wuerstchen import (
WuerstchenCombinedPipeline,
WuerstchenDecoderPipeline,

View File

@@ -23,6 +23,7 @@ except OptionalDependencyNotAvailable:
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_wan"] = ["WanPipeline"]
_import_structure["pipeline_wan_animate"] = ["WanAnimatePipeline"]
_import_structure["pipeline_wan_i2v"] = ["WanImageToVideoPipeline"]
_import_structure["pipeline_wan_vace"] = ["WanVACEPipeline"]
_import_structure["pipeline_wan_video2video"] = ["WanVideoToVideoPipeline"]
@@ -35,10 +36,10 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .pipeline_wan import WanPipeline
from .pipeline_wan_animate import WanAnimatePipeline
from .pipeline_wan_i2v import WanImageToVideoPipeline
from .pipeline_wan_vace import WanVACEPipeline
from .pipeline_wan_video2video import WanVideoToVideoPipeline
else:
import sys

View File

@@ -0,0 +1,185 @@
# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple, Union
import numpy as np
import PIL.Image
import torch
from ...configuration_utils import register_to_config
from ...image_processor import VaeImageProcessor
from ...utils import PIL_INTERPOLATION
class WanAnimateImageProcessor(VaeImageProcessor):
r"""
Image processor to preprocess the reference (character) image for the Wan Animate model.
Args:
do_resize (`bool`, *optional*, defaults to `True`):
Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. Can accept
`height` and `width` arguments from [`image_processor.VaeImageProcessor.preprocess`] method.
vae_scale_factor (`int`, *optional*, defaults to `8`):
VAE (spatial) scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of
this factor.
vae_latent_channels (`int`, *optional*, defaults to `16`):
VAE latent channels.
spatial_patch_size (`Tuple[int, int]`, *optional*, defaults to `(2, 2)`):
The spatial patch size used by the diffusion transformer. For Wan models, this is typically (2, 2).
resample (`str`, *optional*, defaults to `lanczos`):
Resampling filter to use when resizing the image.
do_normalize (`bool`, *optional*, defaults to `True`):
Whether to normalize the image to [-1,1].
do_binarize (`bool`, *optional*, defaults to `False`):
Whether to binarize the image to 0/1.
do_convert_rgb (`bool`, *optional*, defaults to be `False`):
Whether to convert the images to RGB format.
do_convert_grayscale (`bool`, *optional*, defaults to be `False`):
Whether to convert the images to grayscale format.
fill_color (`str` or `float` or `Tuple[float, ...]`, *optional*, defaults to `None`):
An optional fill color when `resize_mode` is set to `"fill"`. This will fill the empty space with that
color instead of filling with data from the image. Any valid `color` argument to `PIL.Image.new` is valid;
if `None`, will default to filling with data from `image`.
"""
@register_to_config
def __init__(
self,
do_resize: bool = True,
vae_scale_factor: int = 8,
vae_latent_channels: int = 16,
spatial_patch_size: Tuple[int, int] = (2, 2),
resample: str = "lanczos",
reducing_gap: int = None,
do_normalize: bool = True,
do_binarize: bool = False,
do_convert_rgb: bool = False,
do_convert_grayscale: bool = False,
fill_color: Optional[Union[str, float, Tuple[float, ...]]] = 0,
):
super().__init__()
if do_convert_rgb and do_convert_grayscale:
raise ValueError(
"`do_convert_rgb` and `do_convert_grayscale` can not both be set to `True`,"
" if you intended to convert the image into RGB format, please set `do_convert_grayscale = False`.",
" if you intended to convert the image into grayscale format, please set `do_convert_rgb = False`",
)
def _resize_and_fill(
self,
image: PIL.Image.Image,
width: int,
height: int,
) -> PIL.Image.Image:
r"""
Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center
the image within the dimensions, filling empty with data from image.
Args:
image (`PIL.Image.Image`):
The image to resize and fill.
width (`int`):
The width to resize the image to.
height (`int`):
The height to resize the image to.
Returns:
`PIL.Image.Image`:
The resized and filled image.
"""
ratio = width / height
src_ratio = image.width / image.height
fill_with_image_data = self.config.fill_color is None
fill_color = self.config.fill_color or 0
src_w = width if ratio < src_ratio else image.width * height // image.height
src_h = height if ratio >= src_ratio else image.height * width // image.width
resized = image.resize((src_w, src_h), resample=PIL_INTERPOLATION[self.config.resample])
res = PIL.Image.new("RGB", (width, height), color=fill_color)
res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
if fill_with_image_data:
if ratio < src_ratio:
fill_height = height // 2 - src_h // 2
if fill_height > 0:
res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0))
res.paste(
resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)),
box=(0, fill_height + src_h),
)
elif ratio > src_ratio:
fill_width = width // 2 - src_w // 2
if fill_width > 0:
res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0))
res.paste(
resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)),
box=(fill_width + src_w, 0),
)
return res
def get_default_height_width(
self,
image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
height: Optional[int] = None,
width: Optional[int] = None,
) -> Tuple[int, int]:
r"""
Returns the height and width of the image, downscaled to the next integer multiple of `vae_scale_factor`.
Args:
image (`Union[PIL.Image.Image, np.ndarray, torch.Tensor]`):
The image input, which can be a PIL image, NumPy array, or PyTorch tensor. If it is a NumPy array, it
should have shape `[batch, height, width]` or `[batch, height, width, channels]`. If it is a PyTorch
tensor, it should have shape `[batch, channels, height, width]`.
height (`Optional[int]`, *optional*, defaults to `None`):
The height of the preprocessed image. If `None`, the height of the `image` input will be used.
width (`Optional[int]`, *optional*, defaults to `None`):
The width of the preprocessed image. If `None`, the width of the `image` input will be used.
Returns:
`Tuple[int, int]`:
A tuple containing the height and width, both resized to the nearest integer multiple of
`vae_scale_factor * spatial_patch_size`.
"""
if height is None:
if isinstance(image, PIL.Image.Image):
height = image.height
elif isinstance(image, torch.Tensor):
height = image.shape[2]
else:
height = image.shape[1]
if width is None:
if isinstance(image, PIL.Image.Image):
width = image.width
elif isinstance(image, torch.Tensor):
width = image.shape[3]
else:
width = image.shape[2]
max_area = width * height
aspect_ratio = height / width
mod_value_h = self.config.vae_scale_factor * self.config.spatial_patch_size[0]
mod_value_w = self.config.vae_scale_factor * self.config.spatial_patch_size[1]
# Try to preserve the aspect ratio
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value_h * mod_value_h
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value_w * mod_value_w
return height, width

File diff suppressed because it is too large Load Diff

View File

@@ -1623,6 +1623,21 @@ class VQModel(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class WanAnimateTransformer3DModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class WanTransformer3DModel(metaclass=DummyObject):
_backends = ["torch"]

View File

@@ -3512,6 +3512,21 @@ class VQDiffusionPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class WanAnimatePipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class WanImageToVideoPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]

View File

@@ -0,0 +1,126 @@
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import torch
from diffusers import WanAnimateTransformer3DModel
from ...testing_utils import (
enable_full_determinism,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
enable_full_determinism()
class WanAnimateTransformer3DTests(ModelTesterMixin, unittest.TestCase):
model_class = WanAnimateTransformer3DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
@property
def dummy_input(self):
batch_size = 1
num_channels = 4
num_frames = 20 # To make the shapes work out; for complicated reasons we want 21 to divide num_frames + 1
height = 16
width = 16
text_encoder_embedding_dim = 16
sequence_length = 12
clip_seq_len = 12
clip_dim = 16
inference_segment_length = 77 # The inference segment length in the full Wan2.2-Animate-14B model
face_height = 16 # Should be square and match `motion_encoder_size` below
face_width = 16
hidden_states = torch.randn((batch_size, 2 * num_channels + 4, num_frames + 1, height, width)).to(torch_device)
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, text_encoder_embedding_dim)).to(torch_device)
clip_ref_features = torch.randn((batch_size, clip_seq_len, clip_dim)).to(torch_device)
pose_latents = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
face_pixel_values = torch.randn((batch_size, 3, inference_segment_length, face_height, face_width)).to(
torch_device
)
return {
"hidden_states": hidden_states,
"timestep": timestep,
"encoder_hidden_states": encoder_hidden_states,
"encoder_hidden_states_image": clip_ref_features,
"pose_hidden_states": pose_latents,
"face_pixel_values": face_pixel_values,
}
@property
def input_shape(self):
return (12, 1, 16, 16)
@property
def output_shape(self):
return (4, 1, 16, 16)
def prepare_init_args_and_inputs_for_common(self):
# Use custom channel sizes since the default Wan Animate channel sizes will cause the motion encoder to
# contain the vast majority of the parameters in the test model
channel_sizes = {"4": 16, "8": 16, "16": 16}
init_dict = {
"patch_size": (1, 2, 2),
"num_attention_heads": 2,
"attention_head_dim": 12,
"in_channels": 12, # 2 * C + 4 = 2 * 4 + 4 = 12
"latent_channels": 4,
"out_channels": 4,
"text_dim": 16,
"freq_dim": 256,
"ffn_dim": 32,
"num_layers": 2,
"cross_attn_norm": True,
"qk_norm": "rms_norm_across_heads",
"image_dim": 16,
"rope_max_seq_len": 32,
"motion_encoder_channel_sizes": channel_sizes, # Start of Wan Animate-specific config
"motion_encoder_size": 16, # Ensures that there will be 2 motion encoder resblocks
"motion_style_dim": 8,
"motion_dim": 4,
"motion_encoder_dim": 16,
"face_encoder_hidden_dim": 16,
"face_encoder_num_heads": 2,
"inject_face_latents_blocks": 2,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_gradient_checkpointing_is_applied(self):
expected_set = {"WanAnimateTransformer3DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
# Override test_output because the transformer output is expected to have less channels than the main transformer
# input.
def test_output(self):
expected_output_shape = (1, 4, 21, 16, 16)
super().test_output(expected_output_shape=expected_output_shape)
class WanAnimateTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
model_class = WanAnimateTransformer3DModel
def prepare_init_args_and_inputs_for_common(self):
return WanAnimateTransformer3DTests().prepare_init_args_and_inputs_for_common()

View File

@@ -0,0 +1,239 @@
# Copyright 2025 The HuggingFace Team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import unittest
import numpy as np
import torch
from PIL import Image
from transformers import (
AutoTokenizer,
CLIPImageProcessor,
CLIPVisionConfig,
CLIPVisionModelWithProjection,
T5EncoderModel,
)
from diffusers import (
AutoencoderKLWan,
FlowMatchEulerDiscreteScheduler,
WanAnimatePipeline,
WanAnimateTransformer3DModel,
)
from ...testing_utils import (
backend_empty_cache,
enable_full_determinism,
require_torch_accelerator,
slow,
torch_device,
)
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin
enable_full_determinism()
class WanAnimatePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = WanAnimatePipeline
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
required_optional_params = frozenset(
[
"num_inference_steps",
"generator",
"latents",
"return_dict",
"callback_on_step_end",
"callback_on_step_end_tensor_inputs",
]
)
test_xformers_attention = False
supports_dduf = False
def get_dummy_components(self):
torch.manual_seed(0)
vae = AutoencoderKLWan(
base_dim=3,
z_dim=16,
dim_mult=[1, 1, 1, 1],
num_res_blocks=1,
temperal_downsample=[False, True, True],
)
torch.manual_seed(0)
scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0)
text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
torch.manual_seed(0)
channel_sizes = {"4": 16, "8": 16, "16": 16}
transformer = WanAnimateTransformer3DModel(
patch_size=(1, 2, 2),
num_attention_heads=2,
attention_head_dim=12,
in_channels=36,
latent_channels=16,
out_channels=16,
text_dim=32,
freq_dim=256,
ffn_dim=32,
num_layers=2,
cross_attn_norm=True,
qk_norm="rms_norm_across_heads",
image_dim=4,
rope_max_seq_len=32,
motion_encoder_channel_sizes=channel_sizes,
motion_encoder_size=16,
motion_style_dim=8,
motion_dim=4,
motion_encoder_dim=16,
face_encoder_hidden_dim=16,
face_encoder_num_heads=2,
inject_face_latents_blocks=2,
)
torch.manual_seed(0)
image_encoder_config = CLIPVisionConfig(
hidden_size=4,
projection_dim=4,
num_hidden_layers=2,
num_attention_heads=2,
image_size=4,
intermediate_size=16,
patch_size=1,
)
image_encoder = CLIPVisionModelWithProjection(image_encoder_config)
torch.manual_seed(0)
image_processor = CLIPImageProcessor(crop_size=4, size=4)
components = {
"transformer": transformer,
"vae": vae,
"scheduler": scheduler,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"image_encoder": image_encoder,
"image_processor": image_processor,
}
return components
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
num_frames = 17
height = 16
width = 16
face_height = 16
face_width = 16
image = Image.new("RGB", (height, width))
pose_video = [Image.new("RGB", (height, width))] * num_frames
face_video = [Image.new("RGB", (face_height, face_width))] * num_frames
inputs = {
"image": image,
"pose_video": pose_video,
"face_video": face_video,
"prompt": "dance monkey",
"negative_prompt": "negative",
"height": height,
"width": width,
"segment_frame_length": 77, # TODO: can we set this to num_frames?
"num_inference_steps": 2,
"mode": "animate",
"prev_segment_conditioning_frames": 1,
"generator": generator,
"guidance_scale": 1.0,
"output_type": "pt",
"max_sequence_length": 16,
}
return inputs
def test_inference(self):
"""Test basic inference in animation mode."""
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
video = pipe(**inputs).frames[0]
self.assertEqual(video.shape, (17, 3, 16, 16))
expected_video = torch.randn(17, 3, 16, 16)
max_diff = np.abs(video - expected_video).max()
self.assertLessEqual(max_diff, 1e10)
def test_inference_replacement(self):
"""Test the pipeline in replacement mode with background and mask videos."""
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
inputs["mode"] = "replace"
num_frames = 17
height = 16
width = 16
inputs["background_video"] = [Image.new("RGB", (height, width))] * num_frames
inputs["mask_video"] = [Image.new("L", (height, width))] * num_frames
video = pipe(**inputs).frames[0]
self.assertEqual(video.shape, (17, 3, 16, 16))
@unittest.skip("Test not supported")
def test_attention_slicing_forward_pass(self):
pass
@unittest.skip(
"Setting the Wan Animate latents to zero at the last denoising step does not guarantee that the output will be"
" zero. I believe this is because the latents are further processed in the outer loop where we loop over"
" inference segments."
)
def test_callback_inputs(self):
pass
@slow
@require_torch_accelerator
class WanAnimatePipelineIntegrationTests(unittest.TestCase):
prompt = "A painting of a squirrel eating a burger."
def setUp(self):
super().setUp()
gc.collect()
backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
backend_empty_cache(torch_device)
@unittest.skip("TODO: test needs to be implemented")
def test_wan_animate(self):
pass

View File

@@ -16,6 +16,7 @@ from diffusers import (
HiDreamImageTransformer2DModel,
SD3Transformer2DModel,
StableDiffusion3Pipeline,
WanAnimateTransformer3DModel,
WanTransformer3DModel,
WanVACETransformer3DModel,
)
@@ -721,6 +722,33 @@ class WanVACEGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase):
}
class WanAnimateGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase):
ckpt_path = "https://huggingface.co/QuantStack/Wan2.2-Animate-14B-GGUF/blob/main/Wan2.2-Animate-14B-Q3_K_S.gguf"
torch_dtype = torch.bfloat16
model_cls = WanAnimateTransformer3DModel
expected_memory_use_in_gb = 9
def get_dummy_inputs(self):
return {
"hidden_states": torch.randn((1, 16, 2, 64, 64), generator=torch.Generator("cpu").manual_seed(0)).to(
torch_device, self.torch_dtype
),
"encoder_hidden_states": torch.randn(
(1, 512, 4096),
generator=torch.Generator("cpu").manual_seed(0),
).to(torch_device, self.torch_dtype),
"control_hidden_states": torch.randn(
(1, 96, 2, 64, 64),
generator=torch.Generator("cpu").manual_seed(0),
).to(torch_device, self.torch_dtype),
"control_hidden_states_scale": torch.randn(
(8,),
generator=torch.Generator("cpu").manual_seed(0),
).to(torch_device, self.torch_dtype),
"timestep": torch.tensor([1]).to(torch_device, self.torch_dtype),
}
@require_torch_version_greater("2.7.1")
class GGUFCompileTests(QuantCompileTests, unittest.TestCase):
torch_dtype = torch.bfloat16