Compare commits

...

8 Commits

Author SHA1 Message Date
Dhruv Nair
c4262eea9b update 2025-03-07 17:27:56 +01:00
Dhruv Nair
36f7eaa994 update 2025-03-07 12:56:55 +01:00
Dhruv Nair
7464205c47 Merge branch 'main' into wan-sf 2025-03-07 11:17:14 +01:00
Dhruv Nair
634aa8d4a7 update 2025-03-07 11:17:00 +01:00
DN6
6399810347 update 2025-03-06 21:40:00 +05:30
DN6
b439e5db6d update 2025-03-06 21:38:56 +05:30
DN6
adde0818fd update 2025-03-06 20:15:23 +05:30
DN6
5d45b0ce9d update 2025-03-06 20:08:50 +05:30
8 changed files with 519 additions and 51 deletions

View File

@@ -45,6 +45,22 @@ pipe = WanPipeline.from_pretrained("Wan-AI/Wan2.1-T2V-1.3B-Diffusers", scheduler
pipe.scheduler = <CUSTOM_SCHEDULER_HERE>
```
### Using single file loading with Wan
The `WanTransformer3DModel` and `AutoencoderKLWan` models support loading checkpoints in their original format via the `from_single_file` loading
method.
```python
import torch
from diffusers import WanPipeline, WanTransformer3DModel
ckpt_path = "https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/blob/main/split_files/diffusion_models/wan2.1_t2v_1.3B_bf16.safetensors"
transformer = WanTransformer3DModel.from_single_file(ckpt_path, torch_dtype=torch.bfloat16)
pipe = WanPipeline.from_pretrained("Wan-AI/Wan2.1-T2V-1.3B-Diffusers", transformer=transformer)
```
## WanPipeline
[[autodoc]] WanPipeline

View File

@@ -39,6 +39,8 @@ from .single_file_utils import (
convert_mochi_transformer_checkpoint_to_diffusers,
convert_sd3_transformer_checkpoint_to_diffusers,
convert_stable_cascade_unet_single_file_to_diffusers,
convert_wan_transformer_to_diffusers,
convert_wan_vae_to_diffusers,
create_controlnet_diffusers_config_from_ldm,
create_unet_diffusers_config_from_ldm,
create_vae_diffusers_config_from_ldm,
@@ -117,6 +119,14 @@ SINGLE_FILE_LOADABLE_CLASSES = {
"checkpoint_mapping_fn": convert_lumina2_to_diffusers,
"default_subfolder": "transformer",
},
"WanTransformer3DModel": {
"checkpoint_mapping_fn": convert_wan_transformer_to_diffusers,
"default_subfolder": "transformer",
},
"AutoencoderKLWan": {
"checkpoint_mapping_fn": convert_wan_vae_to_diffusers,
"default_subfolder": "vae",
},
}

View File

@@ -117,6 +117,8 @@ CHECKPOINT_KEY_NAMES = {
"hunyuan-video": "txt_in.individual_token_refiner.blocks.0.adaLN_modulation.1.bias",
"instruct-pix2pix": "model.diffusion_model.input_blocks.0.0.weight",
"lumina2": ["model.diffusion_model.cap_embedder.0.weight", "cap_embedder.0.weight"],
"wan": ["model.diffusion_model.head.modulation", "head.modulation"],
"wan_vae": "decoder.middle.0.residual.0.gamma",
}
DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
@@ -176,6 +178,9 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
"hunyuan-video": {"pretrained_model_name_or_path": "hunyuanvideo-community/HunyuanVideo"},
"instruct-pix2pix": {"pretrained_model_name_or_path": "timbrooks/instruct-pix2pix"},
"lumina2": {"pretrained_model_name_or_path": "Alpha-VLLM/Lumina-Image-2.0"},
"wan-t2v-1.3B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"},
"wan-t2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-14B-Diffusers"},
"wan-i2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"},
}
# Use to configure model sample size when original config is provided
@@ -664,6 +669,21 @@ def infer_diffusers_model_type(checkpoint):
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["lumina2"]):
model_type = "lumina2"
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["wan"]):
if "model.diffusion_model.patch_embedding.weight" in checkpoint:
target_key = "model.diffusion_model.patch_embedding.weight"
else:
target_key = "patch_embedding.weight"
if checkpoint[target_key].shape[0] == 1536:
model_type = "wan-t2v-1.3B"
elif checkpoint[target_key].shape[0] == 5120 and checkpoint[target_key].shape[1] == 16:
model_type = "wan-t2v-14B"
else:
model_type = "wan-i2v-14B"
elif CHECKPOINT_KEY_NAMES["wan_vae"] in checkpoint:
# All Wan models use the same VAE so we can use the same default model repo to fetch the config
model_type = "wan-t2v-14B"
else:
model_type = "v1"
@@ -2470,7 +2490,7 @@ def convert_autoencoder_dc_checkpoint_to_diffusers(checkpoint, **kwargs):
def convert_mochi_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
new_state_dict = {}
converted_state_dict = {}
# Comfy checkpoints add this prefix
keys = list(checkpoint.keys())
@@ -2479,22 +2499,22 @@ def convert_mochi_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
# Convert patch_embed
new_state_dict["patch_embed.proj.weight"] = checkpoint.pop("x_embedder.proj.weight")
new_state_dict["patch_embed.proj.bias"] = checkpoint.pop("x_embedder.proj.bias")
converted_state_dict["patch_embed.proj.weight"] = checkpoint.pop("x_embedder.proj.weight")
converted_state_dict["patch_embed.proj.bias"] = checkpoint.pop("x_embedder.proj.bias")
# Convert time_embed
new_state_dict["time_embed.timestep_embedder.linear_1.weight"] = checkpoint.pop("t_embedder.mlp.0.weight")
new_state_dict["time_embed.timestep_embedder.linear_1.bias"] = checkpoint.pop("t_embedder.mlp.0.bias")
new_state_dict["time_embed.timestep_embedder.linear_2.weight"] = checkpoint.pop("t_embedder.mlp.2.weight")
new_state_dict["time_embed.timestep_embedder.linear_2.bias"] = checkpoint.pop("t_embedder.mlp.2.bias")
new_state_dict["time_embed.pooler.to_kv.weight"] = checkpoint.pop("t5_y_embedder.to_kv.weight")
new_state_dict["time_embed.pooler.to_kv.bias"] = checkpoint.pop("t5_y_embedder.to_kv.bias")
new_state_dict["time_embed.pooler.to_q.weight"] = checkpoint.pop("t5_y_embedder.to_q.weight")
new_state_dict["time_embed.pooler.to_q.bias"] = checkpoint.pop("t5_y_embedder.to_q.bias")
new_state_dict["time_embed.pooler.to_out.weight"] = checkpoint.pop("t5_y_embedder.to_out.weight")
new_state_dict["time_embed.pooler.to_out.bias"] = checkpoint.pop("t5_y_embedder.to_out.bias")
new_state_dict["time_embed.caption_proj.weight"] = checkpoint.pop("t5_yproj.weight")
new_state_dict["time_embed.caption_proj.bias"] = checkpoint.pop("t5_yproj.bias")
converted_state_dict["time_embed.timestep_embedder.linear_1.weight"] = checkpoint.pop("t_embedder.mlp.0.weight")
converted_state_dict["time_embed.timestep_embedder.linear_1.bias"] = checkpoint.pop("t_embedder.mlp.0.bias")
converted_state_dict["time_embed.timestep_embedder.linear_2.weight"] = checkpoint.pop("t_embedder.mlp.2.weight")
converted_state_dict["time_embed.timestep_embedder.linear_2.bias"] = checkpoint.pop("t_embedder.mlp.2.bias")
converted_state_dict["time_embed.pooler.to_kv.weight"] = checkpoint.pop("t5_y_embedder.to_kv.weight")
converted_state_dict["time_embed.pooler.to_kv.bias"] = checkpoint.pop("t5_y_embedder.to_kv.bias")
converted_state_dict["time_embed.pooler.to_q.weight"] = checkpoint.pop("t5_y_embedder.to_q.weight")
converted_state_dict["time_embed.pooler.to_q.bias"] = checkpoint.pop("t5_y_embedder.to_q.bias")
converted_state_dict["time_embed.pooler.to_out.weight"] = checkpoint.pop("t5_y_embedder.to_out.weight")
converted_state_dict["time_embed.pooler.to_out.bias"] = checkpoint.pop("t5_y_embedder.to_out.bias")
converted_state_dict["time_embed.caption_proj.weight"] = checkpoint.pop("t5_yproj.weight")
converted_state_dict["time_embed.caption_proj.bias"] = checkpoint.pop("t5_yproj.bias")
# Convert transformer blocks
num_layers = 48
@@ -2503,68 +2523,84 @@ def convert_mochi_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
old_prefix = f"blocks.{i}."
# norm1
new_state_dict[block_prefix + "norm1.linear.weight"] = checkpoint.pop(old_prefix + "mod_x.weight")
new_state_dict[block_prefix + "norm1.linear.bias"] = checkpoint.pop(old_prefix + "mod_x.bias")
converted_state_dict[block_prefix + "norm1.linear.weight"] = checkpoint.pop(old_prefix + "mod_x.weight")
converted_state_dict[block_prefix + "norm1.linear.bias"] = checkpoint.pop(old_prefix + "mod_x.bias")
if i < num_layers - 1:
new_state_dict[block_prefix + "norm1_context.linear.weight"] = checkpoint.pop(old_prefix + "mod_y.weight")
new_state_dict[block_prefix + "norm1_context.linear.bias"] = checkpoint.pop(old_prefix + "mod_y.bias")
else:
new_state_dict[block_prefix + "norm1_context.linear_1.weight"] = checkpoint.pop(
converted_state_dict[block_prefix + "norm1_context.linear.weight"] = checkpoint.pop(
old_prefix + "mod_y.weight"
)
new_state_dict[block_prefix + "norm1_context.linear_1.bias"] = checkpoint.pop(old_prefix + "mod_y.bias")
converted_state_dict[block_prefix + "norm1_context.linear.bias"] = checkpoint.pop(
old_prefix + "mod_y.bias"
)
else:
converted_state_dict[block_prefix + "norm1_context.linear_1.weight"] = checkpoint.pop(
old_prefix + "mod_y.weight"
)
converted_state_dict[block_prefix + "norm1_context.linear_1.bias"] = checkpoint.pop(
old_prefix + "mod_y.bias"
)
# Visual attention
qkv_weight = checkpoint.pop(old_prefix + "attn.qkv_x.weight")
q, k, v = qkv_weight.chunk(3, dim=0)
new_state_dict[block_prefix + "attn1.to_q.weight"] = q
new_state_dict[block_prefix + "attn1.to_k.weight"] = k
new_state_dict[block_prefix + "attn1.to_v.weight"] = v
new_state_dict[block_prefix + "attn1.norm_q.weight"] = checkpoint.pop(old_prefix + "attn.q_norm_x.weight")
new_state_dict[block_prefix + "attn1.norm_k.weight"] = checkpoint.pop(old_prefix + "attn.k_norm_x.weight")
new_state_dict[block_prefix + "attn1.to_out.0.weight"] = checkpoint.pop(old_prefix + "attn.proj_x.weight")
new_state_dict[block_prefix + "attn1.to_out.0.bias"] = checkpoint.pop(old_prefix + "attn.proj_x.bias")
converted_state_dict[block_prefix + "attn1.to_q.weight"] = q
converted_state_dict[block_prefix + "attn1.to_k.weight"] = k
converted_state_dict[block_prefix + "attn1.to_v.weight"] = v
converted_state_dict[block_prefix + "attn1.norm_q.weight"] = checkpoint.pop(
old_prefix + "attn.q_norm_x.weight"
)
converted_state_dict[block_prefix + "attn1.norm_k.weight"] = checkpoint.pop(
old_prefix + "attn.k_norm_x.weight"
)
converted_state_dict[block_prefix + "attn1.to_out.0.weight"] = checkpoint.pop(
old_prefix + "attn.proj_x.weight"
)
converted_state_dict[block_prefix + "attn1.to_out.0.bias"] = checkpoint.pop(old_prefix + "attn.proj_x.bias")
# Context attention
qkv_weight = checkpoint.pop(old_prefix + "attn.qkv_y.weight")
q, k, v = qkv_weight.chunk(3, dim=0)
new_state_dict[block_prefix + "attn1.add_q_proj.weight"] = q
new_state_dict[block_prefix + "attn1.add_k_proj.weight"] = k
new_state_dict[block_prefix + "attn1.add_v_proj.weight"] = v
new_state_dict[block_prefix + "attn1.norm_added_q.weight"] = checkpoint.pop(
converted_state_dict[block_prefix + "attn1.add_q_proj.weight"] = q
converted_state_dict[block_prefix + "attn1.add_k_proj.weight"] = k
converted_state_dict[block_prefix + "attn1.add_v_proj.weight"] = v
converted_state_dict[block_prefix + "attn1.norm_added_q.weight"] = checkpoint.pop(
old_prefix + "attn.q_norm_y.weight"
)
new_state_dict[block_prefix + "attn1.norm_added_k.weight"] = checkpoint.pop(
converted_state_dict[block_prefix + "attn1.norm_added_k.weight"] = checkpoint.pop(
old_prefix + "attn.k_norm_y.weight"
)
if i < num_layers - 1:
new_state_dict[block_prefix + "attn1.to_add_out.weight"] = checkpoint.pop(
converted_state_dict[block_prefix + "attn1.to_add_out.weight"] = checkpoint.pop(
old_prefix + "attn.proj_y.weight"
)
new_state_dict[block_prefix + "attn1.to_add_out.bias"] = checkpoint.pop(old_prefix + "attn.proj_y.bias")
converted_state_dict[block_prefix + "attn1.to_add_out.bias"] = checkpoint.pop(
old_prefix + "attn.proj_y.bias"
)
# MLP
new_state_dict[block_prefix + "ff.net.0.proj.weight"] = swap_proj_gate(
converted_state_dict[block_prefix + "ff.net.0.proj.weight"] = swap_proj_gate(
checkpoint.pop(old_prefix + "mlp_x.w1.weight")
)
new_state_dict[block_prefix + "ff.net.2.weight"] = checkpoint.pop(old_prefix + "mlp_x.w2.weight")
converted_state_dict[block_prefix + "ff.net.2.weight"] = checkpoint.pop(old_prefix + "mlp_x.w2.weight")
if i < num_layers - 1:
new_state_dict[block_prefix + "ff_context.net.0.proj.weight"] = swap_proj_gate(
converted_state_dict[block_prefix + "ff_context.net.0.proj.weight"] = swap_proj_gate(
checkpoint.pop(old_prefix + "mlp_y.w1.weight")
)
new_state_dict[block_prefix + "ff_context.net.2.weight"] = checkpoint.pop(old_prefix + "mlp_y.w2.weight")
converted_state_dict[block_prefix + "ff_context.net.2.weight"] = checkpoint.pop(
old_prefix + "mlp_y.w2.weight"
)
# Output layers
new_state_dict["norm_out.linear.weight"] = swap_scale_shift(checkpoint.pop("final_layer.mod.weight"), dim=0)
new_state_dict["norm_out.linear.bias"] = swap_scale_shift(checkpoint.pop("final_layer.mod.bias"), dim=0)
new_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight")
new_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias")
converted_state_dict["norm_out.linear.weight"] = swap_scale_shift(checkpoint.pop("final_layer.mod.weight"), dim=0)
converted_state_dict["norm_out.linear.bias"] = swap_scale_shift(checkpoint.pop("final_layer.mod.bias"), dim=0)
converted_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight")
converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias")
new_state_dict["pos_frequencies"] = checkpoint.pop("pos_frequencies")
converted_state_dict["pos_frequencies"] = checkpoint.pop("pos_frequencies")
return new_state_dict
return converted_state_dict
def convert_hunyuan_video_transformer_to_diffusers(checkpoint, **kwargs):
@@ -2859,3 +2895,252 @@ def convert_lumina2_to_diffusers(checkpoint, **kwargs):
converted_state_dict[diffusers_key] = checkpoint.pop(key)
return converted_state_dict
def convert_wan_transformer_to_diffusers(checkpoint, **kwargs):
converted_state_dict = {}
keys = list(checkpoint.keys())
for k in keys:
if "model.diffusion_model." in k:
checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
TRANSFORMER_KEYS_RENAME_DICT = {
"time_embedding.0": "condition_embedder.time_embedder.linear_1",
"time_embedding.2": "condition_embedder.time_embedder.linear_2",
"text_embedding.0": "condition_embedder.text_embedder.linear_1",
"text_embedding.2": "condition_embedder.text_embedder.linear_2",
"time_projection.1": "condition_embedder.time_proj",
"cross_attn": "attn2",
"self_attn": "attn1",
".o.": ".to_out.0.",
".q.": ".to_q.",
".k.": ".to_k.",
".v.": ".to_v.",
".k_img.": ".add_k_proj.",
".v_img.": ".add_v_proj.",
".norm_k_img.": ".norm_added_k.",
"head.modulation": "scale_shift_table",
"head.head": "proj_out",
"modulation": "scale_shift_table",
"ffn.0": "ffn.net.0.proj",
"ffn.2": "ffn.net.2",
# Hack to swap the layer names
# The original model calls the norms in following order: norm1, norm3, norm2
# We convert it to: norm1, norm2, norm3
"norm2": "norm__placeholder",
"norm3": "norm2",
"norm__placeholder": "norm3",
# For the I2V model
"img_emb.proj.0": "condition_embedder.image_embedder.norm1",
"img_emb.proj.1": "condition_embedder.image_embedder.ff.net.0.proj",
"img_emb.proj.3": "condition_embedder.image_embedder.ff.net.2",
"img_emb.proj.4": "condition_embedder.image_embedder.norm2",
}
for key in list(checkpoint.keys()):
new_key = key[:]
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
new_key = new_key.replace(replace_key, rename_key)
converted_state_dict[new_key] = checkpoint.pop(key)
return converted_state_dict
def convert_wan_vae_to_diffusers(checkpoint, **kwargs):
converted_state_dict = {}
# Create mappings for specific components
middle_key_mapping = {
# Encoder middle block
"encoder.middle.0.residual.0.gamma": "encoder.mid_block.resnets.0.norm1.gamma",
"encoder.middle.0.residual.2.bias": "encoder.mid_block.resnets.0.conv1.bias",
"encoder.middle.0.residual.2.weight": "encoder.mid_block.resnets.0.conv1.weight",
"encoder.middle.0.residual.3.gamma": "encoder.mid_block.resnets.0.norm2.gamma",
"encoder.middle.0.residual.6.bias": "encoder.mid_block.resnets.0.conv2.bias",
"encoder.middle.0.residual.6.weight": "encoder.mid_block.resnets.0.conv2.weight",
"encoder.middle.2.residual.0.gamma": "encoder.mid_block.resnets.1.norm1.gamma",
"encoder.middle.2.residual.2.bias": "encoder.mid_block.resnets.1.conv1.bias",
"encoder.middle.2.residual.2.weight": "encoder.mid_block.resnets.1.conv1.weight",
"encoder.middle.2.residual.3.gamma": "encoder.mid_block.resnets.1.norm2.gamma",
"encoder.middle.2.residual.6.bias": "encoder.mid_block.resnets.1.conv2.bias",
"encoder.middle.2.residual.6.weight": "encoder.mid_block.resnets.1.conv2.weight",
# Decoder middle block
"decoder.middle.0.residual.0.gamma": "decoder.mid_block.resnets.0.norm1.gamma",
"decoder.middle.0.residual.2.bias": "decoder.mid_block.resnets.0.conv1.bias",
"decoder.middle.0.residual.2.weight": "decoder.mid_block.resnets.0.conv1.weight",
"decoder.middle.0.residual.3.gamma": "decoder.mid_block.resnets.0.norm2.gamma",
"decoder.middle.0.residual.6.bias": "decoder.mid_block.resnets.0.conv2.bias",
"decoder.middle.0.residual.6.weight": "decoder.mid_block.resnets.0.conv2.weight",
"decoder.middle.2.residual.0.gamma": "decoder.mid_block.resnets.1.norm1.gamma",
"decoder.middle.2.residual.2.bias": "decoder.mid_block.resnets.1.conv1.bias",
"decoder.middle.2.residual.2.weight": "decoder.mid_block.resnets.1.conv1.weight",
"decoder.middle.2.residual.3.gamma": "decoder.mid_block.resnets.1.norm2.gamma",
"decoder.middle.2.residual.6.bias": "decoder.mid_block.resnets.1.conv2.bias",
"decoder.middle.2.residual.6.weight": "decoder.mid_block.resnets.1.conv2.weight",
}
# Create a mapping for attention blocks
attention_mapping = {
# Encoder middle attention
"encoder.middle.1.norm.gamma": "encoder.mid_block.attentions.0.norm.gamma",
"encoder.middle.1.to_qkv.weight": "encoder.mid_block.attentions.0.to_qkv.weight",
"encoder.middle.1.to_qkv.bias": "encoder.mid_block.attentions.0.to_qkv.bias",
"encoder.middle.1.proj.weight": "encoder.mid_block.attentions.0.proj.weight",
"encoder.middle.1.proj.bias": "encoder.mid_block.attentions.0.proj.bias",
# Decoder middle attention
"decoder.middle.1.norm.gamma": "decoder.mid_block.attentions.0.norm.gamma",
"decoder.middle.1.to_qkv.weight": "decoder.mid_block.attentions.0.to_qkv.weight",
"decoder.middle.1.to_qkv.bias": "decoder.mid_block.attentions.0.to_qkv.bias",
"decoder.middle.1.proj.weight": "decoder.mid_block.attentions.0.proj.weight",
"decoder.middle.1.proj.bias": "decoder.mid_block.attentions.0.proj.bias",
}
# Create a mapping for the head components
head_mapping = {
# Encoder head
"encoder.head.0.gamma": "encoder.norm_out.gamma",
"encoder.head.2.bias": "encoder.conv_out.bias",
"encoder.head.2.weight": "encoder.conv_out.weight",
# Decoder head
"decoder.head.0.gamma": "decoder.norm_out.gamma",
"decoder.head.2.bias": "decoder.conv_out.bias",
"decoder.head.2.weight": "decoder.conv_out.weight",
}
# Create a mapping for the quant components
quant_mapping = {
"conv1.weight": "quant_conv.weight",
"conv1.bias": "quant_conv.bias",
"conv2.weight": "post_quant_conv.weight",
"conv2.bias": "post_quant_conv.bias",
}
# Process each key in the state dict
for key, value in checkpoint.items():
# Handle middle block keys using the mapping
if key in middle_key_mapping:
new_key = middle_key_mapping[key]
converted_state_dict[new_key] = value
# Handle attention blocks using the mapping
elif key in attention_mapping:
new_key = attention_mapping[key]
converted_state_dict[new_key] = value
# Handle head keys using the mapping
elif key in head_mapping:
new_key = head_mapping[key]
converted_state_dict[new_key] = value
# Handle quant keys using the mapping
elif key in quant_mapping:
new_key = quant_mapping[key]
converted_state_dict[new_key] = value
# Handle encoder conv1
elif key == "encoder.conv1.weight":
converted_state_dict["encoder.conv_in.weight"] = value
elif key == "encoder.conv1.bias":
converted_state_dict["encoder.conv_in.bias"] = value
# Handle decoder conv1
elif key == "decoder.conv1.weight":
converted_state_dict["decoder.conv_in.weight"] = value
elif key == "decoder.conv1.bias":
converted_state_dict["decoder.conv_in.bias"] = value
# Handle encoder downsamples
elif key.startswith("encoder.downsamples."):
# Convert to down_blocks
new_key = key.replace("encoder.downsamples.", "encoder.down_blocks.")
# Convert residual block naming but keep the original structure
if ".residual.0.gamma" in new_key:
new_key = new_key.replace(".residual.0.gamma", ".norm1.gamma")
elif ".residual.2.bias" in new_key:
new_key = new_key.replace(".residual.2.bias", ".conv1.bias")
elif ".residual.2.weight" in new_key:
new_key = new_key.replace(".residual.2.weight", ".conv1.weight")
elif ".residual.3.gamma" in new_key:
new_key = new_key.replace(".residual.3.gamma", ".norm2.gamma")
elif ".residual.6.bias" in new_key:
new_key = new_key.replace(".residual.6.bias", ".conv2.bias")
elif ".residual.6.weight" in new_key:
new_key = new_key.replace(".residual.6.weight", ".conv2.weight")
elif ".shortcut.bias" in new_key:
new_key = new_key.replace(".shortcut.bias", ".conv_shortcut.bias")
elif ".shortcut.weight" in new_key:
new_key = new_key.replace(".shortcut.weight", ".conv_shortcut.weight")
converted_state_dict[new_key] = value
# Handle decoder upsamples
elif key.startswith("decoder.upsamples."):
# Convert to up_blocks
parts = key.split(".")
block_idx = int(parts[2])
# Group residual blocks
if "residual" in key:
if block_idx in [0, 1, 2]:
new_block_idx = 0
resnet_idx = block_idx
elif block_idx in [4, 5, 6]:
new_block_idx = 1
resnet_idx = block_idx - 4
elif block_idx in [8, 9, 10]:
new_block_idx = 2
resnet_idx = block_idx - 8
elif block_idx in [12, 13, 14]:
new_block_idx = 3
resnet_idx = block_idx - 12
else:
# Keep as is for other blocks
converted_state_dict[key] = value
continue
# Convert residual block naming
if ".residual.0.gamma" in key:
new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.norm1.gamma"
elif ".residual.2.bias" in key:
new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv1.bias"
elif ".residual.2.weight" in key:
new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv1.weight"
elif ".residual.3.gamma" in key:
new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.norm2.gamma"
elif ".residual.6.bias" in key:
new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv2.bias"
elif ".residual.6.weight" in key:
new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv2.weight"
else:
new_key = key
converted_state_dict[new_key] = value
# Handle shortcut connections
elif ".shortcut." in key:
if block_idx == 4:
new_key = key.replace(".shortcut.", ".resnets.0.conv_shortcut.")
new_key = new_key.replace("decoder.upsamples.4", "decoder.up_blocks.1")
else:
new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.")
new_key = new_key.replace(".shortcut.", ".conv_shortcut.")
converted_state_dict[new_key] = value
# Handle upsamplers
elif ".resample." in key or ".time_conv." in key:
if block_idx == 3:
new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.0.upsamplers.0")
elif block_idx == 7:
new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.1.upsamplers.0")
elif block_idx == 11:
new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.2.upsamplers.0")
else:
new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.")
converted_state_dict[new_key] = value
else:
new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.")
converted_state_dict[new_key] = value
else:
# Keep other keys unchanged
converted_state_dict[key] = value
return converted_state_dict

View File

@@ -284,8 +284,9 @@ class Attention(nn.Module):
self.norm_added_q = RMSNorm(dim_head, eps=eps)
self.norm_added_k = RMSNorm(dim_head, eps=eps)
elif qk_norm == "rms_norm_across_heads":
# Wanx applies qk norm across all heads
self.norm_added_q = RMSNorm(dim_head * heads, eps=eps)
# Wan applies qk norm across all heads
# Wan also doesn't apply a q norm
self.norm_added_q = None
self.norm_added_k = RMSNorm(dim_head * kv_heads, eps=eps)
else:
raise ValueError(

View File

@@ -20,6 +20,7 @@ import torch.nn.functional as F
import torch.utils.checkpoint
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin
from ...utils import logging
from ...utils.accelerate_utils import apply_forward_hook
from ..activations import get_activation
@@ -655,7 +656,7 @@ class WanDecoder3d(nn.Module):
return x
class AutoencoderKLWan(ModelMixin, ConfigMixin):
class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
r"""
A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos.
Introduced in [Wan 2.1].

View File

@@ -20,7 +20,7 @@ import torch.nn as nn
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ..attention import FeedForward
from ..attention_processor import Attention
@@ -288,7 +288,7 @@ class WanTransformerBlock(nn.Module):
return hidden_states
class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
r"""
A Transformer model for video-like data used in the Wan model.
@@ -329,6 +329,7 @@ class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
_skip_layerwise_casting_patterns = ["patch_embedding", "condition_embedder", "norm"]
_no_split_modules = ["WanTransformerBlock"]
_keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"]
_keys_to_ignore_on_load_unexpected = ["norm_added_q"]
@register_to_config
def __init__(

View File

@@ -0,0 +1,61 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import unittest
from diffusers import (
AutoencoderKLWan,
)
from diffusers.utils.testing_utils import (
backend_empty_cache,
enable_full_determinism,
require_torch_accelerator,
torch_device,
)
enable_full_determinism()
@require_torch_accelerator
class AutoencoderKLWanSingleFileTests(unittest.TestCase):
model_class = AutoencoderKLWan
ckpt_path = (
"https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/blob/main/split_files/vae/wan_2.1_vae.safetensors"
)
repo_id = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
def setUp(self):
super().setUp()
gc.collect()
backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
backend_empty_cache(torch_device)
def test_single_file_components(self):
model = self.model_class.from_pretrained(self.repo_id, subfolder="vae")
model_single_file = self.model_class.from_single_file(self.ckpt_path)
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
for param_name, param_value in model_single_file.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
assert (
model.config[param_name] == param_value
), f"{param_name} differs between single file loading and pretrained loading"

View File

@@ -0,0 +1,93 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import unittest
import torch
from diffusers import (
WanTransformer3DModel,
)
from diffusers.utils.testing_utils import (
backend_empty_cache,
enable_full_determinism,
require_big_gpu_with_torch_cuda,
require_torch_accelerator,
torch_device,
)
enable_full_determinism()
@require_torch_accelerator
class WanTransformer3DModelText2VideoSingleFileTest(unittest.TestCase):
model_class = WanTransformer3DModel
ckpt_path = "https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/blob/main/split_files/diffusion_models/wan2.1_t2v_1.3B_bf16.safetensors"
repo_id = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
def setUp(self):
super().setUp()
gc.collect()
backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
backend_empty_cache(torch_device)
def test_single_file_components(self):
model = self.model_class.from_pretrained(self.repo_id, subfolder="transformer")
model_single_file = self.model_class.from_single_file(self.ckpt_path)
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
for param_name, param_value in model_single_file.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
assert (
model.config[param_name] == param_value
), f"{param_name} differs between single file loading and pretrained loading"
@require_big_gpu_with_torch_cuda
@require_torch_accelerator
class WanTransformer3DModelImage2VideoSingleFileTest(unittest.TestCase):
model_class = WanTransformer3DModel
ckpt_path = "https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/blob/main/split_files/diffusion_models/wan2.1_i2v_480p_14B_fp8_e4m3fn.safetensors"
repo_id = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"
torch_dtype = torch.float8_e4m3fn
def setUp(self):
super().setUp()
gc.collect()
backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
backend_empty_cache(torch_device)
def test_single_file_components(self):
model = self.model_class.from_pretrained(self.repo_id, subfolder="transformer", torch_dtype=self.torch_dtype)
model_single_file = self.model_class.from_single_file(self.ckpt_path, torch_dtype=self.torch_dtype)
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
for param_name, param_value in model_single_file.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
assert (
model.config[param_name] == param_value
), f"{param_name} differs between single file loading and pretrained loading"