Fix conversion script

This commit is contained in:
Patrick von Platen
2022-07-15 17:00:41 +00:00
parent 87060e6a9c
commit 3f1e95928e
3 changed files with 113 additions and 19 deletions

86
debug_conversion.py Executable file
View File

@@ -0,0 +1,86 @@
#!/usr/bin/env python3
import json
import os
from diffusers import UNetUnconditionalModel
from scripts.convert_ldm_original_checkpoint_to_diffusers import convert_ldm_checkpoint
from huggingface_hub import hf_hub_download
import torch
model_id = "fusing/latent-diffusion-celeba-256"
subfolder = "unet"
#model_id = "fusing/unet-ldm-dummy"
#subfolder = None
checkpoint = "diffusion_model.pt"
config = "config.json"
if subfolder is not None:
checkpoint = os.path.join(subfolder, checkpoint)
config = os.path.join(subfolder, config)
original_checkpoint = torch.load(hf_hub_download(model_id, checkpoint))
config_path = hf_hub_download(model_id, config)
with open(config_path) as f:
config = json.load(f)
checkpoint = convert_ldm_checkpoint(original_checkpoint, config)
def current_codebase_conversion():
model = UNetUnconditionalModel.from_pretrained(model_id, subfolder=subfolder, ldm=True)
model.eval()
torch.manual_seed(0)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(0)
noise = torch.randn(1, model.config.in_channels, model.config.image_size, model.config.image_size)
time_step = torch.tensor([10] * noise.shape[0])
with torch.no_grad():
output = model(noise, time_step)
return model.state_dict()
currently_converted_checkpoint = current_codebase_conversion()
torch.save(currently_converted_checkpoint, 'currently_converted_checkpoint.pt')
def diff_between_checkpoints(ch_0, ch_1):
all_layers_included = False
if not set(ch_0.keys()) == set(ch_1.keys()):
print(f"Contained in ch_0 and not in ch_1 (Total: {len((set(ch_0.keys()) - set(ch_1.keys())))})")
for key in sorted(list((set(ch_0.keys()) - set(ch_1.keys())))):
print(f"\t{key}")
print(f"Contained in ch_1 and not in ch_0 (Total: {len((set(ch_1.keys()) - set(ch_0.keys())))})")
for key in sorted(list((set(ch_1.keys()) - set(ch_0.keys())))):
print(f"\t{key}")
else:
print("Keys are the same between the two checkpoints")
all_layers_included = True
keys = ch_0.keys()
non_equal_keys = []
if all_layers_included:
for key in keys:
try:
if not torch.allclose(ch_0[key].cpu(), ch_1[key].cpu()):
non_equal_keys.append(f'{key}. Diff: {torch.max(torch.abs(ch_0[key].cpu() - ch_1[key].cpu()))}')
except RuntimeError as e:
print(e)
non_equal_keys.append(f'{key}. Diff in shape: {ch_0[key].size()} vs {ch_1[key].size()}')
if len(non_equal_keys):
non_equal_keys = '\n\t'.join(non_equal_keys)
print(f"These keys do not satisfy equivalence requirement:\n\t{non_equal_keys}")
else:
print("All keys are equal across checkpoints.")
diff_between_checkpoints(currently_converted_checkpoint, checkpoint)

0
scripts/__init__.py Normal file
View File

View File

