mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-06 12:34:13 +08:00
1251 lines
50 KiB
Python
1251 lines
50 KiB
Python
import argparse
|
|
import inspect
|
|
import os
|
|
|
|
import numpy as np
|
|
import torch
|
|
import yaml
|
|
from torch.nn import functional as F
|
|
from transformers import CLIPConfig, CLIPImageProcessor, CLIPVisionModelWithProjection, T5EncoderModel, T5Tokenizer
|
|
|
|
from diffusers import DDPMScheduler, IFPipeline, IFSuperResolutionPipeline, UNet2DConditionModel
|
|
from diffusers.pipelines.deepfloyd_if.safety_checker import IFSafetyChecker
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser()
|
|
|
|
parser.add_argument("--dump_path", required=False, default=None, type=str)
|
|
|
|
parser.add_argument("--dump_path_stage_2", required=False, default=None, type=str)
|
|
|
|
parser.add_argument("--dump_path_stage_3", required=False, default=None, type=str)
|
|
|
|
parser.add_argument("--unet_config", required=False, default=None, type=str, help="Path to unet config file")
|
|
|
|
parser.add_argument(
|
|
"--unet_checkpoint_path", required=False, default=None, type=str, help="Path to unet checkpoint file"
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--unet_checkpoint_path_stage_2",
|
|
required=False,
|
|
default=None,
|
|
type=str,
|
|
help="Path to stage 2 unet checkpoint file",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--unet_checkpoint_path_stage_3",
|
|
required=False,
|
|
default=None,
|
|
type=str,
|
|
help="Path to stage 3 unet checkpoint file",
|
|
)
|
|
|
|
parser.add_argument("--p_head_path", type=str, required=True)
|
|
|
|
parser.add_argument("--w_head_path", type=str, required=True)
|
|
|
|
args = parser.parse_args()
|
|
|
|
return args
|
|
|
|
|
|
def main(args):
|
|
tokenizer = T5Tokenizer.from_pretrained("google/t5-v1_1-xxl")
|
|
text_encoder = T5EncoderModel.from_pretrained("google/t5-v1_1-xxl")
|
|
|
|
feature_extractor = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14")
|
|
safety_checker = convert_safety_checker(p_head_path=args.p_head_path, w_head_path=args.w_head_path)
|
|
|
|
if args.unet_config is not None and args.unet_checkpoint_path is not None and args.dump_path is not None:
|
|
convert_stage_1_pipeline(tokenizer, text_encoder, feature_extractor, safety_checker, args)
|
|
|
|
if args.unet_checkpoint_path_stage_2 is not None and args.dump_path_stage_2 is not None:
|
|
convert_super_res_pipeline(tokenizer, text_encoder, feature_extractor, safety_checker, args, stage=2)
|
|
|
|
if args.unet_checkpoint_path_stage_3 is not None and args.dump_path_stage_3 is not None:
|
|
convert_super_res_pipeline(tokenizer, text_encoder, feature_extractor, safety_checker, args, stage=3)
|
|
|
|
|
|
def convert_stage_1_pipeline(tokenizer, text_encoder, feature_extractor, safety_checker, args):
|
|
unet = get_stage_1_unet(args.unet_config, args.unet_checkpoint_path)
|
|
|
|
scheduler = DDPMScheduler(
|
|
variance_type="learned_range",
|
|
beta_schedule="squaredcos_cap_v2",
|
|
prediction_type="epsilon",
|
|
thresholding=True,
|
|
dynamic_thresholding_ratio=0.95,
|
|
sample_max_value=1.5,
|
|
)
|
|
|
|
pipe = IFPipeline(
|
|
tokenizer=tokenizer,
|
|
text_encoder=text_encoder,
|
|
unet=unet,
|
|
scheduler=scheduler,
|
|
safety_checker=safety_checker,
|
|
feature_extractor=feature_extractor,
|
|
requires_safety_checker=True,
|
|
)
|
|
|
|
pipe.save_pretrained(args.dump_path)
|
|
|
|
|
|
def convert_super_res_pipeline(tokenizer, text_encoder, feature_extractor, safety_checker, args, stage):
|
|
if stage == 2:
|
|
unet_checkpoint_path = args.unet_checkpoint_path_stage_2
|
|
sample_size = None
|
|
dump_path = args.dump_path_stage_2
|
|
elif stage == 3:
|
|
unet_checkpoint_path = args.unet_checkpoint_path_stage_3
|
|
sample_size = 1024
|
|
dump_path = args.dump_path_stage_3
|
|
else:
|
|
assert False
|
|
|
|
unet = get_super_res_unet(unet_checkpoint_path, verify_param_count=False, sample_size=sample_size)
|
|
|
|
image_noising_scheduler = DDPMScheduler(
|
|
beta_schedule="squaredcos_cap_v2",
|
|
)
|
|
|
|
scheduler = DDPMScheduler(
|
|
variance_type="learned_range",
|
|
beta_schedule="squaredcos_cap_v2",
|
|
prediction_type="epsilon",
|
|
thresholding=True,
|
|
dynamic_thresholding_ratio=0.95,
|
|
sample_max_value=1.0,
|
|
)
|
|
|
|
pipe = IFSuperResolutionPipeline(
|
|
tokenizer=tokenizer,
|
|
text_encoder=text_encoder,
|
|
unet=unet,
|
|
scheduler=scheduler,
|
|
image_noising_scheduler=image_noising_scheduler,
|
|
safety_checker=safety_checker,
|
|
feature_extractor=feature_extractor,
|
|
requires_safety_checker=True,
|
|
)
|
|
|
|
pipe.save_pretrained(dump_path)
|
|
|
|
|
|
def get_stage_1_unet(unet_config, unet_checkpoint_path):
|
|
original_unet_config = yaml.safe_load(unet_config)
|
|
original_unet_config = original_unet_config["params"]
|
|
|
|
unet_diffusers_config = create_unet_diffusers_config(original_unet_config)
|
|
|
|
unet = UNet2DConditionModel(**unet_diffusers_config)
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
unet_checkpoint = torch.load(unet_checkpoint_path, map_location=device)
|
|
|
|
converted_unet_checkpoint = convert_ldm_unet_checkpoint(
|
|
unet_checkpoint, unet_diffusers_config, path=unet_checkpoint_path
|
|
)
|
|
|
|
unet.load_state_dict(converted_unet_checkpoint)
|
|
|
|
return unet
|
|
|
|
|
|
def convert_safety_checker(p_head_path, w_head_path):
|
|
state_dict = {}
|
|
|
|
# p head
|
|
|
|
p_head = np.load(p_head_path)
|
|
|
|
p_head_weights = p_head["weights"]
|
|
p_head_weights = torch.from_numpy(p_head_weights)
|
|
p_head_weights = p_head_weights.unsqueeze(0)
|
|
|
|
p_head_biases = p_head["biases"]
|
|
p_head_biases = torch.from_numpy(p_head_biases)
|
|
p_head_biases = p_head_biases.unsqueeze(0)
|
|
|
|
state_dict["p_head.weight"] = p_head_weights
|
|
state_dict["p_head.bias"] = p_head_biases
|
|
|
|
# w head
|
|
|
|
w_head = np.load(w_head_path)
|
|
|
|
w_head_weights = w_head["weights"]
|
|
w_head_weights = torch.from_numpy(w_head_weights)
|
|
w_head_weights = w_head_weights.unsqueeze(0)
|
|
|
|
w_head_biases = w_head["biases"]
|
|
w_head_biases = torch.from_numpy(w_head_biases)
|
|
w_head_biases = w_head_biases.unsqueeze(0)
|
|
|
|
state_dict["w_head.weight"] = w_head_weights
|
|
state_dict["w_head.bias"] = w_head_biases
|
|
|
|
# vision model
|
|
|
|
vision_model = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14")
|
|
vision_model_state_dict = vision_model.state_dict()
|
|
|
|
for key, value in vision_model_state_dict.items():
|
|
key = f"vision_model.{key}"
|
|
state_dict[key] = value
|
|
|
|
# full model
|
|
|
|
config = CLIPConfig.from_pretrained("openai/clip-vit-large-patch14")
|
|
safety_checker = IFSafetyChecker(config)
|
|
|
|
safety_checker.load_state_dict(state_dict)
|
|
|
|
return safety_checker
|
|
|
|
|
|
def create_unet_diffusers_config(original_unet_config, class_embed_type=None):
|
|
attention_resolutions = parse_list(original_unet_config["attention_resolutions"])
|
|
attention_resolutions = [original_unet_config["image_size"] // int(res) for res in attention_resolutions]
|
|
|
|
channel_mult = parse_list(original_unet_config["channel_mult"])
|
|
block_out_channels = [original_unet_config["model_channels"] * mult for mult in channel_mult]
|
|
|
|
down_block_types = []
|
|
resolution = 1
|
|
|
|
for i in range(len(block_out_channels)):
|
|
if resolution in attention_resolutions:
|
|
block_type = "SimpleCrossAttnDownBlock2D"
|
|
elif original_unet_config["resblock_updown"]:
|
|
block_type = "ResnetDownsampleBlock2D"
|
|
else:
|
|
block_type = "DownBlock2D"
|
|
|
|
down_block_types.append(block_type)
|
|
|
|
if i != len(block_out_channels) - 1:
|
|
resolution *= 2
|
|
|
|
up_block_types = []
|
|
for i in range(len(block_out_channels)):
|
|
if resolution in attention_resolutions:
|
|
block_type = "SimpleCrossAttnUpBlock2D"
|
|
elif original_unet_config["resblock_updown"]:
|
|
block_type = "ResnetUpsampleBlock2D"
|
|
else:
|
|
block_type = "UpBlock2D"
|
|
up_block_types.append(block_type)
|
|
resolution //= 2
|
|
|
|
head_dim = original_unet_config["num_head_channels"]
|
|
|
|
use_linear_projection = (
|
|
original_unet_config["use_linear_in_transformer"]
|
|
if "use_linear_in_transformer" in original_unet_config
|
|
else False
|
|
)
|
|
if use_linear_projection:
|
|
# stable diffusion 2-base-512 and 2-768
|
|
if head_dim is None:
|
|
head_dim = [5, 10, 20, 20]
|
|
|
|
projection_class_embeddings_input_dim = None
|
|
|
|
if class_embed_type is None:
|
|
if "num_classes" in original_unet_config:
|
|
if original_unet_config["num_classes"] == "sequential":
|
|
class_embed_type = "projection"
|
|
assert "adm_in_channels" in original_unet_config
|
|
projection_class_embeddings_input_dim = original_unet_config["adm_in_channels"]
|
|
else:
|
|
raise NotImplementedError(
|
|
f"Unknown conditional unet num_classes config: {original_unet_config['num_classes']}"
|
|
)
|
|
|
|
config = {
|
|
"sample_size": original_unet_config["image_size"],
|
|
"in_channels": original_unet_config["in_channels"],
|
|
"down_block_types": tuple(down_block_types),
|
|
"block_out_channels": tuple(block_out_channels),
|
|
"layers_per_block": original_unet_config["num_res_blocks"],
|
|
"cross_attention_dim": original_unet_config["encoder_channels"],
|
|
"attention_head_dim": head_dim,
|
|
"use_linear_projection": use_linear_projection,
|
|
"class_embed_type": class_embed_type,
|
|
"projection_class_embeddings_input_dim": projection_class_embeddings_input_dim,
|
|
"out_channels": original_unet_config["out_channels"],
|
|
"up_block_types": tuple(up_block_types),
|
|
"upcast_attention": False, # TODO: guessing
|
|
"cross_attention_norm": "group_norm",
|
|
"mid_block_type": "UNetMidBlock2DSimpleCrossAttn",
|
|
"addition_embed_type": "text",
|
|
"act_fn": "gelu",
|
|
}
|
|
|
|
if original_unet_config["use_scale_shift_norm"]:
|
|
config["resnet_time_scale_shift"] = "scale_shift"
|
|
|
|
if "encoder_dim" in original_unet_config:
|
|
config["encoder_hid_dim"] = original_unet_config["encoder_dim"]
|
|
|
|
return config
|
|
|
|
|
|
def convert_ldm_unet_checkpoint(unet_state_dict, config, path=None):
|
|
"""
|
|
Takes a state dict and a config, and returns a converted checkpoint.
|
|
"""
|
|
new_checkpoint = {}
|
|
|
|
new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
|
|
new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
|
|
new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
|
|
new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
|
|
|
|
if config["class_embed_type"] in [None, "identity"]:
|
|
# No parameters to port
|
|
...
|
|
elif config["class_embed_type"] == "timestep" or config["class_embed_type"] == "projection":
|
|
new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"]
|
|
new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"]
|
|
new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"]
|
|
new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"]
|
|
else:
|
|
raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}")
|
|
|
|
new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
|
|
new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
|
|
|
|
new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
|
|
new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
|
|
new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
|
|
new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
|
|
|
|
# Retrieves the keys for the input blocks only
|
|
num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
|
|
input_blocks = {
|
|
layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}." in key]
|
|
for layer_id in range(num_input_blocks)
|
|
}
|
|
|
|
# Retrieves the keys for the middle blocks only
|
|
num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
|
|
middle_blocks = {
|
|
layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
|
|
for layer_id in range(num_middle_blocks)
|
|
}
|
|
|
|
# Retrieves the keys for the output blocks only
|
|
num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
|
|
output_blocks = {
|
|
layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}." in key]
|
|
for layer_id in range(num_output_blocks)
|
|
}
|
|
|
|
for i in range(1, num_input_blocks):
|
|
block_id = (i - 1) // (config["layers_per_block"] + 1)
|
|
layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
|
|
|
|
resnets = [
|
|
key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
|
|
]
|
|
attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
|
|
|
|
if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
|
|
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
|
|
f"input_blocks.{i}.0.op.weight"
|
|
)
|
|
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
|
|
f"input_blocks.{i}.0.op.bias"
|
|
)
|
|
|
|
paths = renew_resnet_paths(resnets)
|
|
|
|
# TODO need better check than i in [4, 8, 12, 16]
|
|
block_type = config["down_block_types"][block_id]
|
|
if (block_type == "ResnetDownsampleBlock2D" or block_type == "SimpleCrossAttnDownBlock2D") and i in [
|
|
4,
|
|
8,
|
|
12,
|
|
16,
|
|
]:
|
|
meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.downsamplers.0"}
|
|
else:
|
|
meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
|
|
|
|
assign_to_checkpoint(
|
|
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
|
)
|
|
|
|
if len(attentions):
|
|
old_path = f"input_blocks.{i}.1"
|
|
new_path = f"down_blocks.{block_id}.attentions.{layer_in_block_id}"
|
|
|
|
assign_attention_to_checkpoint(
|
|
new_checkpoint=new_checkpoint,
|
|
unet_state_dict=unet_state_dict,
|
|
old_path=old_path,
|
|
new_path=new_path,
|
|
config=config,
|
|
)
|
|
|
|
paths = renew_attention_paths(attentions)
|
|
meta_path = {"old": old_path, "new": new_path}
|
|
assign_to_checkpoint(
|
|
paths,
|
|
new_checkpoint,
|
|
unet_state_dict,
|
|
additional_replacements=[meta_path],
|
|
config=config,
|
|
)
|
|
|
|
resnet_0 = middle_blocks[0]
|
|
attentions = middle_blocks[1]
|
|
resnet_1 = middle_blocks[2]
|
|
|
|
resnet_0_paths = renew_resnet_paths(resnet_0)
|
|
assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
|
|
|
|
resnet_1_paths = renew_resnet_paths(resnet_1)
|
|
assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
|
|
|
|
old_path = "middle_block.1"
|
|
new_path = "mid_block.attentions.0"
|
|
|
|
assign_attention_to_checkpoint(
|
|
new_checkpoint=new_checkpoint,
|
|
unet_state_dict=unet_state_dict,
|
|
old_path=old_path,
|
|
new_path=new_path,
|
|
config=config,
|
|
)
|
|
|
|
attentions_paths = renew_attention_paths(attentions)
|
|
meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
|
|
assign_to_checkpoint(
|
|
attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
|
)
|
|
|
|
for i in range(num_output_blocks):
|
|
block_id = i // (config["layers_per_block"] + 1)
|
|
layer_in_block_id = i % (config["layers_per_block"] + 1)
|
|
output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
|
|
output_block_list = {}
|
|
|
|
for layer in output_block_layers:
|
|
layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
|
|
if layer_id in output_block_list:
|
|
output_block_list[layer_id].append(layer_name)
|
|
else:
|
|
output_block_list[layer_id] = [layer_name]
|
|
|
|
# len(output_block_list) == 1 -> resnet
|
|
# len(output_block_list) == 2 -> resnet, attention
|
|
# len(output_block_list) == 3 -> resnet, attention, upscale resnet
|
|
|
|
if len(output_block_list) > 1:
|
|
resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
|
|
attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
|
|
|
|
paths = renew_resnet_paths(resnets)
|
|
|
|
meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
|
|
|
|
assign_to_checkpoint(
|
|
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
|
)
|
|
|
|
output_block_list = {k: sorted(v) for k, v in output_block_list.items()}
|
|
if ["conv.bias", "conv.weight"] in output_block_list.values():
|
|
index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
|
|
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
|
|
f"output_blocks.{i}.{index}.conv.weight"
|
|
]
|
|
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
|
|
f"output_blocks.{i}.{index}.conv.bias"
|
|
]
|
|
|
|
# Clear attentions as they have been attributed above.
|
|
if len(attentions) == 2:
|
|
attentions = []
|
|
|
|
if len(attentions):
|
|
old_path = f"output_blocks.{i}.1"
|
|
new_path = f"up_blocks.{block_id}.attentions.{layer_in_block_id}"
|
|
|
|
assign_attention_to_checkpoint(
|
|
new_checkpoint=new_checkpoint,
|
|
unet_state_dict=unet_state_dict,
|
|
old_path=old_path,
|
|
new_path=new_path,
|
|
config=config,
|
|
)
|
|
|
|
paths = renew_attention_paths(attentions)
|
|
meta_path = {
|
|
"old": old_path,
|
|
"new": new_path,
|
|
}
|
|
assign_to_checkpoint(
|
|
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
|
)
|
|
|
|
if len(output_block_list) == 3:
|
|
resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.2" in key]
|
|
paths = renew_resnet_paths(resnets)
|
|
meta_path = {"old": f"output_blocks.{i}.2", "new": f"up_blocks.{block_id}.upsamplers.0"}
|
|
assign_to_checkpoint(
|
|
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
|
)
|
|
else:
|
|
resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
|
|
for path in resnet_0_paths:
|
|
old_path = ".".join(["output_blocks", str(i), path["old"]])
|
|
new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
|
|
|
|
new_checkpoint[new_path] = unet_state_dict[old_path]
|
|
|
|
if "encoder_proj.weight" in unet_state_dict:
|
|
new_checkpoint["encoder_hid_proj.weight"] = unet_state_dict.pop("encoder_proj.weight")
|
|
new_checkpoint["encoder_hid_proj.bias"] = unet_state_dict.pop("encoder_proj.bias")
|
|
|
|
if "encoder_pooling.0.weight" in unet_state_dict:
|
|
new_checkpoint["add_embedding.norm1.weight"] = unet_state_dict.pop("encoder_pooling.0.weight")
|
|
new_checkpoint["add_embedding.norm1.bias"] = unet_state_dict.pop("encoder_pooling.0.bias")
|
|
|
|
new_checkpoint["add_embedding.pool.positional_embedding"] = unet_state_dict.pop(
|
|
"encoder_pooling.1.positional_embedding"
|
|
)
|
|
new_checkpoint["add_embedding.pool.k_proj.weight"] = unet_state_dict.pop("encoder_pooling.1.k_proj.weight")
|
|
new_checkpoint["add_embedding.pool.k_proj.bias"] = unet_state_dict.pop("encoder_pooling.1.k_proj.bias")
|
|
new_checkpoint["add_embedding.pool.q_proj.weight"] = unet_state_dict.pop("encoder_pooling.1.q_proj.weight")
|
|
new_checkpoint["add_embedding.pool.q_proj.bias"] = unet_state_dict.pop("encoder_pooling.1.q_proj.bias")
|
|
new_checkpoint["add_embedding.pool.v_proj.weight"] = unet_state_dict.pop("encoder_pooling.1.v_proj.weight")
|
|
new_checkpoint["add_embedding.pool.v_proj.bias"] = unet_state_dict.pop("encoder_pooling.1.v_proj.bias")
|
|
|
|
new_checkpoint["add_embedding.proj.weight"] = unet_state_dict.pop("encoder_pooling.2.weight")
|
|
new_checkpoint["add_embedding.proj.bias"] = unet_state_dict.pop("encoder_pooling.2.bias")
|
|
|
|
new_checkpoint["add_embedding.norm2.weight"] = unet_state_dict.pop("encoder_pooling.3.weight")
|
|
new_checkpoint["add_embedding.norm2.bias"] = unet_state_dict.pop("encoder_pooling.3.bias")
|
|
|
|
return new_checkpoint
|
|
|
|
|
|
def shave_segments(path, n_shave_prefix_segments=1):
|
|
"""
|
|
Removes segments. Positive values shave the first segments, negative shave the last segments.
|
|
"""
|
|
if n_shave_prefix_segments >= 0:
|
|
return ".".join(path.split(".")[n_shave_prefix_segments:])
|
|
else:
|
|
return ".".join(path.split(".")[:n_shave_prefix_segments])
|
|
|
|
|
|
def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
|
|
"""
|
|
Updates paths inside resnets to the new naming scheme (local renaming)
|
|
"""
|
|
mapping = []
|
|
for old_item in old_list:
|
|
new_item = old_item.replace("in_layers.0", "norm1")
|
|
new_item = new_item.replace("in_layers.2", "conv1")
|
|
|
|
new_item = new_item.replace("out_layers.0", "norm2")
|
|
new_item = new_item.replace("out_layers.3", "conv2")
|
|
|
|
new_item = new_item.replace("emb_layers.1", "time_emb_proj")
|
|
new_item = new_item.replace("skip_connection", "conv_shortcut")
|
|
|
|
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
|
|
|
mapping.append({"old": old_item, "new": new_item})
|
|
|
|
return mapping
|
|
|
|
|
|
def renew_attention_paths(old_list, n_shave_prefix_segments=0):
|
|
"""
|
|
Updates paths inside attentions to the new naming scheme (local renaming)
|
|
"""
|
|
mapping = []
|
|
for old_item in old_list:
|
|
new_item = old_item
|
|
|
|
if "qkv" in new_item:
|
|
continue
|
|
|
|
if "encoder_kv" in new_item:
|
|
continue
|
|
|
|
new_item = new_item.replace("norm.weight", "group_norm.weight")
|
|
new_item = new_item.replace("norm.bias", "group_norm.bias")
|
|
|
|
new_item = new_item.replace("proj_out.weight", "to_out.0.weight")
|
|
new_item = new_item.replace("proj_out.bias", "to_out.0.bias")
|
|
|
|
new_item = new_item.replace("norm_encoder.weight", "norm_cross.weight")
|
|
new_item = new_item.replace("norm_encoder.bias", "norm_cross.bias")
|
|
|
|
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
|
|
|
mapping.append({"old": old_item, "new": new_item})
|
|
|
|
return mapping
|
|
|
|
|
|
def assign_attention_to_checkpoint(new_checkpoint, unet_state_dict, old_path, new_path, config):
|
|
qkv_weight = unet_state_dict.pop(f"{old_path}.qkv.weight")
|
|
qkv_weight = qkv_weight[:, :, 0]
|
|
|
|
qkv_bias = unet_state_dict.pop(f"{old_path}.qkv.bias")
|
|
|
|
is_cross_attn_only = "only_cross_attention" in config and config["only_cross_attention"]
|
|
|
|
split = 1 if is_cross_attn_only else 3
|
|
|
|
weights, bias = split_attentions(
|
|
weight=qkv_weight,
|
|
bias=qkv_bias,
|
|
split=split,
|
|
chunk_size=config["attention_head_dim"],
|
|
)
|
|
|
|
if is_cross_attn_only:
|
|
query_weight, q_bias = weights, bias
|
|
new_checkpoint[f"{new_path}.to_q.weight"] = query_weight[0]
|
|
new_checkpoint[f"{new_path}.to_q.bias"] = q_bias[0]
|
|
else:
|
|
[query_weight, key_weight, value_weight], [q_bias, k_bias, v_bias] = weights, bias
|
|
new_checkpoint[f"{new_path}.to_q.weight"] = query_weight
|
|
new_checkpoint[f"{new_path}.to_q.bias"] = q_bias
|
|
new_checkpoint[f"{new_path}.to_k.weight"] = key_weight
|
|
new_checkpoint[f"{new_path}.to_k.bias"] = k_bias
|
|
new_checkpoint[f"{new_path}.to_v.weight"] = value_weight
|
|
new_checkpoint[f"{new_path}.to_v.bias"] = v_bias
|
|
|
|
encoder_kv_weight = unet_state_dict.pop(f"{old_path}.encoder_kv.weight")
|
|
encoder_kv_weight = encoder_kv_weight[:, :, 0]
|
|
|
|
encoder_kv_bias = unet_state_dict.pop(f"{old_path}.encoder_kv.bias")
|
|
|
|
[encoder_k_weight, encoder_v_weight], [encoder_k_bias, encoder_v_bias] = split_attentions(
|
|
weight=encoder_kv_weight,
|
|
bias=encoder_kv_bias,
|
|
split=2,
|
|
chunk_size=config["attention_head_dim"],
|
|
)
|
|
|
|
new_checkpoint[f"{new_path}.add_k_proj.weight"] = encoder_k_weight
|
|
new_checkpoint[f"{new_path}.add_k_proj.bias"] = encoder_k_bias
|
|
new_checkpoint[f"{new_path}.add_v_proj.weight"] = encoder_v_weight
|
|
new_checkpoint[f"{new_path}.add_v_proj.bias"] = encoder_v_bias
|
|
|
|
|
|
def assign_to_checkpoint(paths, checkpoint, old_checkpoint, additional_replacements=None, config=None):
|
|
"""
|
|
This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits
|
|
attention layers, and takes into account additional replacements that may arise.
|
|
|
|
Assigns the weights to the new checkpoint.
|
|
"""
|
|
assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
|
|
|
|
for path in paths:
|
|
new_path = path["new"]
|
|
|
|
# Global renaming happens here
|
|
new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
|
|
new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
|
|
new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
|
|
|
|
if additional_replacements is not None:
|
|
for replacement in additional_replacements:
|
|
new_path = new_path.replace(replacement["old"], replacement["new"])
|
|
|
|
# proj_attn.weight has to be converted from conv 1D to linear
|
|
if "proj_attn.weight" in new_path or "to_out.0.weight" in new_path:
|
|
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
|
|
else:
|
|
checkpoint[new_path] = old_checkpoint[path["old"]]
|
|
|
|
|
|
# TODO maybe document and/or can do more efficiently (build indices in for loop and extract once for each split?)
|
|
def split_attentions(*, weight, bias, split, chunk_size):
|
|
weights = [None] * split
|
|
biases = [None] * split
|
|
|
|
weights_biases_idx = 0
|
|
|
|
for starting_row_index in range(0, weight.shape[0], chunk_size):
|
|
row_indices = torch.arange(starting_row_index, starting_row_index + chunk_size)
|
|
|
|
weight_rows = weight[row_indices, :]
|
|
bias_rows = bias[row_indices]
|
|
|
|
if weights[weights_biases_idx] is None:
|
|
weights[weights_biases_idx] = weight_rows
|
|
biases[weights_biases_idx] = bias_rows
|
|
else:
|
|
assert weights[weights_biases_idx] is not None
|
|
weights[weights_biases_idx] = torch.concat([weights[weights_biases_idx], weight_rows])
|
|
biases[weights_biases_idx] = torch.concat([biases[weights_biases_idx], bias_rows])
|
|
|
|
weights_biases_idx = (weights_biases_idx + 1) % split
|
|
|
|
return weights, biases
|
|
|
|
|
|
def parse_list(value):
|
|
if isinstance(value, str):
|
|
value = value.split(",")
|
|
value = [int(v) for v in value]
|
|
elif isinstance(value, list):
|
|
pass
|
|
else:
|
|
raise ValueError(f"Can't parse list for type: {type(value)}")
|
|
|
|
return value
|
|
|
|
|
|
# below is copy and pasted from original convert_if_stage_2.py script
|
|
|
|
|
|
def get_super_res_unet(unet_checkpoint_path, verify_param_count=True, sample_size=None):
|
|
orig_path = unet_checkpoint_path
|
|
|
|
original_unet_config = yaml.safe_load(os.path.join(orig_path, "config.yml"))
|
|
original_unet_config = original_unet_config["params"]
|
|
|
|
unet_diffusers_config = superres_create_unet_diffusers_config(original_unet_config)
|
|
unet_diffusers_config["time_embedding_dim"] = original_unet_config["model_channels"] * int(
|
|
original_unet_config["channel_mult"].split(",")[-1]
|
|
)
|
|
if original_unet_config["encoder_dim"] != original_unet_config["encoder_channels"]:
|
|
unet_diffusers_config["encoder_hid_dim"] = original_unet_config["encoder_dim"]
|
|
unet_diffusers_config["class_embed_type"] = "timestep"
|
|
unet_diffusers_config["addition_embed_type"] = "text"
|
|
|
|
unet_diffusers_config["time_embedding_act_fn"] = "gelu"
|
|
unet_diffusers_config["resnet_skip_time_act"] = True
|
|
unet_diffusers_config["resnet_out_scale_factor"] = 1 / 0.7071
|
|
unet_diffusers_config["mid_block_scale_factor"] = 1 / 0.7071
|
|
unet_diffusers_config["only_cross_attention"] = (
|
|
bool(original_unet_config["disable_self_attentions"])
|
|
if (
|
|
"disable_self_attentions" in original_unet_config
|
|
and isinstance(original_unet_config["disable_self_attentions"], int)
|
|
)
|
|
else True
|
|
)
|
|
|
|
if sample_size is None:
|
|
unet_diffusers_config["sample_size"] = original_unet_config["image_size"]
|
|
else:
|
|
# The second upscaler unet's sample size is incorrectly specified
|
|
# in the config and is instead hardcoded in source
|
|
unet_diffusers_config["sample_size"] = sample_size
|
|
|
|
unet_checkpoint = torch.load(os.path.join(unet_checkpoint_path, "pytorch_model.bin"), map_location="cpu")
|
|
|
|
if verify_param_count:
|
|
# check that architecture matches - is a bit slow
|
|
verify_param_count(orig_path, unet_diffusers_config)
|
|
|
|
converted_unet_checkpoint = superres_convert_ldm_unet_checkpoint(
|
|
unet_checkpoint, unet_diffusers_config, path=unet_checkpoint_path
|
|
)
|
|
converted_keys = converted_unet_checkpoint.keys()
|
|
|
|
model = UNet2DConditionModel(**unet_diffusers_config)
|
|
expected_weights = model.state_dict().keys()
|
|
|
|
diff_c_e = set(converted_keys) - set(expected_weights)
|
|
diff_e_c = set(expected_weights) - set(converted_keys)
|
|
|
|
assert len(diff_e_c) == 0, f"Expected, but not converted: {diff_e_c}"
|
|
assert len(diff_c_e) == 0, f"Converted, but not expected: {diff_c_e}"
|
|
|
|
model.load_state_dict(converted_unet_checkpoint)
|
|
|
|
return model
|
|
|
|
|
|
def superres_create_unet_diffusers_config(original_unet_config):
|
|
attention_resolutions = parse_list(original_unet_config["attention_resolutions"])
|
|
attention_resolutions = [original_unet_config["image_size"] // int(res) for res in attention_resolutions]
|
|
|
|
channel_mult = parse_list(original_unet_config["channel_mult"])
|
|
block_out_channels = [original_unet_config["model_channels"] * mult for mult in channel_mult]
|
|
|
|
down_block_types = []
|
|
resolution = 1
|
|
|
|
for i in range(len(block_out_channels)):
|
|
if resolution in attention_resolutions:
|
|
block_type = "SimpleCrossAttnDownBlock2D"
|
|
elif original_unet_config["resblock_updown"]:
|
|
block_type = "ResnetDownsampleBlock2D"
|
|
else:
|
|
block_type = "DownBlock2D"
|
|
|
|
down_block_types.append(block_type)
|
|
|
|
if i != len(block_out_channels) - 1:
|
|
resolution *= 2
|
|
|
|
up_block_types = []
|
|
for i in range(len(block_out_channels)):
|
|
if resolution in attention_resolutions:
|
|
block_type = "SimpleCrossAttnUpBlock2D"
|
|
elif original_unet_config["resblock_updown"]:
|
|
block_type = "ResnetUpsampleBlock2D"
|
|
else:
|
|
block_type = "UpBlock2D"
|
|
up_block_types.append(block_type)
|
|
resolution //= 2
|
|
|
|
head_dim = original_unet_config["num_head_channels"]
|
|
use_linear_projection = (
|
|
original_unet_config["use_linear_in_transformer"]
|
|
if "use_linear_in_transformer" in original_unet_config
|
|
else False
|
|
)
|
|
if use_linear_projection:
|
|
# stable diffusion 2-base-512 and 2-768
|
|
if head_dim is None:
|
|
head_dim = [5, 10, 20, 20]
|
|
|
|
class_embed_type = None
|
|
projection_class_embeddings_input_dim = None
|
|
|
|
if "num_classes" in original_unet_config:
|
|
if original_unet_config["num_classes"] == "sequential":
|
|
class_embed_type = "projection"
|
|
assert "adm_in_channels" in original_unet_config
|
|
projection_class_embeddings_input_dim = original_unet_config["adm_in_channels"]
|
|
else:
|
|
raise NotImplementedError(
|
|
f"Unknown conditional unet num_classes config: {original_unet_config['num_classes']}"
|
|
)
|
|
|
|
config = {
|
|
"in_channels": original_unet_config["in_channels"],
|
|
"down_block_types": tuple(down_block_types),
|
|
"block_out_channels": tuple(block_out_channels),
|
|
"layers_per_block": tuple(original_unet_config["num_res_blocks"]),
|
|
"cross_attention_dim": original_unet_config["encoder_channels"],
|
|
"attention_head_dim": head_dim,
|
|
"use_linear_projection": use_linear_projection,
|
|
"class_embed_type": class_embed_type,
|
|
"projection_class_embeddings_input_dim": projection_class_embeddings_input_dim,
|
|
"out_channels": original_unet_config["out_channels"],
|
|
"up_block_types": tuple(up_block_types),
|
|
"upcast_attention": False, # TODO: guessing
|
|
"cross_attention_norm": "group_norm",
|
|
"mid_block_type": "UNetMidBlock2DSimpleCrossAttn",
|
|
"act_fn": "gelu",
|
|
}
|
|
|
|
if original_unet_config["use_scale_shift_norm"]:
|
|
config["resnet_time_scale_shift"] = "scale_shift"
|
|
|
|
return config
|
|
|
|
|
|
def superres_convert_ldm_unet_checkpoint(unet_state_dict, config, path=None, extract_ema=False):
|
|
"""
|
|
Takes a state dict and a config, and returns a converted checkpoint.
|
|
"""
|
|
new_checkpoint = {}
|
|
|
|
new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
|
|
new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
|
|
new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
|
|
new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
|
|
|
|
if config["class_embed_type"] is None:
|
|
# No parameters to port
|
|
...
|
|
elif config["class_embed_type"] == "timestep" or config["class_embed_type"] == "projection":
|
|
new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict["aug_proj.0.weight"]
|
|
new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict["aug_proj.0.bias"]
|
|
new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict["aug_proj.2.weight"]
|
|
new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict["aug_proj.2.bias"]
|
|
else:
|
|
raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}")
|
|
|
|
if "encoder_proj.weight" in unet_state_dict:
|
|
new_checkpoint["encoder_hid_proj.weight"] = unet_state_dict["encoder_proj.weight"]
|
|
new_checkpoint["encoder_hid_proj.bias"] = unet_state_dict["encoder_proj.bias"]
|
|
|
|
if "encoder_pooling.0.weight" in unet_state_dict:
|
|
mapping = {
|
|
"encoder_pooling.0": "add_embedding.norm1",
|
|
"encoder_pooling.1": "add_embedding.pool",
|
|
"encoder_pooling.2": "add_embedding.proj",
|
|
"encoder_pooling.3": "add_embedding.norm2",
|
|
}
|
|
for key in unet_state_dict.keys():
|
|
if key.startswith("encoder_pooling"):
|
|
prefix = key[: len("encoder_pooling.0")]
|
|
new_key = key.replace(prefix, mapping[prefix])
|
|
new_checkpoint[new_key] = unet_state_dict[key]
|
|
|
|
new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
|
|
new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
|
|
|
|
new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
|
|
new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
|
|
new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
|
|
new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
|
|
|
|
# Retrieves the keys for the input blocks only
|
|
num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
|
|
input_blocks = {
|
|
layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}." in key]
|
|
for layer_id in range(num_input_blocks)
|
|
}
|
|
|
|
# Retrieves the keys for the middle blocks only
|
|
num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
|
|
middle_blocks = {
|
|
layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
|
|
for layer_id in range(num_middle_blocks)
|
|
}
|
|
|
|
# Retrieves the keys for the output blocks only
|
|
num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
|
|
output_blocks = {
|
|
layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}." in key]
|
|
for layer_id in range(num_output_blocks)
|
|
}
|
|
if not isinstance(config["layers_per_block"], int):
|
|
layers_per_block_list = [e + 1 for e in config["layers_per_block"]]
|
|
layers_per_block_cumsum = list(np.cumsum(layers_per_block_list))
|
|
downsampler_ids = layers_per_block_cumsum
|
|
else:
|
|
# TODO need better check than i in [4, 8, 12, 16]
|
|
downsampler_ids = [4, 8, 12, 16]
|
|
|
|
for i in range(1, num_input_blocks):
|
|
if isinstance(config["layers_per_block"], int):
|
|
layers_per_block = config["layers_per_block"]
|
|
block_id = (i - 1) // (layers_per_block + 1)
|
|
layer_in_block_id = (i - 1) % (layers_per_block + 1)
|
|
else:
|
|
block_id = next(k for k, n in enumerate(layers_per_block_cumsum) if (i - 1) < n)
|
|
passed_blocks = layers_per_block_cumsum[block_id - 1] if block_id > 0 else 0
|
|
layer_in_block_id = (i - 1) - passed_blocks
|
|
|
|
resnets = [
|
|
key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
|
|
]
|
|
attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
|
|
|
|
if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
|
|
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
|
|
f"input_blocks.{i}.0.op.weight"
|
|
)
|
|
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
|
|
f"input_blocks.{i}.0.op.bias"
|
|
)
|
|
|
|
paths = renew_resnet_paths(resnets)
|
|
|
|
block_type = config["down_block_types"][block_id]
|
|
if (
|
|
block_type == "ResnetDownsampleBlock2D" or block_type == "SimpleCrossAttnDownBlock2D"
|
|
) and i in downsampler_ids:
|
|
meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.downsamplers.0"}
|
|
else:
|
|
meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
|
|
|
|
assign_to_checkpoint(
|
|
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
|
)
|
|
|
|
if len(attentions):
|
|
old_path = f"input_blocks.{i}.1"
|
|
new_path = f"down_blocks.{block_id}.attentions.{layer_in_block_id}"
|
|
|
|
assign_attention_to_checkpoint(
|
|
new_checkpoint=new_checkpoint,
|
|
unet_state_dict=unet_state_dict,
|
|
old_path=old_path,
|
|
new_path=new_path,
|
|
config=config,
|
|
)
|
|
|
|
paths = renew_attention_paths(attentions)
|
|
meta_path = {"old": old_path, "new": new_path}
|
|
assign_to_checkpoint(
|
|
paths,
|
|
new_checkpoint,
|
|
unet_state_dict,
|
|
additional_replacements=[meta_path],
|
|
config=config,
|
|
)
|
|
|
|
resnet_0 = middle_blocks[0]
|
|
attentions = middle_blocks[1]
|
|
resnet_1 = middle_blocks[2]
|
|
|
|
resnet_0_paths = renew_resnet_paths(resnet_0)
|
|
assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
|
|
|
|
resnet_1_paths = renew_resnet_paths(resnet_1)
|
|
assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
|
|
|
|
old_path = "middle_block.1"
|
|
new_path = "mid_block.attentions.0"
|
|
|
|
assign_attention_to_checkpoint(
|
|
new_checkpoint=new_checkpoint,
|
|
unet_state_dict=unet_state_dict,
|
|
old_path=old_path,
|
|
new_path=new_path,
|
|
config=config,
|
|
)
|
|
|
|
attentions_paths = renew_attention_paths(attentions)
|
|
meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
|
|
assign_to_checkpoint(
|
|
attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
|
)
|
|
if not isinstance(config["layers_per_block"], int):
|
|
layers_per_block_list = list(reversed([e + 1 for e in config["layers_per_block"]]))
|
|
layers_per_block_cumsum = list(np.cumsum(layers_per_block_list))
|
|
|
|
for i in range(num_output_blocks):
|
|
if isinstance(config["layers_per_block"], int):
|
|
layers_per_block = config["layers_per_block"]
|
|
block_id = i // (layers_per_block + 1)
|
|
layer_in_block_id = i % (layers_per_block + 1)
|
|
else:
|
|
block_id = next(k for k, n in enumerate(layers_per_block_cumsum) if i < n)
|
|
passed_blocks = layers_per_block_cumsum[block_id - 1] if block_id > 0 else 0
|
|
layer_in_block_id = i - passed_blocks
|
|
|
|
output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
|
|
output_block_list = {}
|
|
|
|
for layer in output_block_layers:
|
|
layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
|
|
if layer_id in output_block_list:
|
|
output_block_list[layer_id].append(layer_name)
|
|
else:
|
|
output_block_list[layer_id] = [layer_name]
|
|
|
|
# len(output_block_list) == 1 -> resnet
|
|
# len(output_block_list) == 2 -> resnet, attention or resnet, upscale resnet
|
|
# len(output_block_list) == 3 -> resnet, attention, upscale resnet
|
|
|
|
if len(output_block_list) > 1:
|
|
resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
|
|
|
|
has_attention = True
|
|
if len(output_block_list) == 2 and any("in_layers" in k for k in output_block_list["1"]):
|
|
has_attention = False
|
|
|
|
maybe_attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
|
|
|
|
paths = renew_resnet_paths(resnets)
|
|
|
|
meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
|
|
|
|
assign_to_checkpoint(
|
|
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
|
)
|
|
|
|
output_block_list = {k: sorted(v) for k, v in output_block_list.items()}
|
|
if ["conv.bias", "conv.weight"] in output_block_list.values():
|
|
index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
|
|
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
|
|
f"output_blocks.{i}.{index}.conv.weight"
|
|
]
|
|
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
|
|
f"output_blocks.{i}.{index}.conv.bias"
|
|
]
|
|
|
|
# this layer was no attention
|
|
has_attention = False
|
|
maybe_attentions = []
|
|
|
|
if has_attention:
|
|
old_path = f"output_blocks.{i}.1"
|
|
new_path = f"up_blocks.{block_id}.attentions.{layer_in_block_id}"
|
|
|
|
assign_attention_to_checkpoint(
|
|
new_checkpoint=new_checkpoint,
|
|
unet_state_dict=unet_state_dict,
|
|
old_path=old_path,
|
|
new_path=new_path,
|
|
config=config,
|
|
)
|
|
|
|
paths = renew_attention_paths(maybe_attentions)
|
|
meta_path = {
|
|
"old": old_path,
|
|
"new": new_path,
|
|
}
|
|
assign_to_checkpoint(
|
|
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
|
)
|
|
|
|
if len(output_block_list) == 3 or (not has_attention and len(maybe_attentions) > 0):
|
|
layer_id = len(output_block_list) - 1
|
|
resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.{layer_id}" in key]
|
|
paths = renew_resnet_paths(resnets)
|
|
meta_path = {"old": f"output_blocks.{i}.{layer_id}", "new": f"up_blocks.{block_id}.upsamplers.0"}
|
|
assign_to_checkpoint(
|
|
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
|
)
|
|
else:
|
|
resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
|
|
for path in resnet_0_paths:
|
|
old_path = ".".join(["output_blocks", str(i), path["old"]])
|
|
new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
|
|
|
|
new_checkpoint[new_path] = unet_state_dict[old_path]
|
|
|
|
return new_checkpoint
|
|
|
|
|
|
def verify_param_count(orig_path, unet_diffusers_config):
|
|
if "-II-" in orig_path:
|
|
from deepfloyd_if.modules import IFStageII
|
|
|
|
if_II = IFStageII(device="cpu", dir_or_name=orig_path)
|
|
elif "-III-" in orig_path:
|
|
from deepfloyd_if.modules import IFStageIII
|
|
|
|
if_II = IFStageIII(device="cpu", dir_or_name=orig_path)
|
|
else:
|
|
assert f"Weird name. Should have -II- or -III- in path: {orig_path}"
|
|
|
|
unet = UNet2DConditionModel(**unet_diffusers_config)
|
|
|
|
# in params
|
|
assert_param_count(unet.time_embedding, if_II.model.time_embed)
|
|
assert_param_count(unet.conv_in, if_II.model.input_blocks[:1])
|
|
|
|
# downblocks
|
|
assert_param_count(unet.down_blocks[0], if_II.model.input_blocks[1:4])
|
|
assert_param_count(unet.down_blocks[1], if_II.model.input_blocks[4:7])
|
|
assert_param_count(unet.down_blocks[2], if_II.model.input_blocks[7:11])
|
|
|
|
if "-II-" in orig_path:
|
|
assert_param_count(unet.down_blocks[3], if_II.model.input_blocks[11:17])
|
|
assert_param_count(unet.down_blocks[4], if_II.model.input_blocks[17:])
|
|
if "-III-" in orig_path:
|
|
assert_param_count(unet.down_blocks[3], if_II.model.input_blocks[11:15])
|
|
assert_param_count(unet.down_blocks[4], if_II.model.input_blocks[15:20])
|
|
assert_param_count(unet.down_blocks[5], if_II.model.input_blocks[20:])
|
|
|
|
# mid block
|
|
assert_param_count(unet.mid_block, if_II.model.middle_block)
|
|
|
|
# up block
|
|
if "-II-" in orig_path:
|
|
assert_param_count(unet.up_blocks[0], if_II.model.output_blocks[:6])
|
|
assert_param_count(unet.up_blocks[1], if_II.model.output_blocks[6:12])
|
|
assert_param_count(unet.up_blocks[2], if_II.model.output_blocks[12:16])
|
|
assert_param_count(unet.up_blocks[3], if_II.model.output_blocks[16:19])
|
|
assert_param_count(unet.up_blocks[4], if_II.model.output_blocks[19:])
|
|
if "-III-" in orig_path:
|
|
assert_param_count(unet.up_blocks[0], if_II.model.output_blocks[:5])
|
|
assert_param_count(unet.up_blocks[1], if_II.model.output_blocks[5:10])
|
|
assert_param_count(unet.up_blocks[2], if_II.model.output_blocks[10:14])
|
|
assert_param_count(unet.up_blocks[3], if_II.model.output_blocks[14:18])
|
|
assert_param_count(unet.up_blocks[4], if_II.model.output_blocks[18:21])
|
|
assert_param_count(unet.up_blocks[5], if_II.model.output_blocks[21:24])
|
|
|
|
# out params
|
|
assert_param_count(unet.conv_norm_out, if_II.model.out[0])
|
|
assert_param_count(unet.conv_out, if_II.model.out[2])
|
|
|
|
# make sure all model architecture has same param count
|
|
assert_param_count(unet, if_II.model)
|
|
|
|
|
|
def assert_param_count(model_1, model_2):
|
|
count_1 = sum(p.numel() for p in model_1.parameters())
|
|
count_2 = sum(p.numel() for p in model_2.parameters())
|
|
assert count_1 == count_2, f"{model_1.__class__}: {count_1} != {model_2.__class__}: {count_2}"
|
|
|
|
|
|
def superres_check_against_original(dump_path, unet_checkpoint_path):
|
|
model_path = dump_path
|
|
model = UNet2DConditionModel.from_pretrained(model_path)
|
|
model.to("cuda")
|
|
orig_path = unet_checkpoint_path
|
|
|
|
if "-II-" in orig_path:
|
|
from deepfloyd_if.modules import IFStageII
|
|
|
|
if_II_model = IFStageII(device="cuda", dir_or_name=orig_path, model_kwargs={"precision": "fp32"}).model
|
|
elif "-III-" in orig_path:
|
|
from deepfloyd_if.modules import IFStageIII
|
|
|
|
if_II_model = IFStageIII(device="cuda", dir_or_name=orig_path, model_kwargs={"precision": "fp32"}).model
|
|
|
|
batch_size = 1
|
|
channels = model.config.in_channels // 2
|
|
height = model.config.sample_size
|
|
width = model.config.sample_size
|
|
height = 1024
|
|
width = 1024
|
|
|
|
torch.manual_seed(0)
|
|
|
|
latents = torch.randn((batch_size, channels, height, width), device=model.device)
|
|
image_small = torch.randn((batch_size, channels, height // 4, width // 4), device=model.device)
|
|
|
|
interpolate_antialias = {}
|
|
if "antialias" in inspect.signature(F.interpolate).parameters:
|
|
interpolate_antialias["antialias"] = True
|
|
image_upscaled = F.interpolate(
|
|
image_small, size=[height, width], mode="bicubic", align_corners=False, **interpolate_antialias
|
|
)
|
|
|
|
latent_model_input = torch.cat([latents, image_upscaled], dim=1).to(model.dtype)
|
|
t = torch.tensor([5], device=model.device).to(model.dtype)
|
|
|
|
seq_len = 64
|
|
encoder_hidden_states = torch.randn((batch_size, seq_len, model.config.encoder_hid_dim), device=model.device).to(
|
|
model.dtype
|
|
)
|
|
|
|
fake_class_labels = torch.tensor([t], device=model.device).to(model.dtype)
|
|
|
|
with torch.no_grad():
|
|
out = if_II_model(latent_model_input, t, aug_steps=fake_class_labels, text_emb=encoder_hidden_states)
|
|
|
|
if_II_model.to("cpu")
|
|
del if_II_model
|
|
import gc
|
|
|
|
torch.cuda.empty_cache()
|
|
gc.collect()
|
|
print(50 * "=")
|
|
|
|
with torch.no_grad():
|
|
noise_pred = model(
|
|
sample=latent_model_input,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
class_labels=fake_class_labels,
|
|
timestep=t,
|
|
).sample
|
|
|
|
print("Out shape", noise_pred.shape)
|
|
print("Diff", (out - noise_pred).abs().sum())
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main(parse_args())
|