|
|
|
|
@@ -108,6 +108,7 @@ CHECKPOINT_KEY_NAMES = {
|
|
|
|
|
"autoencoder-dc": "decoder.stages.1.op_list.0.main.conv.conv.bias",
|
|
|
|
|
"autoencoder-dc-sana": "encoder.project_in.conv.bias",
|
|
|
|
|
"mochi-1-preview": ["model.diffusion_model.blocks.0.attn.qkv_x.weight", "blocks.0.attn.qkv_x.weight"],
|
|
|
|
|
"hunyuan-video": "txt_in.individual_token_refiner.blocks.0.adaLN_modulation.1.bias",
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
|
|
|
|
|
@@ -162,6 +163,7 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
|
|
|
|
|
"autoencoder-dc-f32c32": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f32c32-mix-1.0-diffusers"},
|
|
|
|
|
"autoencoder-dc-f32c32-sana": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers"},
|
|
|
|
|
"mochi-1-preview": {"pretrained_model_name_or_path": "genmo/mochi-1-preview"},
|
|
|
|
|
"hunyuan-video": {"pretrained_model_name_or_path": "hunyuanvideo-community/HunyuanVideo"},
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
# Use to configure model sample size when original config is provided
|
|
|
|
|
@@ -624,6 +626,9 @@ def infer_diffusers_model_type(checkpoint):
|
|
|
|
|
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["mochi-1-preview"]):
|
|
|
|
|
model_type = "mochi-1-preview"
|
|
|
|
|
|
|
|
|
|
if CHECKPOINT_KEY_NAMES["hunyuan-video"] in checkpoint:
|
|
|
|
|
model_type = "hunyuan-video"
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
model_type = "v1"
|
|
|
|
|
|
|
|
|
|
@@ -2522,3 +2527,133 @@ def convert_mochi_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
|
|
|
|
|
new_state_dict["pos_frequencies"] = checkpoint.pop("pos_frequencies")
|
|
|
|
|
|
|
|
|
|
return new_state_dict
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def convert_hunyuan_video_transformer_to_diffusers(checkpoint, **kwargs):
|
|
|
|
|
def remap_norm_scale_shift_(key, state_dict):
|
|
|
|
|
weight = state_dict.pop(key)
|
|
|
|
|
shift, scale = weight.chunk(2, dim=0)
|
|
|
|
|
new_weight = torch.cat([scale, shift], dim=0)
|
|
|
|
|
state_dict[key.replace("final_layer.adaLN_modulation.1", "norm_out.linear")] = new_weight
|
|
|
|
|
|
|
|
|
|
def remap_txt_in_(key, state_dict):
|
|
|
|
|
def rename_key(key):
|
|
|
|
|
new_key = key.replace("individual_token_refiner.blocks", "token_refiner.refiner_blocks")
|
|
|
|
|
new_key = new_key.replace("adaLN_modulation.1", "norm_out.linear")
|
|
|
|
|
new_key = new_key.replace("txt_in", "context_embedder")
|
|
|
|
|
new_key = new_key.replace("t_embedder.mlp.0", "time_text_embed.timestep_embedder.linear_1")
|
|
|
|
|
new_key = new_key.replace("t_embedder.mlp.2", "time_text_embed.timestep_embedder.linear_2")
|
|
|
|
|
new_key = new_key.replace("c_embedder", "time_text_embed.text_embedder")
|
|
|
|
|
new_key = new_key.replace("mlp", "ff")
|
|
|
|
|
return new_key
|
|
|
|
|
|
|
|
|
|
if "self_attn_qkv" in key:
|
|
|
|
|
weight = state_dict.pop(key)
|
|
|
|
|
to_q, to_k, to_v = weight.chunk(3, dim=0)
|
|
|
|
|
state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_q"))] = to_q
|
|
|
|
|
state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_k"))] = to_k
|
|
|
|
|
state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_v"))] = to_v
|
|
|
|
|
else:
|
|
|
|
|
state_dict[rename_key(key)] = state_dict.pop(key)
|
|
|
|
|
|
|
|
|
|
def remap_img_attn_qkv_(key, state_dict):
|
|
|
|
|
weight = state_dict.pop(key)
|
|
|
|
|
to_q, to_k, to_v = weight.chunk(3, dim=0)
|
|
|
|
|
state_dict[key.replace("img_attn_qkv", "attn.to_q")] = to_q
|
|
|
|
|
state_dict[key.replace("img_attn_qkv", "attn.to_k")] = to_k
|
|
|
|
|
state_dict[key.replace("img_attn_qkv", "attn.to_v")] = to_v
|
|
|
|
|
|
|
|
|
|
def remap_txt_attn_qkv_(key, state_dict):
|
|
|
|
|
weight = state_dict.pop(key)
|
|
|
|
|
to_q, to_k, to_v = weight.chunk(3, dim=0)
|
|
|
|
|
state_dict[key.replace("txt_attn_qkv", "attn.add_q_proj")] = to_q
|
|
|
|
|
state_dict[key.replace("txt_attn_qkv", "attn.add_k_proj")] = to_k
|
|
|
|
|
state_dict[key.replace("txt_attn_qkv", "attn.add_v_proj")] = to_v
|
|
|
|
|
|
|
|
|
|
def remap_single_transformer_blocks_(key, state_dict):
|
|
|
|
|
hidden_size = 3072
|
|
|
|
|
|
|
|
|
|
if "linear1.weight" in key:
|
|
|
|
|
linear1_weight = state_dict.pop(key)
|
|
|
|
|
split_size = (hidden_size, hidden_size, hidden_size, linear1_weight.size(0) - 3 * hidden_size)
|
|
|
|
|
q, k, v, mlp = torch.split(linear1_weight, split_size, dim=0)
|
|
|
|
|
new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(".linear1.weight")
|
|
|
|
|
state_dict[f"{new_key}.attn.to_q.weight"] = q
|
|
|
|
|
state_dict[f"{new_key}.attn.to_k.weight"] = k
|
|
|
|
|
state_dict[f"{new_key}.attn.to_v.weight"] = v
|
|
|
|
|
state_dict[f"{new_key}.proj_mlp.weight"] = mlp
|
|
|
|
|
|
|
|
|
|
elif "linear1.bias" in key:
|
|
|
|
|
linear1_bias = state_dict.pop(key)
|
|
|
|
|
split_size = (hidden_size, hidden_size, hidden_size, linear1_bias.size(0) - 3 * hidden_size)
|
|
|
|
|
q_bias, k_bias, v_bias, mlp_bias = torch.split(linear1_bias, split_size, dim=0)
|
|
|
|
|
new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(".linear1.bias")
|
|
|
|
|
state_dict[f"{new_key}.attn.to_q.bias"] = q_bias
|
|
|
|
|
state_dict[f"{new_key}.attn.to_k.bias"] = k_bias
|
|
|
|
|
state_dict[f"{new_key}.attn.to_v.bias"] = v_bias
|
|
|
|
|
state_dict[f"{new_key}.proj_mlp.bias"] = mlp_bias
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
new_key = key.replace("single_blocks", "single_transformer_blocks")
|
|
|
|
|
new_key = new_key.replace("linear2", "proj_out")
|
|
|
|
|
new_key = new_key.replace("q_norm", "attn.norm_q")
|
|
|
|
|
new_key = new_key.replace("k_norm", "attn.norm_k")
|
|
|
|
|
state_dict[new_key] = state_dict.pop(key)
|
|
|
|
|
|
|
|
|
|
TRANSFORMER_KEYS_RENAME_DICT = {
|
|
|
|
|
"img_in": "x_embedder",
|
|
|
|
|
"time_in.mlp.0": "time_text_embed.timestep_embedder.linear_1",
|
|
|
|
|
"time_in.mlp.2": "time_text_embed.timestep_embedder.linear_2",
|
|
|
|
|
"guidance_in.mlp.0": "time_text_embed.guidance_embedder.linear_1",
|
|
|
|
|
"guidance_in.mlp.2": "time_text_embed.guidance_embedder.linear_2",
|
|
|
|
|
"vector_in.in_layer": "time_text_embed.text_embedder.linear_1",
|
|
|
|
|
"vector_in.out_layer": "time_text_embed.text_embedder.linear_2",
|
|
|
|
|
"double_blocks": "transformer_blocks",
|
|
|
|
|
"img_attn_q_norm": "attn.norm_q",
|
|
|
|
|
"img_attn_k_norm": "attn.norm_k",
|
|
|
|
|
"img_attn_proj": "attn.to_out.0",
|
|
|
|
|
"txt_attn_q_norm": "attn.norm_added_q",
|
|
|
|
|
"txt_attn_k_norm": "attn.norm_added_k",
|
|
|
|
|
"txt_attn_proj": "attn.to_add_out",
|
|
|
|
|
"img_mod.linear": "norm1.linear",
|
|
|
|
|
"img_norm1": "norm1.norm",
|
|
|
|
|
"img_norm2": "norm2",
|
|
|
|
|
"img_mlp": "ff",
|
|
|
|
|
"txt_mod.linear": "norm1_context.linear",
|
|
|
|
|
"txt_norm1": "norm1.norm",
|
|
|
|
|
"txt_norm2": "norm2_context",
|
|
|
|
|
"txt_mlp": "ff_context",
|
|
|
|
|
"self_attn_proj": "attn.to_out.0",
|
|
|
|
|
"modulation.linear": "norm.linear",
|
|
|
|
|
"pre_norm": "norm.norm",
|
|
|
|
|
"final_layer.norm_final": "norm_out.norm",
|
|
|
|
|
"final_layer.linear": "proj_out",
|
|
|
|
|
"fc1": "net.0.proj",
|
|
|
|
|
"fc2": "net.2",
|
|
|
|
|
"input_embedder": "proj_in",
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TRANSFORMER_SPECIAL_KEYS_REMAP = {
|
|
|
|
|
"txt_in": remap_txt_in_,
|
|
|
|
|
"img_attn_qkv": remap_img_attn_qkv_,
|
|
|
|
|
"txt_attn_qkv": remap_txt_attn_qkv_,
|
|
|
|
|
"single_blocks": remap_single_transformer_blocks_,
|
|
|
|
|
"final_layer.adaLN_modulation.1": remap_norm_scale_shift_,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
def update_state_dict_(state_dict, old_key, new_key):
|
|
|
|
|
state_dict[new_key] = state_dict.pop(old_key)
|
|
|
|
|
|
|
|
|
|
for key in list(checkpoint.keys()):
|
|
|
|
|
new_key = key[:]
|
|
|
|
|
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
|
|
|
|
|
new_key = new_key.replace(replace_key, rename_key)
|
|
|
|
|
update_state_dict_(checkpoint, key, new_key)
|
|
|
|
|
|
|
|
|
|
for key in list(checkpoint.keys()):
|
|
|
|
|
for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
|
|
|
|
|
if special_key not in key:
|
|
|
|
|
continue
|
|
|
|
|
handler_fn_inplace(key, checkpoint)
|
|
|
|
|
|
|
|
|
|
return checkpoint
|
|
|
|
|
|