|
|
|
|
@@ -106,6 +106,7 @@ CHECKPOINT_KEY_NAMES = {
|
|
|
|
|
],
|
|
|
|
|
"autoencoder-dc": "decoder.stages.1.op_list.0.main.conv.conv.bias",
|
|
|
|
|
"autoencoder-dc-sana": "encoder.project_in.conv.bias",
|
|
|
|
|
"mochi-1-preview": ["model.diffusion_model.blocks.0.attn.qkv_x.weight", "blocks.0.attn.qkv_x.weight"],
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
|
|
|
|
|
@@ -157,6 +158,7 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
|
|
|
|
|
"autoencoder-dc-f64c128": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f64c128-mix-1.0-diffusers"},
|
|
|
|
|
"autoencoder-dc-f32c32": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f32c32-mix-1.0-diffusers"},
|
|
|
|
|
"autoencoder-dc-f32c32-sana": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers"},
|
|
|
|
|
"mochi-1-preview": {"pretrained_model_name_or_path": "genmo/mochi-1-preview"},
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
# Use to configure model sample size when original config is provided
|
|
|
|
|
@@ -610,6 +612,9 @@ def infer_diffusers_model_type(checkpoint):
|
|
|
|
|
else:
|
|
|
|
|
model_type = "autoencoder-dc-f128c512"
|
|
|
|
|
|
|
|
|
|
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["mochi-1-preview"]):
|
|
|
|
|
model_type = "mochi-1-preview"
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
model_type = "v1"
|
|
|
|
|
|
|
|
|
|
@@ -1750,6 +1755,12 @@ def swap_scale_shift(weight, dim):
|
|
|
|
|
return new_weight
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def swap_proj_gate(weight):
|
|
|
|
|
proj, gate = weight.chunk(2, dim=0)
|
|
|
|
|
new_weight = torch.cat([gate, proj], dim=0)
|
|
|
|
|
return new_weight
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_attn2_layers(state_dict):
|
|
|
|
|
attn2_layers = []
|
|
|
|
|
for key in state_dict.keys():
|
|
|
|
|
@@ -2406,3 +2417,101 @@ def convert_autoencoder_dc_checkpoint_to_diffusers(checkpoint, **kwargs):
|
|
|
|
|
handler_fn_inplace(key, converted_state_dict)
|
|
|
|
|
|
|
|
|
|
return converted_state_dict
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def convert_mochi_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
|
|
|
|
|
new_state_dict = {}
|
|
|
|
|
|
|
|
|
|
# Comfy checkpoints add this prefix
|
|
|
|
|
keys = list(checkpoint.keys())
|
|
|
|
|
for k in keys:
|
|
|
|
|
if "model.diffusion_model." in k:
|
|
|
|
|
checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
|
|
|
|
|
|
|
|
|
|
# Convert patch_embed
|
|
|
|
|
new_state_dict["patch_embed.proj.weight"] = checkpoint.pop("x_embedder.proj.weight")
|
|
|
|
|
new_state_dict["patch_embed.proj.bias"] = checkpoint.pop("x_embedder.proj.bias")
|
|
|
|
|
|
|
|
|
|
# Convert time_embed
|
|
|
|
|
new_state_dict["time_embed.timestep_embedder.linear_1.weight"] = checkpoint.pop("t_embedder.mlp.0.weight")
|
|
|
|
|
new_state_dict["time_embed.timestep_embedder.linear_1.bias"] = checkpoint.pop("t_embedder.mlp.0.bias")
|
|
|
|
|
new_state_dict["time_embed.timestep_embedder.linear_2.weight"] = checkpoint.pop("t_embedder.mlp.2.weight")
|
|
|
|
|
new_state_dict["time_embed.timestep_embedder.linear_2.bias"] = checkpoint.pop("t_embedder.mlp.2.bias")
|
|
|
|
|
new_state_dict["time_embed.pooler.to_kv.weight"] = checkpoint.pop("t5_y_embedder.to_kv.weight")
|
|
|
|
|
new_state_dict["time_embed.pooler.to_kv.bias"] = checkpoint.pop("t5_y_embedder.to_kv.bias")
|
|
|
|
|
new_state_dict["time_embed.pooler.to_q.weight"] = checkpoint.pop("t5_y_embedder.to_q.weight")
|
|
|
|
|
new_state_dict["time_embed.pooler.to_q.bias"] = checkpoint.pop("t5_y_embedder.to_q.bias")
|
|
|
|
|
new_state_dict["time_embed.pooler.to_out.weight"] = checkpoint.pop("t5_y_embedder.to_out.weight")
|
|
|
|
|
new_state_dict["time_embed.pooler.to_out.bias"] = checkpoint.pop("t5_y_embedder.to_out.bias")
|
|
|
|
|
new_state_dict["time_embed.caption_proj.weight"] = checkpoint.pop("t5_yproj.weight")
|
|
|
|
|
new_state_dict["time_embed.caption_proj.bias"] = checkpoint.pop("t5_yproj.bias")
|
|
|
|
|
|
|
|
|
|
# Convert transformer blocks
|
|
|
|
|
num_layers = 48
|
|
|
|
|
for i in range(num_layers):
|
|
|
|
|
block_prefix = f"transformer_blocks.{i}."
|
|
|
|
|
old_prefix = f"blocks.{i}."
|
|
|
|
|
|
|
|
|
|
# norm1
|
|
|
|
|
new_state_dict[block_prefix + "norm1.linear.weight"] = checkpoint.pop(old_prefix + "mod_x.weight")
|
|
|
|
|
new_state_dict[block_prefix + "norm1.linear.bias"] = checkpoint.pop(old_prefix + "mod_x.bias")
|
|
|
|
|
if i < num_layers - 1:
|
|
|
|
|
new_state_dict[block_prefix + "norm1_context.linear.weight"] = checkpoint.pop(old_prefix + "mod_y.weight")
|
|
|
|
|
new_state_dict[block_prefix + "norm1_context.linear.bias"] = checkpoint.pop(old_prefix + "mod_y.bias")
|
|
|
|
|
else:
|
|
|
|
|
new_state_dict[block_prefix + "norm1_context.linear_1.weight"] = checkpoint.pop(
|
|
|
|
|
old_prefix + "mod_y.weight"
|
|
|
|
|
)
|
|
|
|
|
new_state_dict[block_prefix + "norm1_context.linear_1.bias"] = checkpoint.pop(old_prefix + "mod_y.bias")
|
|
|
|
|
|
|
|
|
|
# Visual attention
|
|
|
|
|
qkv_weight = checkpoint.pop(old_prefix + "attn.qkv_x.weight")
|
|
|
|
|
q, k, v = qkv_weight.chunk(3, dim=0)
|
|
|
|
|
|
|
|
|
|
new_state_dict[block_prefix + "attn1.to_q.weight"] = q
|
|
|
|
|
new_state_dict[block_prefix + "attn1.to_k.weight"] = k
|
|
|
|
|
new_state_dict[block_prefix + "attn1.to_v.weight"] = v
|
|
|
|
|
new_state_dict[block_prefix + "attn1.norm_q.weight"] = checkpoint.pop(old_prefix + "attn.q_norm_x.weight")
|
|
|
|
|
new_state_dict[block_prefix + "attn1.norm_k.weight"] = checkpoint.pop(old_prefix + "attn.k_norm_x.weight")
|
|
|
|
|
new_state_dict[block_prefix + "attn1.to_out.0.weight"] = checkpoint.pop(old_prefix + "attn.proj_x.weight")
|
|
|
|
|
new_state_dict[block_prefix + "attn1.to_out.0.bias"] = checkpoint.pop(old_prefix + "attn.proj_x.bias")
|
|
|
|
|
|
|
|
|
|
# Context attention
|
|
|
|
|
qkv_weight = checkpoint.pop(old_prefix + "attn.qkv_y.weight")
|
|
|
|
|
q, k, v = qkv_weight.chunk(3, dim=0)
|
|
|
|
|
|
|
|
|
|
new_state_dict[block_prefix + "attn1.add_q_proj.weight"] = q
|
|
|
|
|
new_state_dict[block_prefix + "attn1.add_k_proj.weight"] = k
|
|
|
|
|
new_state_dict[block_prefix + "attn1.add_v_proj.weight"] = v
|
|
|
|
|
new_state_dict[block_prefix + "attn1.norm_added_q.weight"] = checkpoint.pop(
|
|
|
|
|
old_prefix + "attn.q_norm_y.weight"
|
|
|
|
|
)
|
|
|
|
|
new_state_dict[block_prefix + "attn1.norm_added_k.weight"] = checkpoint.pop(
|
|
|
|
|
old_prefix + "attn.k_norm_y.weight"
|
|
|
|
|
)
|
|
|
|
|
if i < num_layers - 1:
|
|
|
|
|
new_state_dict[block_prefix + "attn1.to_add_out.weight"] = checkpoint.pop(
|
|
|
|
|
old_prefix + "attn.proj_y.weight"
|
|
|
|
|
)
|
|
|
|
|
new_state_dict[block_prefix + "attn1.to_add_out.bias"] = checkpoint.pop(old_prefix + "attn.proj_y.bias")
|
|
|
|
|
|
|
|
|
|
# MLP
|
|
|
|
|
new_state_dict[block_prefix + "ff.net.0.proj.weight"] = swap_proj_gate(
|
|
|
|
|
checkpoint.pop(old_prefix + "mlp_x.w1.weight")
|
|
|
|
|
)
|
|
|
|
|
new_state_dict[block_prefix + "ff.net.2.weight"] = checkpoint.pop(old_prefix + "mlp_x.w2.weight")
|
|
|
|
|
if i < num_layers - 1:
|
|
|
|
|
new_state_dict[block_prefix + "ff_context.net.0.proj.weight"] = swap_proj_gate(
|
|
|
|
|
checkpoint.pop(old_prefix + "mlp_y.w1.weight")
|
|
|
|
|
)
|
|
|
|
|
new_state_dict[block_prefix + "ff_context.net.2.weight"] = checkpoint.pop(old_prefix + "mlp_y.w2.weight")
|
|
|
|
|
|
|
|
|
|
# Output layers
|
|
|
|
|
new_state_dict["norm_out.linear.weight"] = swap_scale_shift(checkpoint.pop("final_layer.mod.weight"), dim=0)
|
|
|
|
|
new_state_dict["norm_out.linear.bias"] = swap_scale_shift(checkpoint.pop("final_layer.mod.bias"), dim=0)
|
|
|
|
|
new_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight")
|
|
|
|
|
new_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias")
|
|
|
|
|
|
|
|
|
|
new_state_dict["pos_frequencies"] = checkpoint.pop("pos_frequencies")
|
|
|
|
|
|
|
|
|
|
return new_state_dict
|
|
|
|
|
|