Files
diffusers/scripts/convert_hunyuan_video1_5_to_diffusers.py
YiYi Xu 6156cf8f22 Hunyuanvideo15 (#12696)
* add


---------

Co-authored-by: yiyi@huggingface.co <yiyi@ip-26-0-161-123.ec2.internal>
Co-authored-by: yiyi@huggingface.co <yiyi@ip-26-0-160-103.ec2.internal>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2025-11-30 20:27:59 -10:00

851 lines
37 KiB
Python

import argparse
import json
import os
import pathlib
import torch
from accelerate import init_empty_weights
from huggingface_hub import hf_hub_download, snapshot_download
from safetensors.torch import load_file
from transformers import (
AutoModel,
AutoTokenizer,
SiglipImageProcessor,
SiglipVisionModel,
T5EncoderModel,
)
from diffusers import (
AutoencoderKLHunyuanVideo15,
ClassifierFreeGuidance,
FlowMatchEulerDiscreteScheduler,
HunyuanVideo15ImageToVideoPipeline,
HunyuanVideo15Pipeline,
HunyuanVideo15Transformer3DModel,
)
# to convert only transformer
"""
python scripts/convert_hunyuan_video1_5_to_diffusers.py \
--original_state_dict_repo_id tencent/HunyuanVideo-1.5\
--output_path /fsx/yiyi/HunyuanVideo-1.5-Diffusers/transformer\
--transformer_type 480p_t2v
"""
# to convert full pipeline
"""
python scripts/convert_hunyuan_video1_5_to_diffusers.py \
--original_state_dict_repo_id tencent/HunyuanVideo-1.5\
--output_path /fsx/yiyi/HunyuanVideo-1.5-Diffusers \
--save_pipeline \
--byt5_path /fsx/yiyi/hy15/text_encoder/Glyph-SDXL-v2\
--transformer_type 480p_t2v
"""
TRANSFORMER_CONFIGS = {
"480p_t2v": {
"target_size": 640,
"task_type": "i2v",
},
"720p_t2v": {
"target_size": 960,
"task_type": "t2v",
},
"720p_i2v": {
"target_size": 960,
"task_type": "i2v",
},
"480p_t2v_distilled": {
"target_size": 640,
"task_type": "t2v",
},
"480p_i2v_distilled": {
"target_size": 640,
"task_type": "i2v",
},
"720p_i2v_distilled": {
"target_size": 960,
"task_type": "i2v",
},
}
SCHEDULER_CONFIGS = {
"480p_t2v": {
"shift": 5.0,
},
"480p_i2v": {
"shift": 5.0,
},
"720p_t2v": {
"shift": 9.0,
},
"720p_i2v": {
"shift": 7.0,
},
"480p_t2v_distilled": {
"shift": 5.0,
},
"480p_i2v_distilled": {
"shift": 5.0,
},
"720p_i2v_distilled": {
"shift": 7.0,
},
}
GUIDANCE_CONFIGS = {
"480p_t2v": {
"guidance_scale": 6.0,
},
"480p_i2v": {
"guidance_scale": 6.0,
},
"720p_t2v": {
"guidance_scale": 6.0,
},
"720p_i2v": {
"guidance_scale": 6.0,
},
"480p_t2v_distilled": {
"guidance_scale": 1.0,
},
"480p_i2v_distilled": {
"guidance_scale": 1.0,
},
"720p_i2v_distilled": {
"guidance_scale": 1.0,
},
}
def swap_scale_shift(weight):
shift, scale = weight.chunk(2, dim=0)
new_weight = torch.cat([scale, shift], dim=0)
return new_weight
def convert_hyvideo15_transformer_to_diffusers(original_state_dict):
"""
Convert HunyuanVideo 1.5 original checkpoint to Diffusers format.
"""
converted_state_dict = {}
# 1. time_embed.timestep_embedder <- time_in
converted_state_dict["time_embed.timestep_embedder.linear_1.weight"] = original_state_dict.pop(
"time_in.mlp.0.weight"
)
converted_state_dict["time_embed.timestep_embedder.linear_1.bias"] = original_state_dict.pop("time_in.mlp.0.bias")
converted_state_dict["time_embed.timestep_embedder.linear_2.weight"] = original_state_dict.pop(
"time_in.mlp.2.weight"
)
converted_state_dict["time_embed.timestep_embedder.linear_2.bias"] = original_state_dict.pop("time_in.mlp.2.bias")
# 2. context_embedder.time_text_embed.timestep_embedder <- txt_in.t_embedder
converted_state_dict["context_embedder.time_text_embed.timestep_embedder.linear_1.weight"] = (
original_state_dict.pop("txt_in.t_embedder.mlp.0.weight")
)
converted_state_dict["context_embedder.time_text_embed.timestep_embedder.linear_1.bias"] = original_state_dict.pop(
"txt_in.t_embedder.mlp.0.bias"
)
converted_state_dict["context_embedder.time_text_embed.timestep_embedder.linear_2.weight"] = (
original_state_dict.pop("txt_in.t_embedder.mlp.2.weight")
)
converted_state_dict["context_embedder.time_text_embed.timestep_embedder.linear_2.bias"] = original_state_dict.pop(
"txt_in.t_embedder.mlp.2.bias"
)
# 3. context_embedder.time_text_embed.text_embedder <- txt_in.c_embedder
converted_state_dict["context_embedder.time_text_embed.text_embedder.linear_1.weight"] = original_state_dict.pop(
"txt_in.c_embedder.linear_1.weight"
)
converted_state_dict["context_embedder.time_text_embed.text_embedder.linear_1.bias"] = original_state_dict.pop(
"txt_in.c_embedder.linear_1.bias"
)
converted_state_dict["context_embedder.time_text_embed.text_embedder.linear_2.weight"] = original_state_dict.pop(
"txt_in.c_embedder.linear_2.weight"
)
converted_state_dict["context_embedder.time_text_embed.text_embedder.linear_2.bias"] = original_state_dict.pop(
"txt_in.c_embedder.linear_2.bias"
)
# 4. context_embedder.proj_in <- txt_in.input_embedder
converted_state_dict["context_embedder.proj_in.weight"] = original_state_dict.pop("txt_in.input_embedder.weight")
converted_state_dict["context_embedder.proj_in.bias"] = original_state_dict.pop("txt_in.input_embedder.bias")
# 5. context_embedder.token_refiner <- txt_in.individual_token_refiner
num_refiner_blocks = 2
for i in range(num_refiner_blocks):
block_prefix = f"context_embedder.token_refiner.refiner_blocks.{i}."
orig_prefix = f"txt_in.individual_token_refiner.blocks.{i}."
# norm1
converted_state_dict[f"{block_prefix}norm1.weight"] = original_state_dict.pop(f"{orig_prefix}norm1.weight")
converted_state_dict[f"{block_prefix}norm1.bias"] = original_state_dict.pop(f"{orig_prefix}norm1.bias")
# Split self_attn_qkv into to_q, to_k, to_v
qkv_weight = original_state_dict.pop(f"{orig_prefix}self_attn_qkv.weight")
qkv_bias = original_state_dict.pop(f"{orig_prefix}self_attn_qkv.bias")
q, k, v = torch.chunk(qkv_weight, 3, dim=0)
q_bias, k_bias, v_bias = torch.chunk(qkv_bias, 3, dim=0)
converted_state_dict[f"{block_prefix}attn.to_q.weight"] = q
converted_state_dict[f"{block_prefix}attn.to_q.bias"] = q_bias
converted_state_dict[f"{block_prefix}attn.to_k.weight"] = k
converted_state_dict[f"{block_prefix}attn.to_k.bias"] = k_bias
converted_state_dict[f"{block_prefix}attn.to_v.weight"] = v
converted_state_dict[f"{block_prefix}attn.to_v.bias"] = v_bias
# self_attn_proj -> attn.to_out.0
converted_state_dict[f"{block_prefix}attn.to_out.0.weight"] = original_state_dict.pop(
f"{orig_prefix}self_attn_proj.weight"
)
converted_state_dict[f"{block_prefix}attn.to_out.0.bias"] = original_state_dict.pop(
f"{orig_prefix}self_attn_proj.bias"
)
# norm2
converted_state_dict[f"{block_prefix}norm2.weight"] = original_state_dict.pop(f"{orig_prefix}norm2.weight")
converted_state_dict[f"{block_prefix}norm2.bias"] = original_state_dict.pop(f"{orig_prefix}norm2.bias")
# mlp -> ff
converted_state_dict[f"{block_prefix}ff.net.0.proj.weight"] = original_state_dict.pop(
f"{orig_prefix}mlp.fc1.weight"
)
converted_state_dict[f"{block_prefix}ff.net.0.proj.bias"] = original_state_dict.pop(
f"{orig_prefix}mlp.fc1.bias"
)
converted_state_dict[f"{block_prefix}ff.net.2.weight"] = original_state_dict.pop(
f"{orig_prefix}mlp.fc2.weight"
)
converted_state_dict[f"{block_prefix}ff.net.2.bias"] = original_state_dict.pop(f"{orig_prefix}mlp.fc2.bias")
# adaLN_modulation -> norm_out
converted_state_dict[f"{block_prefix}norm_out.linear.weight"] = original_state_dict.pop(
f"{orig_prefix}adaLN_modulation.1.weight"
)
converted_state_dict[f"{block_prefix}norm_out.linear.bias"] = original_state_dict.pop(
f"{orig_prefix}adaLN_modulation.1.bias"
)
# 6. context_embedder_2 <- byt5_in
converted_state_dict["context_embedder_2.norm.weight"] = original_state_dict.pop("byt5_in.layernorm.weight")
converted_state_dict["context_embedder_2.norm.bias"] = original_state_dict.pop("byt5_in.layernorm.bias")
converted_state_dict["context_embedder_2.linear_1.weight"] = original_state_dict.pop("byt5_in.fc1.weight")
converted_state_dict["context_embedder_2.linear_1.bias"] = original_state_dict.pop("byt5_in.fc1.bias")
converted_state_dict["context_embedder_2.linear_2.weight"] = original_state_dict.pop("byt5_in.fc2.weight")
converted_state_dict["context_embedder_2.linear_2.bias"] = original_state_dict.pop("byt5_in.fc2.bias")
converted_state_dict["context_embedder_2.linear_3.weight"] = original_state_dict.pop("byt5_in.fc3.weight")
converted_state_dict["context_embedder_2.linear_3.bias"] = original_state_dict.pop("byt5_in.fc3.bias")
# 7. image_embedder <- vision_in
converted_state_dict["image_embedder.norm_in.weight"] = original_state_dict.pop("vision_in.proj.0.weight")
converted_state_dict["image_embedder.norm_in.bias"] = original_state_dict.pop("vision_in.proj.0.bias")
converted_state_dict["image_embedder.linear_1.weight"] = original_state_dict.pop("vision_in.proj.1.weight")
converted_state_dict["image_embedder.linear_1.bias"] = original_state_dict.pop("vision_in.proj.1.bias")
converted_state_dict["image_embedder.linear_2.weight"] = original_state_dict.pop("vision_in.proj.3.weight")
converted_state_dict["image_embedder.linear_2.bias"] = original_state_dict.pop("vision_in.proj.3.bias")
converted_state_dict["image_embedder.norm_out.weight"] = original_state_dict.pop("vision_in.proj.4.weight")
converted_state_dict["image_embedder.norm_out.bias"] = original_state_dict.pop("vision_in.proj.4.bias")
# 8. x_embedder <- img_in
converted_state_dict["x_embedder.proj.weight"] = original_state_dict.pop("img_in.proj.weight")
converted_state_dict["x_embedder.proj.bias"] = original_state_dict.pop("img_in.proj.bias")
# 9. cond_type_embed <- cond_type_embedding
converted_state_dict["cond_type_embed.weight"] = original_state_dict.pop("cond_type_embedding.weight")
# 10. transformer_blocks <- double_blocks
num_layers = 54
for i in range(num_layers):
block_prefix = f"transformer_blocks.{i}."
orig_prefix = f"double_blocks.{i}."
# norm1 (img_mod)
converted_state_dict[f"{block_prefix}norm1.linear.weight"] = original_state_dict.pop(
f"{orig_prefix}img_mod.linear.weight"
)
converted_state_dict[f"{block_prefix}norm1.linear.bias"] = original_state_dict.pop(
f"{orig_prefix}img_mod.linear.bias"
)
# norm1_context (txt_mod)
converted_state_dict[f"{block_prefix}norm1_context.linear.weight"] = original_state_dict.pop(
f"{orig_prefix}txt_mod.linear.weight"
)
converted_state_dict[f"{block_prefix}norm1_context.linear.bias"] = original_state_dict.pop(
f"{orig_prefix}txt_mod.linear.bias"
)
# img attention (to_q, to_k, to_v)
converted_state_dict[f"{block_prefix}attn.to_q.weight"] = original_state_dict.pop(
f"{orig_prefix}img_attn_q.weight"
)
converted_state_dict[f"{block_prefix}attn.to_q.bias"] = original_state_dict.pop(
f"{orig_prefix}img_attn_q.bias"
)
converted_state_dict[f"{block_prefix}attn.to_k.weight"] = original_state_dict.pop(
f"{orig_prefix}img_attn_k.weight"
)
converted_state_dict[f"{block_prefix}attn.to_k.bias"] = original_state_dict.pop(
f"{orig_prefix}img_attn_k.bias"
)
converted_state_dict[f"{block_prefix}attn.to_v.weight"] = original_state_dict.pop(
f"{orig_prefix}img_attn_v.weight"
)
converted_state_dict[f"{block_prefix}attn.to_v.bias"] = original_state_dict.pop(
f"{orig_prefix}img_attn_v.bias"
)
# img attention qk norm
converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = original_state_dict.pop(
f"{orig_prefix}img_attn_q_norm.weight"
)
converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = original_state_dict.pop(
f"{orig_prefix}img_attn_k_norm.weight"
)
# img attention output projection
converted_state_dict[f"{block_prefix}attn.to_out.0.weight"] = original_state_dict.pop(
f"{orig_prefix}img_attn_proj.weight"
)
converted_state_dict[f"{block_prefix}attn.to_out.0.bias"] = original_state_dict.pop(
f"{orig_prefix}img_attn_proj.bias"
)
# txt attention (add_q_proj, add_k_proj, add_v_proj)
converted_state_dict[f"{block_prefix}attn.add_q_proj.weight"] = original_state_dict.pop(
f"{orig_prefix}txt_attn_q.weight"
)
converted_state_dict[f"{block_prefix}attn.add_q_proj.bias"] = original_state_dict.pop(
f"{orig_prefix}txt_attn_q.bias"
)
converted_state_dict[f"{block_prefix}attn.add_k_proj.weight"] = original_state_dict.pop(
f"{orig_prefix}txt_attn_k.weight"
)
converted_state_dict[f"{block_prefix}attn.add_k_proj.bias"] = original_state_dict.pop(
f"{orig_prefix}txt_attn_k.bias"
)
converted_state_dict[f"{block_prefix}attn.add_v_proj.weight"] = original_state_dict.pop(
f"{orig_prefix}txt_attn_v.weight"
)
converted_state_dict[f"{block_prefix}attn.add_v_proj.bias"] = original_state_dict.pop(
f"{orig_prefix}txt_attn_v.bias"
)
# txt attention qk norm
converted_state_dict[f"{block_prefix}attn.norm_added_q.weight"] = original_state_dict.pop(
f"{orig_prefix}txt_attn_q_norm.weight"
)
converted_state_dict[f"{block_prefix}attn.norm_added_k.weight"] = original_state_dict.pop(
f"{orig_prefix}txt_attn_k_norm.weight"
)
# txt attention output projection
converted_state_dict[f"{block_prefix}attn.to_add_out.weight"] = original_state_dict.pop(
f"{orig_prefix}txt_attn_proj.weight"
)
converted_state_dict[f"{block_prefix}attn.to_add_out.bias"] = original_state_dict.pop(
f"{orig_prefix}txt_attn_proj.bias"
)
# norm2 and norm2_context (these don't have weights in the original, they're LayerNorm with elementwise_affine=False)
# So we skip them
# img_mlp -> ff
converted_state_dict[f"{block_prefix}ff.net.0.proj.weight"] = original_state_dict.pop(
f"{orig_prefix}img_mlp.fc1.weight"
)
converted_state_dict[f"{block_prefix}ff.net.0.proj.bias"] = original_state_dict.pop(
f"{orig_prefix}img_mlp.fc1.bias"
)
converted_state_dict[f"{block_prefix}ff.net.2.weight"] = original_state_dict.pop(
f"{orig_prefix}img_mlp.fc2.weight"
)
converted_state_dict[f"{block_prefix}ff.net.2.bias"] = original_state_dict.pop(
f"{orig_prefix}img_mlp.fc2.bias"
)
# txt_mlp -> ff_context
converted_state_dict[f"{block_prefix}ff_context.net.0.proj.weight"] = original_state_dict.pop(
f"{orig_prefix}txt_mlp.fc1.weight"
)
converted_state_dict[f"{block_prefix}ff_context.net.0.proj.bias"] = original_state_dict.pop(
f"{orig_prefix}txt_mlp.fc1.bias"
)
converted_state_dict[f"{block_prefix}ff_context.net.2.weight"] = original_state_dict.pop(
f"{orig_prefix}txt_mlp.fc2.weight"
)
converted_state_dict[f"{block_prefix}ff_context.net.2.bias"] = original_state_dict.pop(
f"{orig_prefix}txt_mlp.fc2.bias"
)
# 11. norm_out and proj_out <- final_layer
converted_state_dict["norm_out.linear.weight"] = swap_scale_shift(
original_state_dict.pop("final_layer.adaLN_modulation.1.weight")
)
converted_state_dict["norm_out.linear.bias"] = swap_scale_shift(
original_state_dict.pop("final_layer.adaLN_modulation.1.bias")
)
converted_state_dict["proj_out.weight"] = original_state_dict.pop("final_layer.linear.weight")
converted_state_dict["proj_out.bias"] = original_state_dict.pop("final_layer.linear.bias")
return converted_state_dict
def convert_hunyuan_video_15_vae_checkpoint_to_diffusers(
original_state_dict, block_out_channels=[128, 256, 512, 1024, 1024], layers_per_block=2
):
converted = {}
# 1. Encoder
# 1.1 conv_in
converted["encoder.conv_in.conv.weight"] = original_state_dict.pop("encoder.conv_in.conv.weight")
converted["encoder.conv_in.conv.bias"] = original_state_dict.pop("encoder.conv_in.conv.bias")
# 1.2 Down blocks
for down_block_index in range(len(block_out_channels)): # 0 to 4
# ResNet blocks
for resnet_block_index in range(layers_per_block): # 0 to 1
converted[f"encoder.down_blocks.{down_block_index}.resnets.{resnet_block_index}.norm1.gamma"] = (
original_state_dict.pop(f"encoder.down.{down_block_index}.block.{resnet_block_index}.norm1.gamma")
)
converted[f"encoder.down_blocks.{down_block_index}.resnets.{resnet_block_index}.conv1.conv.weight"] = (
original_state_dict.pop(
f"encoder.down.{down_block_index}.block.{resnet_block_index}.conv1.conv.weight"
)
)
converted[f"encoder.down_blocks.{down_block_index}.resnets.{resnet_block_index}.conv1.conv.bias"] = (
original_state_dict.pop(f"encoder.down.{down_block_index}.block.{resnet_block_index}.conv1.conv.bias")
)
converted[f"encoder.down_blocks.{down_block_index}.resnets.{resnet_block_index}.norm2.gamma"] = (
original_state_dict.pop(f"encoder.down.{down_block_index}.block.{resnet_block_index}.norm2.gamma")
)
converted[f"encoder.down_blocks.{down_block_index}.resnets.{resnet_block_index}.conv2.conv.weight"] = (
original_state_dict.pop(
f"encoder.down.{down_block_index}.block.{resnet_block_index}.conv2.conv.weight"
)
)
converted[f"encoder.down_blocks.{down_block_index}.resnets.{resnet_block_index}.conv2.conv.bias"] = (
original_state_dict.pop(f"encoder.down.{down_block_index}.block.{resnet_block_index}.conv2.conv.bias")
)
# Downsample (if exists)
if f"encoder.down.{down_block_index}.downsample.conv.conv.weight" in original_state_dict:
converted[f"encoder.down_blocks.{down_block_index}.downsamplers.0.conv.conv.weight"] = (
original_state_dict.pop(f"encoder.down.{down_block_index}.downsample.conv.conv.weight")
)
converted[f"encoder.down_blocks.{down_block_index}.downsamplers.0.conv.conv.bias"] = (
original_state_dict.pop(f"encoder.down.{down_block_index}.downsample.conv.conv.bias")
)
# 1.3 Mid block
converted["encoder.mid_block.resnets.0.norm1.gamma"] = original_state_dict.pop("encoder.mid.block_1.norm1.gamma")
converted["encoder.mid_block.resnets.0.conv1.conv.weight"] = original_state_dict.pop(
"encoder.mid.block_1.conv1.conv.weight"
)
converted["encoder.mid_block.resnets.0.conv1.conv.bias"] = original_state_dict.pop(
"encoder.mid.block_1.conv1.conv.bias"
)
converted["encoder.mid_block.resnets.0.norm2.gamma"] = original_state_dict.pop("encoder.mid.block_1.norm2.gamma")
converted["encoder.mid_block.resnets.0.conv2.conv.weight"] = original_state_dict.pop(
"encoder.mid.block_1.conv2.conv.weight"
)
converted["encoder.mid_block.resnets.0.conv2.conv.bias"] = original_state_dict.pop(
"encoder.mid.block_1.conv2.conv.bias"
)
converted["encoder.mid_block.resnets.1.norm1.gamma"] = original_state_dict.pop("encoder.mid.block_2.norm1.gamma")
converted["encoder.mid_block.resnets.1.conv1.conv.weight"] = original_state_dict.pop(
"encoder.mid.block_2.conv1.conv.weight"
)
converted["encoder.mid_block.resnets.1.conv1.conv.bias"] = original_state_dict.pop(
"encoder.mid.block_2.conv1.conv.bias"
)
converted["encoder.mid_block.resnets.1.norm2.gamma"] = original_state_dict.pop("encoder.mid.block_2.norm2.gamma")
converted["encoder.mid_block.resnets.1.conv2.conv.weight"] = original_state_dict.pop(
"encoder.mid.block_2.conv2.conv.weight"
)
converted["encoder.mid_block.resnets.1.conv2.conv.bias"] = original_state_dict.pop(
"encoder.mid.block_2.conv2.conv.bias"
)
# Attention block
converted["encoder.mid_block.attentions.0.norm.gamma"] = original_state_dict.pop("encoder.mid.attn_1.norm.gamma")
converted["encoder.mid_block.attentions.0.to_q.weight"] = original_state_dict.pop("encoder.mid.attn_1.q.weight")
converted["encoder.mid_block.attentions.0.to_q.bias"] = original_state_dict.pop("encoder.mid.attn_1.q.bias")
converted["encoder.mid_block.attentions.0.to_k.weight"] = original_state_dict.pop("encoder.mid.attn_1.k.weight")
converted["encoder.mid_block.attentions.0.to_k.bias"] = original_state_dict.pop("encoder.mid.attn_1.k.bias")
converted["encoder.mid_block.attentions.0.to_v.weight"] = original_state_dict.pop("encoder.mid.attn_1.v.weight")
converted["encoder.mid_block.attentions.0.to_v.bias"] = original_state_dict.pop("encoder.mid.attn_1.v.bias")
converted["encoder.mid_block.attentions.0.proj_out.weight"] = original_state_dict.pop(
"encoder.mid.attn_1.proj_out.weight"
)
converted["encoder.mid_block.attentions.0.proj_out.bias"] = original_state_dict.pop(
"encoder.mid.attn_1.proj_out.bias"
)
# 1.4 Encoder output
converted["encoder.norm_out.gamma"] = original_state_dict.pop("encoder.norm_out.gamma")
converted["encoder.conv_out.conv.weight"] = original_state_dict.pop("encoder.conv_out.conv.weight")
converted["encoder.conv_out.conv.bias"] = original_state_dict.pop("encoder.conv_out.conv.bias")
# 2. Decoder
# 2.1 conv_in
converted["decoder.conv_in.conv.weight"] = original_state_dict.pop("decoder.conv_in.conv.weight")
converted["decoder.conv_in.conv.bias"] = original_state_dict.pop("decoder.conv_in.conv.bias")
# 2.2 Mid block
converted["decoder.mid_block.resnets.0.norm1.gamma"] = original_state_dict.pop("decoder.mid.block_1.norm1.gamma")
converted["decoder.mid_block.resnets.0.conv1.conv.weight"] = original_state_dict.pop(
"decoder.mid.block_1.conv1.conv.weight"
)
converted["decoder.mid_block.resnets.0.conv1.conv.bias"] = original_state_dict.pop(
"decoder.mid.block_1.conv1.conv.bias"
)
converted["decoder.mid_block.resnets.0.norm2.gamma"] = original_state_dict.pop("decoder.mid.block_1.norm2.gamma")
converted["decoder.mid_block.resnets.0.conv2.conv.weight"] = original_state_dict.pop(
"decoder.mid.block_1.conv2.conv.weight"
)
converted["decoder.mid_block.resnets.0.conv2.conv.bias"] = original_state_dict.pop(
"decoder.mid.block_1.conv2.conv.bias"
)
converted["decoder.mid_block.resnets.1.norm1.gamma"] = original_state_dict.pop("decoder.mid.block_2.norm1.gamma")
converted["decoder.mid_block.resnets.1.conv1.conv.weight"] = original_state_dict.pop(
"decoder.mid.block_2.conv1.conv.weight"
)
converted["decoder.mid_block.resnets.1.conv1.conv.bias"] = original_state_dict.pop(
"decoder.mid.block_2.conv1.conv.bias"
)
converted["decoder.mid_block.resnets.1.norm2.gamma"] = original_state_dict.pop("decoder.mid.block_2.norm2.gamma")
converted["decoder.mid_block.resnets.1.conv2.conv.weight"] = original_state_dict.pop(
"decoder.mid.block_2.conv2.conv.weight"
)
converted["decoder.mid_block.resnets.1.conv2.conv.bias"] = original_state_dict.pop(
"decoder.mid.block_2.conv2.conv.bias"
)
# Decoder attention block
converted["decoder.mid_block.attentions.0.norm.gamma"] = original_state_dict.pop("decoder.mid.attn_1.norm.gamma")
converted["decoder.mid_block.attentions.0.to_q.weight"] = original_state_dict.pop("decoder.mid.attn_1.q.weight")
converted["decoder.mid_block.attentions.0.to_q.bias"] = original_state_dict.pop("decoder.mid.attn_1.q.bias")
converted["decoder.mid_block.attentions.0.to_k.weight"] = original_state_dict.pop("decoder.mid.attn_1.k.weight")
converted["decoder.mid_block.attentions.0.to_k.bias"] = original_state_dict.pop("decoder.mid.attn_1.k.bias")
converted["decoder.mid_block.attentions.0.to_v.weight"] = original_state_dict.pop("decoder.mid.attn_1.v.weight")
converted["decoder.mid_block.attentions.0.to_v.bias"] = original_state_dict.pop("decoder.mid.attn_1.v.bias")
converted["decoder.mid_block.attentions.0.proj_out.weight"] = original_state_dict.pop(
"decoder.mid.attn_1.proj_out.weight"
)
converted["decoder.mid_block.attentions.0.proj_out.bias"] = original_state_dict.pop(
"decoder.mid.attn_1.proj_out.bias"
)
# 2.3 Up blocks
for up_block_index in range(len(block_out_channels)): # 0 to 5
# ResNet blocks
for resnet_block_index in range(layers_per_block + 1): # 0 to 2 (decoder has 3 resnets per level)
converted[f"decoder.up_blocks.{up_block_index}.resnets.{resnet_block_index}.norm1.gamma"] = (
original_state_dict.pop(f"decoder.up.{up_block_index}.block.{resnet_block_index}.norm1.gamma")
)
converted[f"decoder.up_blocks.{up_block_index}.resnets.{resnet_block_index}.conv1.conv.weight"] = (
original_state_dict.pop(f"decoder.up.{up_block_index}.block.{resnet_block_index}.conv1.conv.weight")
)
converted[f"decoder.up_blocks.{up_block_index}.resnets.{resnet_block_index}.conv1.conv.bias"] = (
original_state_dict.pop(f"decoder.up.{up_block_index}.block.{resnet_block_index}.conv1.conv.bias")
)
converted[f"decoder.up_blocks.{up_block_index}.resnets.{resnet_block_index}.norm2.gamma"] = (
original_state_dict.pop(f"decoder.up.{up_block_index}.block.{resnet_block_index}.norm2.gamma")
)
converted[f"decoder.up_blocks.{up_block_index}.resnets.{resnet_block_index}.conv2.conv.weight"] = (
original_state_dict.pop(f"decoder.up.{up_block_index}.block.{resnet_block_index}.conv2.conv.weight")
)
converted[f"decoder.up_blocks.{up_block_index}.resnets.{resnet_block_index}.conv2.conv.bias"] = (
original_state_dict.pop(f"decoder.up.{up_block_index}.block.{resnet_block_index}.conv2.conv.bias")
)
# Upsample (if exists)
if f"decoder.up.{up_block_index}.upsample.conv.conv.weight" in original_state_dict:
converted[f"decoder.up_blocks.{up_block_index}.upsamplers.0.conv.conv.weight"] = original_state_dict.pop(
f"decoder.up.{up_block_index}.upsample.conv.conv.weight"
)
converted[f"decoder.up_blocks.{up_block_index}.upsamplers.0.conv.conv.bias"] = original_state_dict.pop(
f"decoder.up.{up_block_index}.upsample.conv.conv.bias"
)
# 2.4 Decoder output
converted["decoder.norm_out.gamma"] = original_state_dict.pop("decoder.norm_out.gamma")
converted["decoder.conv_out.conv.weight"] = original_state_dict.pop("decoder.conv_out.conv.weight")
converted["decoder.conv_out.conv.bias"] = original_state_dict.pop("decoder.conv_out.conv.bias")
return converted
def load_sharded_safetensors(dir: pathlib.Path):
file_paths = list(dir.glob("diffusion_pytorch_model*.safetensors"))
state_dict = {}
for path in file_paths:
state_dict.update(load_file(path))
return state_dict
def load_original_transformer_state_dict(args):
if args.original_state_dict_repo_id is not None:
model_dir = snapshot_download(
args.original_state_dict_repo_id,
repo_type="model",
allow_patterns="transformer/" + args.transformer_type + "/*",
)
elif args.original_state_dict_folder is not None:
model_dir = pathlib.Path(args.original_state_dict_folder)
else:
raise ValueError("Please provide either `original_state_dict_repo_id` or `original_state_dict_folder`")
model_dir = pathlib.Path(model_dir)
model_dir = model_dir / "transformer" / args.transformer_type
return load_sharded_safetensors(model_dir)
def load_original_vae_state_dict(args):
if args.original_state_dict_repo_id is not None:
ckpt_path = hf_hub_download(
repo_id=args.original_state_dict_repo_id, filename="vae/diffusion_pytorch_model.safetensors"
)
elif args.original_state_dict_folder is not None:
model_dir = pathlib.Path(args.original_state_dict_folder)
ckpt_path = model_dir / "vae/diffusion_pytorch_model.safetensors"
else:
raise ValueError("Please provide either `original_state_dict_repo_id` or `original_state_dict_folder`")
original_state_dict = load_file(ckpt_path)
return original_state_dict
def convert_transformer(args):
original_state_dict = load_original_transformer_state_dict(args)
config = TRANSFORMER_CONFIGS[args.transformer_type]
with init_empty_weights():
transformer = HunyuanVideo15Transformer3DModel(**config)
state_dict = convert_hyvideo15_transformer_to_diffusers(original_state_dict)
transformer.load_state_dict(state_dict, strict=True, assign=True)
return transformer
def convert_vae(args):
original_state_dict = load_original_vae_state_dict(args)
with init_empty_weights():
vae = AutoencoderKLHunyuanVideo15()
state_dict = convert_hunyuan_video_15_vae_checkpoint_to_diffusers(original_state_dict)
vae.load_state_dict(state_dict, strict=True, assign=True)
return vae
def load_mllm():
print(" loading from Qwen/Qwen2.5-VL-7B-Instruct")
text_encoder = AutoModel.from_pretrained(
"Qwen/Qwen2.5-VL-7B-Instruct", torch_dtype=torch.bfloat16, low_cpu_mem_usage=True
)
if hasattr(text_encoder, "language_model"):
text_encoder = text_encoder.language_model
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct", padding_side="right")
return text_encoder, tokenizer
# copied from https://github.com/Tencent-Hunyuan/HunyuanVideo-1.5/blob/910da2a829c484ea28982e8cff3bbc2cacdf1681/hyvideo/models/text_encoders/byT5/__init__.py#L89
def add_special_token(
tokenizer,
text_encoder,
add_color=True,
add_font=True,
multilingual=True,
color_ann_path="assets/color_idx.json",
font_ann_path="assets/multilingual_10-lang_idx.json",
):
"""
Add special tokens for color and font to tokenizer and text encoder.
Args:
tokenizer: Huggingface tokenizer.
text_encoder: Huggingface T5 encoder.
add_color (bool): Whether to add color tokens.
add_font (bool): Whether to add font tokens.
color_ann_path (str): Path to color annotation JSON.
font_ann_path (str): Path to font annotation JSON.
multilingual (bool): Whether to use multilingual font tokens.
"""
with open(font_ann_path, "r") as f:
idx_font_dict = json.load(f)
with open(color_ann_path, "r") as f:
idx_color_dict = json.load(f)
if multilingual:
font_token = [f"<{font_code[:2]}-font-{idx_font_dict[font_code]}>" for font_code in idx_font_dict]
else:
font_token = [f"<font-{i}>" for i in range(len(idx_font_dict))]
color_token = [f"<color-{i}>" for i in range(len(idx_color_dict))]
additional_special_tokens = []
if add_color:
additional_special_tokens += color_token
if add_font:
additional_special_tokens += font_token
tokenizer.add_tokens(additional_special_tokens, special_tokens=True)
# Set mean_resizing=False to avoid PyTorch LAPACK dependency
text_encoder.resize_token_embeddings(len(tokenizer), mean_resizing=False)
def load_byt5(args):
"""
Load ByT5 encoder with Glyph-SDXL-v2 weights and save in HuggingFace format.
"""
# 1. Load base tokenizer and encoder
tokenizer = AutoTokenizer.from_pretrained("google/byt5-small")
# Load as T5EncoderModel
encoder = T5EncoderModel.from_pretrained("google/byt5-small")
byt5_checkpoint_path = os.path.join(args.byt5_path, "checkpoints/byt5_model.pt")
color_ann_path = os.path.join(args.byt5_path, "assets/color_idx.json")
font_ann_path = os.path.join(args.byt5_path, "assets/multilingual_10-lang_idx.json")
# 2. Add special tokens
add_special_token(
tokenizer=tokenizer,
text_encoder=encoder,
add_color=True,
add_font=True,
color_ann_path=color_ann_path,
font_ann_path=font_ann_path,
multilingual=True,
)
# 3. Load Glyph-SDXL-v2 checkpoint
print(f"\n3. Loading Glyph-SDXL-v2 checkpoint: {byt5_checkpoint_path}")
checkpoint = torch.load(byt5_checkpoint_path, map_location="cpu")
# Handle different checkpoint formats
if "state_dict" in checkpoint:
state_dict = checkpoint["state_dict"]
else:
state_dict = checkpoint
# add 'encoder.' prefix to the keys
# Remove 'module.text_tower.encoder.' prefix if present
cleaned_state_dict = {}
for key, value in state_dict.items():
if key.startswith("module.text_tower.encoder."):
new_key = "encoder." + key[len("module.text_tower.encoder.") :]
cleaned_state_dict[new_key] = value
else:
new_key = "encoder." + key
cleaned_state_dict[new_key] = value
# 4. Load weights
missing_keys, unexpected_keys = encoder.load_state_dict(cleaned_state_dict, strict=False)
if unexpected_keys:
raise ValueError(f"Unexpected keys: {unexpected_keys}")
if "shared.weight" in missing_keys:
print(" Missing shared.weight as expected")
missing_keys.remove("shared.weight")
if missing_keys:
raise ValueError(f"Missing keys: {missing_keys}")
return encoder, tokenizer
def load_siglip():
image_encoder = SiglipVisionModel.from_pretrained(
"black-forest-labs/FLUX.1-Redux-dev", subfolder="image_encoder", torch_dtype=torch.bfloat16
)
feature_extractor = SiglipImageProcessor.from_pretrained(
"black-forest-labs/FLUX.1-Redux-dev", subfolder="feature_extractor"
)
return image_encoder, feature_extractor
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--original_state_dict_repo_id", type=str, default=None, help="Path to original hub_id for the model"
)
parser.add_argument(
"--original_state_dict_folder", type=str, default=None, help="Local folder name of the original state dict"
)
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model(s) should be saved")
parser.add_argument("--transformer_type", type=str, default="480p_i2v", choices=list(TRANSFORMER_CONFIGS.keys()))
parser.add_argument(
"--byt5_path",
type=str,
default=None,
help=(
"path to the downloaded byt5 checkpoint & assets. "
"Note: They use Glyph-SDXL-v2 as byt5 encoder. You can download from modelscope like: "
"`modelscope download --model AI-ModelScope/Glyph-SDXL-v2 --local_dir ./ckpts/text_encoder/Glyph-SDXL-v2` "
"or manually download following the instructions on "
"https://github.com/Tencent-Hunyuan/HunyuanVideo-1.5/blob/910da2a829c484ea28982e8cff3bbc2cacdf1681/checkpoints-download.md. "
"The path should point to the Glyph-SDXL-v2 folder which should contain an `assets` folder and a `checkpoints` folder, "
"like: Glyph-SDXL-v2/assets/... and Glyph-SDXL-v2/checkpoints/byt5_model.pt"
),
)
parser.add_argument("--save_pipeline", action="store_true")
return parser.parse_args()
if __name__ == "__main__":
args = get_args()
if args.save_pipeline and args.byt5_path is None:
raise ValueError("Please provide --byt5_path when saving pipeline")
transformer = None
transformer = convert_transformer(args)
if not args.save_pipeline:
transformer.save_pretrained(args.output_path, safe_serialization=True)
else:
task_type = transformer.config.task_type
vae = convert_vae(args)
text_encoder, tokenizer = load_mllm()
text_encoder_2, tokenizer_2 = load_byt5(args)
flow_shift = SCHEDULER_CONFIGS[args.transformer_type]["shift"]
scheduler = FlowMatchEulerDiscreteScheduler(shift=flow_shift)
guidance_scale = GUIDANCE_CONFIGS[args.transformer_type]["guidance_scale"]
guider = ClassifierFreeGuidance(guidance_scale=guidance_scale)
if task_type == "i2v":
image_encoder, feature_extractor = load_siglip()
pipeline = HunyuanVideo15ImageToVideoPipeline(
vae=vae,
text_encoder=text_encoder,
text_encoder_2=text_encoder_2,
tokenizer=tokenizer,
tokenizer_2=tokenizer_2,
transformer=transformer,
guider=guider,
scheduler=scheduler,
image_encoder=image_encoder,
feature_extractor=feature_extractor,
)
elif task_type == "t2v":
pipeline = HunyuanVideo15Pipeline(
vae=vae,
text_encoder=text_encoder,
text_encoder_2=text_encoder_2,
tokenizer=tokenizer,
tokenizer_2=tokenizer_2,
transformer=transformer,
guider=guider,
scheduler=scheduler,
)
else:
raise ValueError(f"Task type {task_type} is not supported")
pipeline.save_pretrained(args.output_path, safe_serialization=True)