@@ -72,7 +72,7 @@ def renew_attention_paths(old_list, n_shave_prefix_segments=0):
return mapping
def assign_to_checkpoint(paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None):
def assign_to_checkpoint(paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None):
"""
This does the final conversion step: take locally converted weights and apply a global renaming
to them. It splits attention layers, and takes into account additional replacements
@@ -85,11 +85,19 @@ def assign_to_checkpoint(paths, checkpoint, old_checkpoint, attention_paths_to_s
# Splits the attention layers into three variables.
if attention_paths_to_split is not None:
for path, path_map in attention_paths_to_split.items():
query, key, value = torch.split(old_checkpoint[path], int(old_checkpoint[path].shape[0] / 3))
old_tensor = old_checkpoint[path]
channels = old_tensor.shape[0] // 3
checkpoint[path_map['query']] = query
checkpoint[path_map['key']] = key
checkpoint[path_map['value']] = value
target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
query, key, value = old_tensor.split(channels // num_heads, dim=1)
checkpoint[path_map['query']] = query.reshape(target_shape)
checkpoint[path_map['key']] = key.reshape(target_shape)
checkpoint[path_map['value']] = value.reshape(target_shape)
for path in paths:
new_path = path['new']
@@ -107,7 +115,11 @@ def assign_to_checkpoint(paths, checkpoint, old_checkpoint, attention_paths_to_s
for replacement in additional_replacements:
new_path = new_path.replace(replacement['old'], replacement['new'])
checkpoint[new_path] = old_checkpoint[path['old']]
# proj_attn.weight has to be converted from conv 1D to linear
if "proj_attn.weight" in new_path:
checkpoint[new_path] = old_checkpoint[path['old']][:, :, 0]
else:
checkpoint[new_path] = old_checkpoint[path['old']]
def convert_ldm_checkpoint(checkpoint, config):
@@ -155,7 +167,7 @@ def convert_ldm_checkpoint(checkpoint, config):
paths = renew_resnet_paths(resnets)
meta_path = {'old': f'input_blocks.{i}.0', 'new': f'downsample_blocks.{block_id}.resnets.{layer_in_block_id}'}
resnet_op = {'old': 'resnets.2.op', 'new': 'downsamplers.0.op'}
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[meta_path, resnet_op])
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[meta_path, resnet_op], config=config)
if len(attentions):
paths = renew_attention_paths(attentions)
@@ -177,19 +189,19 @@ def convert_ldm_checkpoint(checkpoint, config):
new_checkpoint,
checkpoint,
additional_replacements=[meta_path],
attention_paths_to_split=to_split
attention_paths_to_split=to_split,
config=config
)
resnet_0 = middle_blocks[0]
attentions = middle_blocks[1]
resnet_1 = middle_blocks[2]
resnet_0_paths = renew_resnet_paths(resnet_0)
assign_to_checkpoint(resnet_0_paths, new_checkpoint, checkpoint)
assign_to_checkpoint(resnet_0_paths, new_checkpoint, checkpoint, config=config)
resnet_1_paths = renew_resnet_paths(resnet_1)
assign_to_checkpoint(resnet_1_paths, new_checkpoint, checkpoint)
assign_to_checkpoint(resnet_1_paths, new_checkpoint, checkpoint, config=config)
attentions_paths = renew_attention_paths(attentions)
to_split = {
@@ -204,7 +216,7 @@ def convert_ldm_checkpoint(checkpoint, config):
'value': 'mid.attentions.0.value.weight',
},
}
assign_to_checkpoint(attentions_paths, new_checkpoint, checkpoint, attention_paths_to_split=to_split)
assign_to_checkpoint(attentions_paths, new_checkpoint, checkpoint, attention_paths_to_split=to_split, config=config)
for i in range(num_output_blocks):
block_id = i // (config['num_res_blocks'] + 1)
@@ -227,7 +239,7 @@ def convert_ldm_checkpoint(checkpoint, config):
paths = renew_resnet_paths(resnets)
meta_path = {'old': f'output_blocks.{i}.0', 'new': f'upsample_blocks.{block_id}.resnets.{layer_in_block_id}'}
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[meta_path])
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[meta_path], config=config)
if ['conv.weight', 'conv.bias'] in output_block_list.values():
index = list(output_block_list.values()).index(['conv.weight', 'conv.bias'])
@@ -238,7 +250,6 @@ def convert_ldm_checkpoint(checkpoint, config):
if len(attentions) == 2:
attentions = []
if len(attentions):
paths = renew_attention_paths(attentions)
meta_path = {
@@ -262,7 +273,8 @@ def convert_ldm_checkpoint(checkpoint, config):
new_checkpoint,
checkpoint,
additional_replacements=[meta_path],
attention_paths_to_split=to_split if any('qkv' in key for key in attentions) else None
attention_paths_to_split=to_split if any('qkv' in key for key in attentions) else None,
config=config,
)
else:
resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
@@ -296,7 +308,6 @@ if __name__ == "__main__":
args = parser.parse_args()
checkpoint = torch.load(args.checkpoint_path)
with open(args.config_file) as f:
@@ -304,6 +315,3 @@ if __name__ == "__main__":
converted_checkpoint = convert_ldm_checkpoint(checkpoint, config)
torch.save(checkpoint, args.dump_path)