Files
diffusers/scripts/convert_hunyuan_image_to_diffusers.py
YiYi Xu a138d71ec1 HunyuanImage21 (#12333)
* add hunyuanimage2.1


---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2025-10-23 22:31:12 -10:00

1045 lines
53 KiB
Python

import argparse
import logging
import torch
from safetensors import safe_open
from diffusers import AutoencoderKLHunyuanImage, AutoencoderKLHunyuanImageRefiner, HunyuanImageTransformer2DModel
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
"""
Usage examples
==============
python scripts/convert_hunyuan_image_to_diffusers.py \
--model_type hunyuanimage2.1 \
--transformer_checkpoint_path "/raid/yiyi/HunyuanImage-2.1/ckpts/dit/hunyuanimage2.1.safetensors" \
--vae_checkpoint_path "HunyuanImage-2.1/ckpts/vae/vae_2_1/pytorch_model.ckpt" \
--output_path "/raid/yiyi/test-hy21-diffusers" \
--dtype fp32
python scripts/convert_hunyuan_image_to_diffusers.py \
--model_type hunyuanimage2.1-distilled \
--transformer_checkpoint_path "/raid/yiyi/HunyuanImage-2.1/ckpts/dit/hunyuanimage2.1-distilled.safetensors" \
--vae_checkpoint_path "/raid/yiyi/HunyuanImage-2.1/ckpts/vae/vae_2_1/pytorch_model.ckpt" \
--output_path "/raid/yiyi/test-hy21-distilled-diffusers" \
--dtype fp32
python scripts/convert_hunyuan_image_to_diffusers.py \
--model_type hunyuanimage-refiner \
--transformer_checkpoint_path "/raid/yiyi/HunyuanImage-2.1/ckpts/dit/hunyuanimage-refiner.safetensors" \
--vae_checkpoint_path "/raid/yiyi/HunyuanImage-2.1/ckpts/vae/vae_refiner/pytorch_model.pt" \
--output_path "/raid/yiyi/test-hy2-refiner-diffusers" \
--dtype fp32
"""
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_type", type=str, default=None
) # hunyuanimage2.1, hunyuanimage2.1-distilled, hunyuanimage-refiner
parser.add_argument("--transformer_checkpoint_path", default=None, type=str) # ckpts/dit/hunyuanimage2.1.safetensors
parser.add_argument("--vae_checkpoint_path", default=None, type=str) # ckpts/vae/vae_2_1/pytorch_model.ckpt
parser.add_argument("--output_path", type=str)
parser.add_argument("--dtype", type=str, default="fp32")
args = parser.parse_args()
dtype = torch.bfloat16 if args.dtype == "bf16" else torch.float32
# copied from https://github.com/Tencent-Hunyuan/HunyuanImage-2.1/hyimage/models/hunyuan/modules/hunyuanimage_dit.py#L21
def convert_hunyuan_dict_for_tensor_parallel(state_dict):
"""
Convert a Hunyuan model state dict to be compatible with tensor parallel architectures.
Args:
state_dict: Original state dict
Returns:
new_dict: Converted state dict
"""
new_dict = {}
for k, w in state_dict.items():
if k.startswith("double_blocks") and "attn_qkv.weight" in k:
hidden_size = w.shape[1]
k1 = k.replace("attn_qkv.weight", "attn_q.weight")
w1 = w[:hidden_size, :]
new_dict[k1] = w1
k2 = k.replace("attn_qkv.weight", "attn_k.weight")
w2 = w[hidden_size : 2 * hidden_size, :]
new_dict[k2] = w2
k3 = k.replace("attn_qkv.weight", "attn_v.weight")
w3 = w[-hidden_size:, :]
new_dict[k3] = w3
elif k.startswith("double_blocks") and "attn_qkv.bias" in k:
hidden_size = w.shape[0] // 3
k1 = k.replace("attn_qkv.bias", "attn_q.bias")
w1 = w[:hidden_size]
new_dict[k1] = w1
k2 = k.replace("attn_qkv.bias", "attn_k.bias")
w2 = w[hidden_size : 2 * hidden_size]
new_dict[k2] = w2
k3 = k.replace("attn_qkv.bias", "attn_v.bias")
w3 = w[-hidden_size:]
new_dict[k3] = w3
elif k.startswith("single_blocks") and "linear1" in k:
hidden_size = state_dict[k.replace("linear1", "linear2")].shape[0]
k1 = k.replace("linear1", "linear1_q")
w1 = w[:hidden_size]
new_dict[k1] = w1
k2 = k.replace("linear1", "linear1_k")
w2 = w[hidden_size : 2 * hidden_size]
new_dict[k2] = w2
k3 = k.replace("linear1", "linear1_v")
w3 = w[2 * hidden_size : 3 * hidden_size]
new_dict[k3] = w3
k4 = k.replace("linear1", "linear1_mlp")
w4 = w[3 * hidden_size :]
new_dict[k4] = w4
elif k.startswith("single_blocks") and "linear2" in k:
k1 = k.replace("linear2", "linear2.fc")
new_dict[k1] = w
else:
new_dict[k] = w
return new_dict
def load_original_vae_checkpoint(args):
# "ckpts/vae/vae_2_1/pytorch_model.ckpt"
state_dict = torch.load(args.vae_checkpoint_path)
if "state_dict" in state_dict:
state_dict = state_dict["state_dict"]
vae_state_dict = {}
for k, v in state_dict.items():
if k.startswith("vae."):
vae_state_dict[k.replace("vae.", "")] = v
for k, v in vae_state_dict.items():
if "weight" in k:
if len(v.shape) == 5 and v.shape[2] == 1:
vae_state_dict[k] = v.squeeze(2)
else:
vae_state_dict[k] = v
else:
vae_state_dict[k] = v
return vae_state_dict
def load_original_refiner_vae_checkpoint(args):
# "ckpts/vae/vae_refiner/pytorch_model.pt"
state_dict = torch.load(args.vae_checkpoint_path)
if "state_dict" in state_dict:
state_dict = state_dict["state_dict"]
vae_state_dict = {}
for k, v in state_dict.items():
if k.startswith("vae."):
vae_state_dict[k.replace("vae.", "")] = v
return vae_state_dict
def load_original_transformer_checkpoint(args):
# ckpts/dit/hunyuanimage-refiner.safetensors"
# ckpts/dit/hunyuanimage2.1.safetensors"
state_dict = {}
with safe_open(args.transformer_checkpoint_path, framework="pt", device="cpu") as f:
for key in f.keys():
state_dict[key] = f.get_tensor(key)
if args.model_type == "hunyuanimage-2.1":
state_dict = convert_hunyuan_dict_for_tensor_parallel(state_dict)
return state_dict
def convert_hunyuan_image_transformer_checkpoint_to_diffusers(
original_state_dict, use_byt5=True, guidance_distilled=False, use_meanflow=False
):
converted_state_dict = {}
# 1. byt5_in -> context_embedder_2
if use_byt5:
converted_state_dict["context_embedder_2.norm.weight"] = original_state_dict.pop("byt5_in.layernorm.weight")
converted_state_dict["context_embedder_2.norm.bias"] = original_state_dict.pop("byt5_in.layernorm.bias")
converted_state_dict["context_embedder_2.linear_1.weight"] = original_state_dict.pop("byt5_in.fc1.weight")
converted_state_dict["context_embedder_2.linear_1.bias"] = original_state_dict.pop("byt5_in.fc1.bias")
converted_state_dict["context_embedder_2.linear_2.weight"] = original_state_dict.pop("byt5_in.fc2.weight")
converted_state_dict["context_embedder_2.linear_2.bias"] = original_state_dict.pop("byt5_in.fc2.bias")
converted_state_dict["context_embedder_2.linear_3.weight"] = original_state_dict.pop("byt5_in.fc3.weight")
converted_state_dict["context_embedder_2.linear_3.bias"] = original_state_dict.pop("byt5_in.fc3.bias")
# 2. img_in -> x_embedder
converted_state_dict["x_embedder.proj.weight"] = original_state_dict.pop("img_in.proj.weight")
converted_state_dict["x_embedder.proj.bias"] = original_state_dict.pop("img_in.proj.bias")
# 3. txt_in -> context_embedder (complex mapping)
# txt_in.input_embedder -> context_embedder.proj_in
converted_state_dict["context_embedder.proj_in.weight"] = original_state_dict.pop("txt_in.input_embedder.weight")
converted_state_dict["context_embedder.proj_in.bias"] = original_state_dict.pop("txt_in.input_embedder.bias")
# txt_in.t_embedder -> context_embedder.time_text_embed.timestep_embedder
converted_state_dict["context_embedder.time_text_embed.timestep_embedder.linear_1.weight"] = (
original_state_dict.pop("txt_in.t_embedder.mlp.0.weight")
)
converted_state_dict["context_embedder.time_text_embed.timestep_embedder.linear_1.bias"] = original_state_dict.pop(
"txt_in.t_embedder.mlp.0.bias"
)
converted_state_dict["context_embedder.time_text_embed.timestep_embedder.linear_2.weight"] = (
original_state_dict.pop("txt_in.t_embedder.mlp.2.weight")
)
converted_state_dict["context_embedder.time_text_embed.timestep_embedder.linear_2.bias"] = original_state_dict.pop(
"txt_in.t_embedder.mlp.2.bias"
)
# txt_in.c_embedder -> context_embedder.time_text_embed.text_embedder
converted_state_dict["context_embedder.time_text_embed.text_embedder.linear_1.weight"] = original_state_dict.pop(
"txt_in.c_embedder.linear_1.weight"
)
converted_state_dict["context_embedder.time_text_embed.text_embedder.linear_1.bias"] = original_state_dict.pop(
"txt_in.c_embedder.linear_1.bias"
)
converted_state_dict["context_embedder.time_text_embed.text_embedder.linear_2.weight"] = original_state_dict.pop(
"txt_in.c_embedder.linear_2.weight"
)
converted_state_dict["context_embedder.time_text_embed.text_embedder.linear_2.bias"] = original_state_dict.pop(
"txt_in.c_embedder.linear_2.bias"
)
# txt_in.individual_token_refiner -> context_embedder.token_refiner
for i in range(2): # 2 refiner blocks
block_prefix = f"context_embedder.token_refiner.refiner_blocks.{i}."
# norm1
converted_state_dict[f"{block_prefix}norm1.weight"] = original_state_dict.pop(
f"txt_in.individual_token_refiner.blocks.{i}.norm1.weight"
)
converted_state_dict[f"{block_prefix}norm1.bias"] = original_state_dict.pop(
f"txt_in.individual_token_refiner.blocks.{i}.norm1.bias"
)
# norm2
converted_state_dict[f"{block_prefix}norm2.weight"] = original_state_dict.pop(
f"txt_in.individual_token_refiner.blocks.{i}.norm2.weight"
)
converted_state_dict[f"{block_prefix}norm2.bias"] = original_state_dict.pop(
f"txt_in.individual_token_refiner.blocks.{i}.norm2.bias"
)
# Split QKV
qkv_weight = original_state_dict.pop(f"txt_in.individual_token_refiner.blocks.{i}.self_attn_qkv.weight")
qkv_bias = original_state_dict.pop(f"txt_in.individual_token_refiner.blocks.{i}.self_attn_qkv.bias")
q_weight, k_weight, v_weight = torch.chunk(qkv_weight, 3, dim=0)
q_bias, k_bias, v_bias = torch.chunk(qkv_bias, 3, dim=0)
converted_state_dict[f"{block_prefix}attn.to_q.weight"] = q_weight
converted_state_dict[f"{block_prefix}attn.to_q.bias"] = q_bias
converted_state_dict[f"{block_prefix}attn.to_k.weight"] = k_weight
converted_state_dict[f"{block_prefix}attn.to_k.bias"] = k_bias
converted_state_dict[f"{block_prefix}attn.to_v.weight"] = v_weight
converted_state_dict[f"{block_prefix}attn.to_v.bias"] = v_bias
# attn projection
converted_state_dict[f"{block_prefix}attn.to_out.0.weight"] = original_state_dict.pop(
f"txt_in.individual_token_refiner.blocks.{i}.self_attn_proj.weight"
)
converted_state_dict[f"{block_prefix}attn.to_out.0.bias"] = original_state_dict.pop(
f"txt_in.individual_token_refiner.blocks.{i}.self_attn_proj.bias"
)
# MLP
converted_state_dict[f"{block_prefix}ff.net.0.proj.weight"] = original_state_dict.pop(
f"txt_in.individual_token_refiner.blocks.{i}.mlp.fc1.weight"
)
converted_state_dict[f"{block_prefix}ff.net.0.proj.bias"] = original_state_dict.pop(
f"txt_in.individual_token_refiner.blocks.{i}.mlp.fc1.bias"
)
converted_state_dict[f"{block_prefix}ff.net.2.weight"] = original_state_dict.pop(
f"txt_in.individual_token_refiner.blocks.{i}.mlp.fc2.weight"
)
converted_state_dict[f"{block_prefix}ff.net.2.bias"] = original_state_dict.pop(
f"txt_in.individual_token_refiner.blocks.{i}.mlp.fc2.bias"
)
# norm_out
converted_state_dict[f"{block_prefix}norm_out.linear.weight"] = original_state_dict.pop(
f"txt_in.individual_token_refiner.blocks.{i}.adaLN_modulation.1.weight"
)
converted_state_dict[f"{block_prefix}norm_out.linear.bias"] = original_state_dict.pop(
f"txt_in.individual_token_refiner.blocks.{i}.adaLN_modulation.1.bias"
)
# 4. time_in -> time_text_embed.timestep_embedder
converted_state_dict["time_guidance_embed.timestep_embedder.linear_1.weight"] = original_state_dict.pop(
"time_in.mlp.0.weight"
)
converted_state_dict["time_guidance_embed.timestep_embedder.linear_1.bias"] = original_state_dict.pop(
"time_in.mlp.0.bias"
)
converted_state_dict["time_guidance_embed.timestep_embedder.linear_2.weight"] = original_state_dict.pop(
"time_in.mlp.2.weight"
)
converted_state_dict["time_guidance_embed.timestep_embedder.linear_2.bias"] = original_state_dict.pop(
"time_in.mlp.2.bias"
)
# time_r_in -> time_guidance_embed.timestep_r_embedder
if use_meanflow:
converted_state_dict["time_guidance_embed.timestep_embedder_r.linear_1.weight"] = original_state_dict.pop(
"time_r_in.mlp.0.weight"
)
converted_state_dict["time_guidance_embed.timestep_embedder_r.linear_1.bias"] = original_state_dict.pop(
"time_r_in.mlp.0.bias"
)
converted_state_dict["time_guidance_embed.timestep_embedder_r.linear_2.weight"] = original_state_dict.pop(
"time_r_in.mlp.2.weight"
)
converted_state_dict["time_guidance_embed.timestep_embedder_r.linear_2.bias"] = original_state_dict.pop(
"time_r_in.mlp.2.bias"
)
# guidance_in -> time_guidance_embed.guidance_embedder
if guidance_distilled:
converted_state_dict["time_guidance_embed.guidance_embedder.linear_1.weight"] = original_state_dict.pop(
"guidance_in.mlp.0.weight"
)
converted_state_dict["time_guidance_embed.guidance_embedder.linear_1.bias"] = original_state_dict.pop(
"guidance_in.mlp.0.bias"
)
converted_state_dict["time_guidance_embed.guidance_embedder.linear_2.weight"] = original_state_dict.pop(
"guidance_in.mlp.2.weight"
)
converted_state_dict["time_guidance_embed.guidance_embedder.linear_2.bias"] = original_state_dict.pop(
"guidance_in.mlp.2.bias"
)
# 5. double_blocks -> transformer_blocks
for i in range(20): # 20 double blocks
block_prefix = f"transformer_blocks.{i}."
# norm1 (img_mod)
converted_state_dict[f"{block_prefix}norm1.linear.weight"] = original_state_dict.pop(
f"double_blocks.{i}.img_mod.linear.weight"
)
converted_state_dict[f"{block_prefix}norm1.linear.bias"] = original_state_dict.pop(
f"double_blocks.{i}.img_mod.linear.bias"
)
# norm1_context (txt_mod)
converted_state_dict[f"{block_prefix}norm1_context.linear.weight"] = original_state_dict.pop(
f"double_blocks.{i}.txt_mod.linear.weight"
)
converted_state_dict[f"{block_prefix}norm1_context.linear.bias"] = original_state_dict.pop(
f"double_blocks.{i}.txt_mod.linear.bias"
)
# img attention
converted_state_dict[f"{block_prefix}attn.to_q.weight"] = original_state_dict.pop(
f"double_blocks.{i}.img_attn_q.weight"
)
converted_state_dict[f"{block_prefix}attn.to_q.bias"] = original_state_dict.pop(
f"double_blocks.{i}.img_attn_q.bias"
)
converted_state_dict[f"{block_prefix}attn.to_k.weight"] = original_state_dict.pop(
f"double_blocks.{i}.img_attn_k.weight"
)
converted_state_dict[f"{block_prefix}attn.to_k.bias"] = original_state_dict.pop(
f"double_blocks.{i}.img_attn_k.bias"
)
converted_state_dict[f"{block_prefix}attn.to_v.weight"] = original_state_dict.pop(
f"double_blocks.{i}.img_attn_v.weight"
)
converted_state_dict[f"{block_prefix}attn.to_v.bias"] = original_state_dict.pop(
f"double_blocks.{i}.img_attn_v.bias"
)
# img attention norms
converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = original_state_dict.pop(
f"double_blocks.{i}.img_attn_q_norm.weight"
)
converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = original_state_dict.pop(
f"double_blocks.{i}.img_attn_k_norm.weight"
)
# img attention projection
converted_state_dict[f"{block_prefix}attn.to_out.0.weight"] = original_state_dict.pop(
f"double_blocks.{i}.img_attn_proj.weight"
)
converted_state_dict[f"{block_prefix}attn.to_out.0.bias"] = original_state_dict.pop(
f"double_blocks.{i}.img_attn_proj.bias"
)
# img MLP
converted_state_dict[f"{block_prefix}ff.net.0.proj.weight"] = original_state_dict.pop(
f"double_blocks.{i}.img_mlp.fc1.weight"
)
converted_state_dict[f"{block_prefix}ff.net.0.proj.bias"] = original_state_dict.pop(
f"double_blocks.{i}.img_mlp.fc1.bias"
)
converted_state_dict[f"{block_prefix}ff.net.2.weight"] = original_state_dict.pop(
f"double_blocks.{i}.img_mlp.fc2.weight"
)
converted_state_dict[f"{block_prefix}ff.net.2.bias"] = original_state_dict.pop(
f"double_blocks.{i}.img_mlp.fc2.bias"
)
# txt attention (additional projections)
converted_state_dict[f"{block_prefix}attn.add_q_proj.weight"] = original_state_dict.pop(
f"double_blocks.{i}.txt_attn_q.weight"
)
converted_state_dict[f"{block_prefix}attn.add_q_proj.bias"] = original_state_dict.pop(
f"double_blocks.{i}.txt_attn_q.bias"
)
converted_state_dict[f"{block_prefix}attn.add_k_proj.weight"] = original_state_dict.pop(
f"double_blocks.{i}.txt_attn_k.weight"
)
converted_state_dict[f"{block_prefix}attn.add_k_proj.bias"] = original_state_dict.pop(
f"double_blocks.{i}.txt_attn_k.bias"
)
converted_state_dict[f"{block_prefix}attn.add_v_proj.weight"] = original_state_dict.pop(
f"double_blocks.{i}.txt_attn_v.weight"
)
converted_state_dict[f"{block_prefix}attn.add_v_proj.bias"] = original_state_dict.pop(
f"double_blocks.{i}.txt_attn_v.bias"
)
# txt attention norms
converted_state_dict[f"{block_prefix}attn.norm_added_q.weight"] = original_state_dict.pop(
f"double_blocks.{i}.txt_attn_q_norm.weight"
)
converted_state_dict[f"{block_prefix}attn.norm_added_k.weight"] = original_state_dict.pop(
f"double_blocks.{i}.txt_attn_k_norm.weight"
)
# txt attention projection
converted_state_dict[f"{block_prefix}attn.to_add_out.weight"] = original_state_dict.pop(
f"double_blocks.{i}.txt_attn_proj.weight"
)
converted_state_dict[f"{block_prefix}attn.to_add_out.bias"] = original_state_dict.pop(
f"double_blocks.{i}.txt_attn_proj.bias"
)
# txt MLP (ff_context)
converted_state_dict[f"{block_prefix}ff_context.net.0.proj.weight"] = original_state_dict.pop(
f"double_blocks.{i}.txt_mlp.fc1.weight"
)
converted_state_dict[f"{block_prefix}ff_context.net.0.proj.bias"] = original_state_dict.pop(
f"double_blocks.{i}.txt_mlp.fc1.bias"
)
converted_state_dict[f"{block_prefix}ff_context.net.2.weight"] = original_state_dict.pop(
f"double_blocks.{i}.txt_mlp.fc2.weight"
)
converted_state_dict[f"{block_prefix}ff_context.net.2.bias"] = original_state_dict.pop(
f"double_blocks.{i}.txt_mlp.fc2.bias"
)
# 6. single_blocks -> single_transformer_blocks
for i in range(40): # 40 single blocks
block_prefix = f"single_transformer_blocks.{i}."
# norm
converted_state_dict[f"{block_prefix}norm.linear.weight"] = original_state_dict.pop(
f"single_blocks.{i}.modulation.linear.weight"
)
converted_state_dict[f"{block_prefix}norm.linear.bias"] = original_state_dict.pop(
f"single_blocks.{i}.modulation.linear.bias"
)
# attention Q, K, V
converted_state_dict[f"{block_prefix}attn.to_q.weight"] = original_state_dict.pop(
f"single_blocks.{i}.linear1_q.weight"
)
converted_state_dict[f"{block_prefix}attn.to_q.bias"] = original_state_dict.pop(
f"single_blocks.{i}.linear1_q.bias"
)
converted_state_dict[f"{block_prefix}attn.to_k.weight"] = original_state_dict.pop(
f"single_blocks.{i}.linear1_k.weight"
)
converted_state_dict[f"{block_prefix}attn.to_k.bias"] = original_state_dict.pop(
f"single_blocks.{i}.linear1_k.bias"
)
converted_state_dict[f"{block_prefix}attn.to_v.weight"] = original_state_dict.pop(
f"single_blocks.{i}.linear1_v.weight"
)
converted_state_dict[f"{block_prefix}attn.to_v.bias"] = original_state_dict.pop(
f"single_blocks.{i}.linear1_v.bias"
)
# attention norms
converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = original_state_dict.pop(
f"single_blocks.{i}.q_norm.weight"
)
converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = original_state_dict.pop(
f"single_blocks.{i}.k_norm.weight"
)
# MLP projection
converted_state_dict[f"{block_prefix}proj_mlp.weight"] = original_state_dict.pop(
f"single_blocks.{i}.linear1_mlp.weight"
)
converted_state_dict[f"{block_prefix}proj_mlp.bias"] = original_state_dict.pop(
f"single_blocks.{i}.linear1_mlp.bias"
)
# output projection
converted_state_dict[f"{block_prefix}proj_out.weight"] = original_state_dict.pop(
f"single_blocks.{i}.linear2.fc.weight"
)
converted_state_dict[f"{block_prefix}proj_out.bias"] = original_state_dict.pop(
f"single_blocks.{i}.linear2.fc.bias"
)
# 7. final_layer -> norm_out + proj_out
converted_state_dict["proj_out.weight"] = original_state_dict.pop("final_layer.linear.weight")
converted_state_dict["proj_out.bias"] = original_state_dict.pop("final_layer.linear.bias")
shift_w, scale_w = original_state_dict.pop("final_layer.adaLN_modulation.1.weight").chunk(2, dim=0)
shift_b, scale_b = original_state_dict.pop("final_layer.adaLN_modulation.1.bias").chunk(2, dim=0)
converted_state_dict["norm_out.linear.weight"] = torch.cat([scale_w, shift_w], dim=0)
converted_state_dict["norm_out.linear.bias"] = torch.cat([scale_b, shift_b], dim=0)
return converted_state_dict, original_state_dict
def convert_hunyuan_image_vae_checkpoint_to_diffusers(
original_state_dict, block_out_channels=[128, 256, 512, 512, 1024, 1024], layers_per_block=2
):
"""Convert original VAE state dict to Diffusers format."""
converted = {}
# 1. Encoder
# 1.1 conv_in
converted["encoder.conv_in.weight"] = original_state_dict.pop("encoder.conv_in.weight")
converted["encoder.conv_in.bias"] = original_state_dict.pop("encoder.conv_in.bias")
# 1.2 down blocks
diffusers_block_idx = 0
for block_index in range(len(block_out_channels)):
for resnet_block_index in range(layers_per_block):
orig_prefix = f"encoder.down.{block_index}.block.{resnet_block_index}"
diff_prefix = f"encoder.down_blocks.{diffusers_block_idx}"
# resnet blocks
converted[f"{diff_prefix}.norm1.weight"] = original_state_dict.pop(f"{orig_prefix}.norm1.weight")
converted[f"{diff_prefix}.norm1.bias"] = original_state_dict.pop(f"{orig_prefix}.norm1.bias")
converted[f"{diff_prefix}.conv1.weight"] = original_state_dict.pop(f"{orig_prefix}.conv1.weight")
converted[f"{diff_prefix}.conv1.bias"] = original_state_dict.pop(f"{orig_prefix}.conv1.bias")
converted[f"{diff_prefix}.norm2.weight"] = original_state_dict.pop(f"{orig_prefix}.norm2.weight")
converted[f"{diff_prefix}.norm2.bias"] = original_state_dict.pop(f"{orig_prefix}.norm2.bias")
converted[f"{diff_prefix}.conv2.weight"] = original_state_dict.pop(f"{orig_prefix}.conv2.weight")
converted[f"{diff_prefix}.conv2.bias"] = original_state_dict.pop(f"{orig_prefix}.conv2.bias")
diffusers_block_idx += 1
# downsample blocks
if f"encoder.down.{block_index}.downsample.conv.weight" in original_state_dict:
converted[f"encoder.down_blocks.{diffusers_block_idx}.conv.weight"] = original_state_dict.pop(
f"encoder.down.{block_index}.downsample.conv.weight"
)
converted[f"encoder.down_blocks.{diffusers_block_idx}.conv.bias"] = original_state_dict.pop(
f"encoder.down.{block_index}.downsample.conv.bias"
)
diffusers_block_idx += 1
# 1.3 mid block
converted["encoder.mid_block.resnets.0.norm1.weight"] = original_state_dict.pop("encoder.mid.block_1.norm1.weight")
converted["encoder.mid_block.resnets.0.norm1.bias"] = original_state_dict.pop("encoder.mid.block_1.norm1.bias")
converted["encoder.mid_block.resnets.0.conv1.weight"] = original_state_dict.pop("encoder.mid.block_1.conv1.weight")
converted["encoder.mid_block.resnets.0.conv1.bias"] = original_state_dict.pop("encoder.mid.block_1.conv1.bias")
converted["encoder.mid_block.resnets.0.norm2.weight"] = original_state_dict.pop("encoder.mid.block_1.norm2.weight")
converted["encoder.mid_block.resnets.0.norm2.bias"] = original_state_dict.pop("encoder.mid.block_1.norm2.bias")
converted["encoder.mid_block.resnets.0.conv2.weight"] = original_state_dict.pop("encoder.mid.block_1.conv2.weight")
converted["encoder.mid_block.resnets.0.conv2.bias"] = original_state_dict.pop("encoder.mid.block_1.conv2.bias")
converted["encoder.mid_block.resnets.1.norm1.weight"] = original_state_dict.pop("encoder.mid.block_2.norm1.weight")
converted["encoder.mid_block.resnets.1.norm1.bias"] = original_state_dict.pop("encoder.mid.block_2.norm1.bias")
converted["encoder.mid_block.resnets.1.conv1.weight"] = original_state_dict.pop("encoder.mid.block_2.conv1.weight")
converted["encoder.mid_block.resnets.1.conv1.bias"] = original_state_dict.pop("encoder.mid.block_2.conv1.bias")
converted["encoder.mid_block.resnets.1.norm2.weight"] = original_state_dict.pop("encoder.mid.block_2.norm2.weight")
converted["encoder.mid_block.resnets.1.norm2.bias"] = original_state_dict.pop("encoder.mid.block_2.norm2.bias")
converted["encoder.mid_block.resnets.1.conv2.weight"] = original_state_dict.pop("encoder.mid.block_2.conv2.weight")
converted["encoder.mid_block.resnets.1.conv2.bias"] = original_state_dict.pop("encoder.mid.block_2.conv2.bias")
converted["encoder.mid_block.attentions.0.norm.weight"] = original_state_dict.pop("encoder.mid.attn_1.norm.weight")
converted["encoder.mid_block.attentions.0.norm.bias"] = original_state_dict.pop("encoder.mid.attn_1.norm.bias")
converted["encoder.mid_block.attentions.0.to_q.weight"] = original_state_dict.pop("encoder.mid.attn_1.q.weight")
converted["encoder.mid_block.attentions.0.to_q.bias"] = original_state_dict.pop("encoder.mid.attn_1.q.bias")
converted["encoder.mid_block.attentions.0.to_k.weight"] = original_state_dict.pop("encoder.mid.attn_1.k.weight")
converted["encoder.mid_block.attentions.0.to_k.bias"] = original_state_dict.pop("encoder.mid.attn_1.k.bias")
converted["encoder.mid_block.attentions.0.to_v.weight"] = original_state_dict.pop("encoder.mid.attn_1.v.weight")
converted["encoder.mid_block.attentions.0.to_v.bias"] = original_state_dict.pop("encoder.mid.attn_1.v.bias")
converted["encoder.mid_block.attentions.0.proj.weight"] = original_state_dict.pop(
"encoder.mid.attn_1.proj_out.weight"
)
converted["encoder.mid_block.attentions.0.proj.bias"] = original_state_dict.pop("encoder.mid.attn_1.proj_out.bias")
# 1.4 encoder output
converted["encoder.norm_out.weight"] = original_state_dict.pop("encoder.norm_out.weight")
converted["encoder.norm_out.bias"] = original_state_dict.pop("encoder.norm_out.bias")
converted["encoder.conv_out.weight"] = original_state_dict.pop("encoder.conv_out.weight")
converted["encoder.conv_out.bias"] = original_state_dict.pop("encoder.conv_out.bias")
# 2. Decoder
# 2.1 conv_in
converted["decoder.conv_in.weight"] = original_state_dict.pop("decoder.conv_in.weight")
converted["decoder.conv_in.bias"] = original_state_dict.pop("decoder.conv_in.bias")
# 2.2 mid block
converted["decoder.mid_block.resnets.0.norm1.weight"] = original_state_dict.pop("decoder.mid.block_1.norm1.weight")
converted["decoder.mid_block.resnets.0.norm1.bias"] = original_state_dict.pop("decoder.mid.block_1.norm1.bias")
converted["decoder.mid_block.resnets.0.conv1.weight"] = original_state_dict.pop("decoder.mid.block_1.conv1.weight")
converted["decoder.mid_block.resnets.0.conv1.bias"] = original_state_dict.pop("decoder.mid.block_1.conv1.bias")
converted["decoder.mid_block.resnets.0.norm2.weight"] = original_state_dict.pop("decoder.mid.block_1.norm2.weight")
converted["decoder.mid_block.resnets.0.norm2.bias"] = original_state_dict.pop("decoder.mid.block_1.norm2.bias")
converted["decoder.mid_block.resnets.0.conv2.weight"] = original_state_dict.pop("decoder.mid.block_1.conv2.weight")
converted["decoder.mid_block.resnets.0.conv2.bias"] = original_state_dict.pop("decoder.mid.block_1.conv2.bias")
converted["decoder.mid_block.resnets.1.norm1.weight"] = original_state_dict.pop("decoder.mid.block_2.norm1.weight")
converted["decoder.mid_block.resnets.1.norm1.bias"] = original_state_dict.pop("decoder.mid.block_2.norm1.bias")
converted["decoder.mid_block.resnets.1.conv1.weight"] = original_state_dict.pop("decoder.mid.block_2.conv1.weight")
converted["decoder.mid_block.resnets.1.conv1.bias"] = original_state_dict.pop("decoder.mid.block_2.conv1.bias")
converted["decoder.mid_block.resnets.1.norm2.weight"] = original_state_dict.pop("decoder.mid.block_2.norm2.weight")
converted["decoder.mid_block.resnets.1.norm2.bias"] = original_state_dict.pop("decoder.mid.block_2.norm2.bias")
converted["decoder.mid_block.resnets.1.conv2.weight"] = original_state_dict.pop("decoder.mid.block_2.conv2.weight")
converted["decoder.mid_block.resnets.1.conv2.bias"] = original_state_dict.pop("decoder.mid.block_2.conv2.bias")
converted["decoder.mid_block.attentions.0.norm.weight"] = original_state_dict.pop("decoder.mid.attn_1.norm.weight")
converted["decoder.mid_block.attentions.0.norm.bias"] = original_state_dict.pop("decoder.mid.attn_1.norm.bias")
converted["decoder.mid_block.attentions.0.to_q.weight"] = original_state_dict.pop("decoder.mid.attn_1.q.weight")
converted["decoder.mid_block.attentions.0.to_q.bias"] = original_state_dict.pop("decoder.mid.attn_1.q.bias")
converted["decoder.mid_block.attentions.0.to_k.weight"] = original_state_dict.pop("decoder.mid.attn_1.k.weight")
converted["decoder.mid_block.attentions.0.to_k.bias"] = original_state_dict.pop("decoder.mid.attn_1.k.bias")
converted["decoder.mid_block.attentions.0.to_v.weight"] = original_state_dict.pop("decoder.mid.attn_1.v.weight")
converted["decoder.mid_block.attentions.0.to_v.bias"] = original_state_dict.pop("decoder.mid.attn_1.v.bias")
converted["decoder.mid_block.attentions.0.proj.weight"] = original_state_dict.pop(
"decoder.mid.attn_1.proj_out.weight"
)
converted["decoder.mid_block.attentions.0.proj.bias"] = original_state_dict.pop("decoder.mid.attn_1.proj_out.bias")
# 2.3 up blocks
diffusers_block_idx = 0
for up_block_index in range(len(block_out_channels)):
# resnet blocks
for resnet_block_index in range(layers_per_block + 1):
orig_prefix = f"decoder.up.{up_block_index}.block.{resnet_block_index}"
diff_prefix = f"decoder.up_blocks.{diffusers_block_idx}"
converted[f"{diff_prefix}.norm1.weight"] = original_state_dict.pop(f"{orig_prefix}.norm1.weight")
converted[f"{diff_prefix}.norm1.bias"] = original_state_dict.pop(f"{orig_prefix}.norm1.bias")
converted[f"{diff_prefix}.conv1.weight"] = original_state_dict.pop(f"{orig_prefix}.conv1.weight")
converted[f"{diff_prefix}.conv1.bias"] = original_state_dict.pop(f"{orig_prefix}.conv1.bias")
converted[f"{diff_prefix}.norm2.weight"] = original_state_dict.pop(f"{orig_prefix}.norm2.weight")
converted[f"{diff_prefix}.norm2.bias"] = original_state_dict.pop(f"{orig_prefix}.norm2.bias")
converted[f"{diff_prefix}.conv2.weight"] = original_state_dict.pop(f"{orig_prefix}.conv2.weight")
converted[f"{diff_prefix}.conv2.bias"] = original_state_dict.pop(f"{orig_prefix}.conv2.bias")
diffusers_block_idx += 1
# upsample blocks
if f"decoder.up.{up_block_index}.upsample.conv.weight" in original_state_dict:
converted[f"decoder.up_blocks.{diffusers_block_idx}.conv.weight"] = original_state_dict.pop(
f"decoder.up.{up_block_index}.upsample.conv.weight"
)
converted[f"decoder.up_blocks.{diffusers_block_idx}.conv.bias"] = original_state_dict.pop(
f"decoder.up.{up_block_index}.upsample.conv.bias"
)
diffusers_block_idx += 1
# 2.4 decoder output
converted["decoder.norm_out.weight"] = original_state_dict.pop("decoder.norm_out.weight")
converted["decoder.norm_out.bias"] = original_state_dict.pop("decoder.norm_out.bias")
converted["decoder.conv_out.weight"] = original_state_dict.pop("decoder.conv_out.weight")
converted["decoder.conv_out.bias"] = original_state_dict.pop("decoder.conv_out.bias")
return converted, original_state_dict
def convert_hunyuan_image_refiner_vae_checkpoint_to_diffusers(
original_state_dict, block_out_channels=[128, 256, 512, 1024, 1024], layers_per_block=2
):
converted = {}
# 1. Encoder
# 1.1 conv_in
converted["encoder.conv_in.conv.weight"] = original_state_dict.pop("encoder.conv_in.conv.weight")
converted["encoder.conv_in.conv.bias"] = original_state_dict.pop("encoder.conv_in.conv.bias")
# 1.2 Down blocks
for down_block_index in range(len(block_out_channels)): # 0 to 4
# ResNet blocks
for resnet_block_index in range(layers_per_block): # 0 to 1
converted[f"encoder.down_blocks.{down_block_index}.resnets.{resnet_block_index}.norm1.gamma"] = (
original_state_dict.pop(f"encoder.down.{down_block_index}.block.{resnet_block_index}.norm1.gamma")
)
converted[f"encoder.down_blocks.{down_block_index}.resnets.{resnet_block_index}.conv1.conv.weight"] = (
original_state_dict.pop(
f"encoder.down.{down_block_index}.block.{resnet_block_index}.conv1.conv.weight"
)
)
converted[f"encoder.down_blocks.{down_block_index}.resnets.{resnet_block_index}.conv1.conv.bias"] = (
original_state_dict.pop(f"encoder.down.{down_block_index}.block.{resnet_block_index}.conv1.conv.bias")
)
converted[f"encoder.down_blocks.{down_block_index}.resnets.{resnet_block_index}.norm2.gamma"] = (
original_state_dict.pop(f"encoder.down.{down_block_index}.block.{resnet_block_index}.norm2.gamma")
)
converted[f"encoder.down_blocks.{down_block_index}.resnets.{resnet_block_index}.conv2.conv.weight"] = (
original_state_dict.pop(
f"encoder.down.{down_block_index}.block.{resnet_block_index}.conv2.conv.weight"
)
)
converted[f"encoder.down_blocks.{down_block_index}.resnets.{resnet_block_index}.conv2.conv.bias"] = (
original_state_dict.pop(f"encoder.down.{down_block_index}.block.{resnet_block_index}.conv2.conv.bias")
)
# Downsample (if exists)
if f"encoder.down.{down_block_index}.downsample.conv.conv.weight" in original_state_dict:
converted[f"encoder.down_blocks.{down_block_index}.downsamplers.0.conv.conv.weight"] = (
original_state_dict.pop(f"encoder.down.{down_block_index}.downsample.conv.conv.weight")
)
converted[f"encoder.down_blocks.{down_block_index}.downsamplers.0.conv.conv.bias"] = (
original_state_dict.pop(f"encoder.down.{down_block_index}.downsample.conv.conv.bias")
)
# 1.3 Mid block
converted["encoder.mid_block.resnets.0.norm1.gamma"] = original_state_dict.pop("encoder.mid.block_1.norm1.gamma")
converted["encoder.mid_block.resnets.0.conv1.conv.weight"] = original_state_dict.pop(
"encoder.mid.block_1.conv1.conv.weight"
)
converted["encoder.mid_block.resnets.0.conv1.conv.bias"] = original_state_dict.pop(
"encoder.mid.block_1.conv1.conv.bias"
)
converted["encoder.mid_block.resnets.0.norm2.gamma"] = original_state_dict.pop("encoder.mid.block_1.norm2.gamma")
converted["encoder.mid_block.resnets.0.conv2.conv.weight"] = original_state_dict.pop(
"encoder.mid.block_1.conv2.conv.weight"
)
converted["encoder.mid_block.resnets.0.conv2.conv.bias"] = original_state_dict.pop(
"encoder.mid.block_1.conv2.conv.bias"
)
converted["encoder.mid_block.resnets.1.norm1.gamma"] = original_state_dict.pop("encoder.mid.block_2.norm1.gamma")
converted["encoder.mid_block.resnets.1.conv1.conv.weight"] = original_state_dict.pop(
"encoder.mid.block_2.conv1.conv.weight"
)
converted["encoder.mid_block.resnets.1.conv1.conv.bias"] = original_state_dict.pop(
"encoder.mid.block_2.conv1.conv.bias"
)
converted["encoder.mid_block.resnets.1.norm2.gamma"] = original_state_dict.pop("encoder.mid.block_2.norm2.gamma")
converted["encoder.mid_block.resnets.1.conv2.conv.weight"] = original_state_dict.pop(
"encoder.mid.block_2.conv2.conv.weight"
)
converted["encoder.mid_block.resnets.1.conv2.conv.bias"] = original_state_dict.pop(
"encoder.mid.block_2.conv2.conv.bias"
)
# Attention block
converted["encoder.mid_block.attentions.0.norm.gamma"] = original_state_dict.pop("encoder.mid.attn_1.norm.gamma")
converted["encoder.mid_block.attentions.0.to_q.weight"] = original_state_dict.pop("encoder.mid.attn_1.q.weight")
converted["encoder.mid_block.attentions.0.to_q.bias"] = original_state_dict.pop("encoder.mid.attn_1.q.bias")
converted["encoder.mid_block.attentions.0.to_k.weight"] = original_state_dict.pop("encoder.mid.attn_1.k.weight")
converted["encoder.mid_block.attentions.0.to_k.bias"] = original_state_dict.pop("encoder.mid.attn_1.k.bias")
converted["encoder.mid_block.attentions.0.to_v.weight"] = original_state_dict.pop("encoder.mid.attn_1.v.weight")
converted["encoder.mid_block.attentions.0.to_v.bias"] = original_state_dict.pop("encoder.mid.attn_1.v.bias")
converted["encoder.mid_block.attentions.0.proj_out.weight"] = original_state_dict.pop(
"encoder.mid.attn_1.proj_out.weight"
)
converted["encoder.mid_block.attentions.0.proj_out.bias"] = original_state_dict.pop(
"encoder.mid.attn_1.proj_out.bias"
)
# 1.4 Encoder output
converted["encoder.norm_out.gamma"] = original_state_dict.pop("encoder.norm_out.gamma")
converted["encoder.conv_out.conv.weight"] = original_state_dict.pop("encoder.conv_out.conv.weight")
converted["encoder.conv_out.conv.bias"] = original_state_dict.pop("encoder.conv_out.conv.bias")
# 2. Decoder
# 2.1 conv_in
converted["decoder.conv_in.conv.weight"] = original_state_dict.pop("decoder.conv_in.conv.weight")
converted["decoder.conv_in.conv.bias"] = original_state_dict.pop("decoder.conv_in.conv.bias")
# 2.2 Mid block
converted["decoder.mid_block.resnets.0.norm1.gamma"] = original_state_dict.pop("decoder.mid.block_1.norm1.gamma")
converted["decoder.mid_block.resnets.0.conv1.conv.weight"] = original_state_dict.pop(
"decoder.mid.block_1.conv1.conv.weight"
)
converted["decoder.mid_block.resnets.0.conv1.conv.bias"] = original_state_dict.pop(
"decoder.mid.block_1.conv1.conv.bias"
)
converted["decoder.mid_block.resnets.0.norm2.gamma"] = original_state_dict.pop("decoder.mid.block_1.norm2.gamma")
converted["decoder.mid_block.resnets.0.conv2.conv.weight"] = original_state_dict.pop(
"decoder.mid.block_1.conv2.conv.weight"
)
converted["decoder.mid_block.resnets.0.conv2.conv.bias"] = original_state_dict.pop(
"decoder.mid.block_1.conv2.conv.bias"
)
converted["decoder.mid_block.resnets.1.norm1.gamma"] = original_state_dict.pop("decoder.mid.block_2.norm1.gamma")
converted["decoder.mid_block.resnets.1.conv1.conv.weight"] = original_state_dict.pop(
"decoder.mid.block_2.conv1.conv.weight"
)
converted["decoder.mid_block.resnets.1.conv1.conv.bias"] = original_state_dict.pop(
"decoder.mid.block_2.conv1.conv.bias"
)
converted["decoder.mid_block.resnets.1.norm2.gamma"] = original_state_dict.pop("decoder.mid.block_2.norm2.gamma")
converted["decoder.mid_block.resnets.1.conv2.conv.weight"] = original_state_dict.pop(
"decoder.mid.block_2.conv2.conv.weight"
)
converted["decoder.mid_block.resnets.1.conv2.conv.bias"] = original_state_dict.pop(
"decoder.mid.block_2.conv2.conv.bias"
)
# Decoder attention block
converted["decoder.mid_block.attentions.0.norm.gamma"] = original_state_dict.pop("decoder.mid.attn_1.norm.gamma")
converted["decoder.mid_block.attentions.0.to_q.weight"] = original_state_dict.pop("decoder.mid.attn_1.q.weight")
converted["decoder.mid_block.attentions.0.to_q.bias"] = original_state_dict.pop("decoder.mid.attn_1.q.bias")
converted["decoder.mid_block.attentions.0.to_k.weight"] = original_state_dict.pop("decoder.mid.attn_1.k.weight")
converted["decoder.mid_block.attentions.0.to_k.bias"] = original_state_dict.pop("decoder.mid.attn_1.k.bias")
converted["decoder.mid_block.attentions.0.to_v.weight"] = original_state_dict.pop("decoder.mid.attn_1.v.weight")
converted["decoder.mid_block.attentions.0.to_v.bias"] = original_state_dict.pop("decoder.mid.attn_1.v.bias")
converted["decoder.mid_block.attentions.0.proj_out.weight"] = original_state_dict.pop(
"decoder.mid.attn_1.proj_out.weight"
)
converted["decoder.mid_block.attentions.0.proj_out.bias"] = original_state_dict.pop(
"decoder.mid.attn_1.proj_out.bias"
)
# 2.3 Up blocks
for up_block_index in range(len(block_out_channels)): # 0 to 5
# ResNet blocks
for resnet_block_index in range(layers_per_block + 1): # 0 to 2 (decoder has 3 resnets per level)
converted[f"decoder.up_blocks.{up_block_index}.resnets.{resnet_block_index}.norm1.gamma"] = (
original_state_dict.pop(f"decoder.up.{up_block_index}.block.{resnet_block_index}.norm1.gamma")
)
converted[f"decoder.up_blocks.{up_block_index}.resnets.{resnet_block_index}.conv1.conv.weight"] = (
original_state_dict.pop(f"decoder.up.{up_block_index}.block.{resnet_block_index}.conv1.conv.weight")
)
converted[f"decoder.up_blocks.{up_block_index}.resnets.{resnet_block_index}.conv1.conv.bias"] = (
original_state_dict.pop(f"decoder.up.{up_block_index}.block.{resnet_block_index}.conv1.conv.bias")
)
converted[f"decoder.up_blocks.{up_block_index}.resnets.{resnet_block_index}.norm2.gamma"] = (
original_state_dict.pop(f"decoder.up.{up_block_index}.block.{resnet_block_index}.norm2.gamma")
)
converted[f"decoder.up_blocks.{up_block_index}.resnets.{resnet_block_index}.conv2.conv.weight"] = (
original_state_dict.pop(f"decoder.up.{up_block_index}.block.{resnet_block_index}.conv2.conv.weight")
)
converted[f"decoder.up_blocks.{up_block_index}.resnets.{resnet_block_index}.conv2.conv.bias"] = (
original_state_dict.pop(f"decoder.up.{up_block_index}.block.{resnet_block_index}.conv2.conv.bias")
)
# Upsample (if exists)
if f"decoder.up.{up_block_index}.upsample.conv.conv.weight" in original_state_dict:
converted[f"decoder.up_blocks.{up_block_index}.upsamplers.0.conv.conv.weight"] = original_state_dict.pop(
f"decoder.up.{up_block_index}.upsample.conv.conv.weight"
)
converted[f"decoder.up_blocks.{up_block_index}.upsamplers.0.conv.conv.bias"] = original_state_dict.pop(
f"decoder.up.{up_block_index}.upsample.conv.conv.bias"
)
# 2.4 Decoder output
converted["decoder.norm_out.gamma"] = original_state_dict.pop("decoder.norm_out.gamma")
converted["decoder.conv_out.conv.weight"] = original_state_dict.pop("decoder.conv_out.conv.weight")
converted["decoder.conv_out.conv.bias"] = original_state_dict.pop("decoder.conv_out.conv.bias")
return converted, original_state_dict
def main(args):
if args.model_type == "hunyuanimage2.1":
original_transformer_state_dict = load_original_transformer_checkpoint(args)
original_vae_state_dict = load_original_vae_checkpoint(args)
transformer_config = {
"in_channels": 64,
"out_channels": 64,
"num_attention_heads": 28,
"attention_head_dim": 128,
"num_layers": 20,
"num_single_layers": 40,
"num_refiner_layers": 2,
"patch_size": (1, 1),
"qk_norm": "rms_norm",
"guidance_embeds": False,
"text_embed_dim": 3584,
"text_embed_2_dim": 1472,
"rope_theta": 256.0,
"rope_axes_dim": (64, 64),
}
converted_transformer_state_dict, original_transformer_state_dict = (
convert_hunyuan_image_transformer_checkpoint_to_diffusers(
original_transformer_state_dict, use_byt5=True, guidance_distilled=False
)
)
if original_transformer_state_dict:
logger.warning(
f"Unused {len(original_transformer_state_dict)} original keys for transformer: {list(original_transformer_state_dict.keys())}"
)
transformer = HunyuanImageTransformer2DModel(**transformer_config)
missing_keys, unexpected_key = transformer.load_state_dict(converted_transformer_state_dict, strict=True)
if missing_keys:
logger.warning(f"Missing keys for transformer: {missing_keys}")
if unexpected_key:
logger.warning(f"Unexpected keys for transformer: {unexpected_key}")
transformer.to(dtype).save_pretrained(f"{args.output_path}/transformer")
vae_config_diffusers = {
"in_channels": 3,
"out_channels": 3,
"latent_channels": 64,
"block_out_channels": [128, 256, 512, 512, 1024, 1024],
"layers_per_block": 2,
"spatial_compression_ratio": 32,
"sample_size": 384,
"scaling_factor": 0.75289,
"downsample_match_channel": True,
"upsample_match_channel": True,
}
converted_vae_state_dict, original_vae_state_dict = convert_hunyuan_image_vae_checkpoint_to_diffusers(
original_vae_state_dict, block_out_channels=[128, 256, 512, 512, 1024, 1024], layers_per_block=2
)
if original_vae_state_dict:
logger.warning(
f"Unused {len(original_vae_state_dict)} original keys for vae: {list(original_vae_state_dict.keys())}"
)
vae = AutoencoderKLHunyuanImage(**vae_config_diffusers)
missing_keys, unexpected_key = vae.load_state_dict(converted_vae_state_dict, strict=True)
if missing_keys:
logger.warning(f"Missing keys for vae: {missing_keys}")
if unexpected_key:
logger.warning(f"Unexpected keys for vae: {unexpected_key}")
vae.to(dtype).save_pretrained(f"{args.output_path}/vae")
elif args.model_type == "hunyuanimage2.1-distilled":
original_transformer_state_dict = load_original_transformer_checkpoint(args)
original_vae_state_dict = load_original_vae_checkpoint(args)
transformer_config = {
"in_channels": 64,
"out_channels": 64,
"num_attention_heads": 28,
"attention_head_dim": 128,
"num_layers": 20,
"num_single_layers": 40,
"num_refiner_layers": 2,
"patch_size": (1, 1),
"qk_norm": "rms_norm",
"guidance_embeds": True,
"text_embed_dim": 3584,
"text_embed_2_dim": 1472,
"rope_theta": 256.0,
"rope_axes_dim": (64, 64),
"use_meanflow": True,
}
converted_transformer_state_dict, original_transformer_state_dict = (
convert_hunyuan_image_transformer_checkpoint_to_diffusers(
original_transformer_state_dict, use_byt5=True, guidance_distilled=True, use_meanflow=True
)
)
if original_transformer_state_dict:
logger.warning(
f"Unused {len(original_transformer_state_dict)} original keys for transformer: {list(original_transformer_state_dict.keys())}"
)
transformer = HunyuanImageTransformer2DModel(**transformer_config)
missing_keys, unexpected_key = transformer.load_state_dict(converted_transformer_state_dict, strict=True)
if missing_keys:
logger.warning(f"Missing keys for transformer: {missing_keys}")
if unexpected_key:
logger.warning(f"Unexpected keys for transformer: {unexpected_key}")
transformer.to(dtype).save_pretrained(f"{args.output_path}/transformer")
vae_config_diffusers = {
"in_channels": 3,
"out_channels": 3,
"latent_channels": 64,
"block_out_channels": [128, 256, 512, 512, 1024, 1024],
"layers_per_block": 2,
"spatial_compression_ratio": 32,
"sample_size": 384,
"scaling_factor": 0.75289,
"downsample_match_channel": True,
"upsample_match_channel": True,
}
converted_vae_state_dict, original_vae_state_dict = convert_hunyuan_image_vae_checkpoint_to_diffusers(
original_vae_state_dict, block_out_channels=[128, 256, 512, 512, 1024, 1024], layers_per_block=2
)
if original_vae_state_dict:
logger.warning(
f"Unused {len(original_vae_state_dict)} original keys for vae: {list(original_vae_state_dict.keys())}"
)
vae = AutoencoderKLHunyuanImage(**vae_config_diffusers)
missing_keys, unexpected_key = vae.load_state_dict(converted_vae_state_dict, strict=True)
if missing_keys:
logger.warning(f"Missing keys for vae: {missing_keys}")
if unexpected_key:
logger.warning(f"Unexpected keys for vae: {unexpected_key}")
vae.to(dtype).save_pretrained(f"{args.output_path}/vae")
elif args.model_type == "hunyuanimage-refiner":
original_transformer_state_dict = load_original_transformer_checkpoint(args)
original_vae_state_dict = load_original_refiner_vae_checkpoint(args)
transformer_config = {
"in_channels": 128,
"out_channels": 64,
"num_layers": 20,
"num_single_layers": 40,
"rope_axes_dim": [16, 56, 56],
"num_attention_heads": 26,
"attention_head_dim": 128,
"mlp_ratio": 4,
"patch_size": (1, 1, 1),
"text_embed_dim": 3584,
"guidance_embeds": True,
}
converted_transformer_state_dict, original_transformer_state_dict = (
convert_hunyuan_image_transformer_checkpoint_to_diffusers(
original_transformer_state_dict, use_byt5=False, guidance_distilled=True
)
)
if original_transformer_state_dict:
logger.warning(
f"Unused {len(original_transformer_state_dict)} original keys for transformer: {list(original_transformer_state_dict.keys())}"
)
transformer = HunyuanImageTransformer2DModel(**transformer_config)
missing_keys, unexpected_key = transformer.load_state_dict(converted_transformer_state_dict, strict=True)
if missing_keys:
logger.warning(f"Missing keys for transformer: {missing_keys}")
if unexpected_key:
logger.warning(f"Unexpected keys for transformer: {unexpected_key}")
transformer.to(dtype).save_pretrained(f"{args.output_path}/transformer")
vae = AutoencoderKLHunyuanImageRefiner()
converted_vae_state_dict, original_vae_state_dict = convert_hunyuan_image_refiner_vae_checkpoint_to_diffusers(
original_vae_state_dict
)
if original_vae_state_dict:
logger.warning(
f"Unused {len(original_vae_state_dict)} original keys for vae: {list(original_vae_state_dict.keys())}"
)
missing_keys, unexpected_key = vae.load_state_dict(converted_vae_state_dict, strict=True)
logger.warning(f"Missing keys for vae: {missing_keys}")
logger.warning(f"Unexpected keys for vae: {unexpected_key}")
vae.to(dtype).save_pretrained(f"{args.output_path}/vae")
if __name__ == "__main__":
main(args)