mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-06 12:34:13 +08:00
Fix conversion script
This commit is contained in:
86
debug_conversion.py
Executable file
86
debug_conversion.py
Executable file
@@ -0,0 +1,86 @@
|
||||
#!/usr/bin/env python3
|
||||
import json
|
||||
import os
|
||||
from diffusers import UNetUnconditionalModel
|
||||
from scripts.convert_ldm_original_checkpoint_to_diffusers import convert_ldm_checkpoint
|
||||
from huggingface_hub import hf_hub_download
|
||||
import torch
|
||||
|
||||
model_id = "fusing/latent-diffusion-celeba-256"
|
||||
subfolder = "unet"
|
||||
#model_id = "fusing/unet-ldm-dummy"
|
||||
#subfolder = None
|
||||
|
||||
checkpoint = "diffusion_model.pt"
|
||||
config = "config.json"
|
||||
|
||||
if subfolder is not None:
|
||||
checkpoint = os.path.join(subfolder, checkpoint)
|
||||
config = os.path.join(subfolder, config)
|
||||
|
||||
original_checkpoint = torch.load(hf_hub_download(model_id, checkpoint))
|
||||
config_path = hf_hub_download(model_id, config)
|
||||
|
||||
with open(config_path) as f:
|
||||
config = json.load(f)
|
||||
|
||||
checkpoint = convert_ldm_checkpoint(original_checkpoint, config)
|
||||
|
||||
|
||||
def current_codebase_conversion():
|
||||
model = UNetUnconditionalModel.from_pretrained(model_id, subfolder=subfolder, ldm=True)
|
||||
model.eval()
|
||||
|
||||
torch.manual_seed(0)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(0)
|
||||
|
||||
noise = torch.randn(1, model.config.in_channels, model.config.image_size, model.config.image_size)
|
||||
time_step = torch.tensor([10] * noise.shape[0])
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(noise, time_step)
|
||||
|
||||
return model.state_dict()
|
||||
|
||||
|
||||
currently_converted_checkpoint = current_codebase_conversion()
|
||||
torch.save(currently_converted_checkpoint, 'currently_converted_checkpoint.pt')
|
||||
|
||||
|
||||
def diff_between_checkpoints(ch_0, ch_1):
|
||||
all_layers_included = False
|
||||
|
||||
if not set(ch_0.keys()) == set(ch_1.keys()):
|
||||
print(f"Contained in ch_0 and not in ch_1 (Total: {len((set(ch_0.keys()) - set(ch_1.keys())))})")
|
||||
for key in sorted(list((set(ch_0.keys()) - set(ch_1.keys())))):
|
||||
print(f"\t{key}")
|
||||
|
||||
print(f"Contained in ch_1 and not in ch_0 (Total: {len((set(ch_1.keys()) - set(ch_0.keys())))})")
|
||||
for key in sorted(list((set(ch_1.keys()) - set(ch_0.keys())))):
|
||||
print(f"\t{key}")
|
||||
else:
|
||||
print("Keys are the same between the two checkpoints")
|
||||
all_layers_included = True
|
||||
|
||||
keys = ch_0.keys()
|
||||
non_equal_keys = []
|
||||
|
||||
if all_layers_included:
|
||||
for key in keys:
|
||||
try:
|
||||
if not torch.allclose(ch_0[key].cpu(), ch_1[key].cpu()):
|
||||
non_equal_keys.append(f'{key}. Diff: {torch.max(torch.abs(ch_0[key].cpu() - ch_1[key].cpu()))}')
|
||||
|
||||
except RuntimeError as e:
|
||||
print(e)
|
||||
non_equal_keys.append(f'{key}. Diff in shape: {ch_0[key].size()} vs {ch_1[key].size()}')
|
||||
|
||||
if len(non_equal_keys):
|
||||
non_equal_keys = '\n\t'.join(non_equal_keys)
|
||||
print(f"These keys do not satisfy equivalence requirement:\n\t{non_equal_keys}")
|
||||
else:
|
||||
print("All keys are equal across checkpoints.")
|
||||
|
||||
|
||||
diff_between_checkpoints(currently_converted_checkpoint, checkpoint)
|
||||
0
scripts/__init__.py
Normal file
0
scripts/__init__.py
Normal file
@@ -72,7 +72,7 @@ def renew_attention_paths(old_list, n_shave_prefix_segments=0):
|
||||
return mapping
|
||||
|
||||
|
||||
def assign_to_checkpoint(paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None):
|
||||
def assign_to_checkpoint(paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None):
|
||||
"""
|
||||
This does the final conversion step: take locally converted weights and apply a global renaming
|
||||
to them. It splits attention layers, and takes into account additional replacements
|
||||
@@ -85,11 +85,19 @@ def assign_to_checkpoint(paths, checkpoint, old_checkpoint, attention_paths_to_s
|
||||
# Splits the attention layers into three variables.
|
||||
if attention_paths_to_split is not None:
|
||||
for path, path_map in attention_paths_to_split.items():
|
||||
query, key, value = torch.split(old_checkpoint[path], int(old_checkpoint[path].shape[0] / 3))
|
||||
old_tensor = old_checkpoint[path]
|
||||
channels = old_tensor.shape[0] // 3
|
||||
|
||||
checkpoint[path_map['query']] = query
|
||||
checkpoint[path_map['key']] = key
|
||||
checkpoint[path_map['value']] = value
|
||||
target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
|
||||
|
||||
num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
|
||||
|
||||
old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
|
||||
query, key, value = old_tensor.split(channels // num_heads, dim=1)
|
||||
|
||||
checkpoint[path_map['query']] = query.reshape(target_shape)
|
||||
checkpoint[path_map['key']] = key.reshape(target_shape)
|
||||
checkpoint[path_map['value']] = value.reshape(target_shape)
|
||||
|
||||
for path in paths:
|
||||
new_path = path['new']
|
||||
@@ -107,7 +115,11 @@ def assign_to_checkpoint(paths, checkpoint, old_checkpoint, attention_paths_to_s
|
||||
for replacement in additional_replacements:
|
||||
new_path = new_path.replace(replacement['old'], replacement['new'])
|
||||
|
||||
checkpoint[new_path] = old_checkpoint[path['old']]
|
||||
# proj_attn.weight has to be converted from conv 1D to linear
|
||||
if "proj_attn.weight" in new_path:
|
||||
checkpoint[new_path] = old_checkpoint[path['old']][:, :, 0]
|
||||
else:
|
||||
checkpoint[new_path] = old_checkpoint[path['old']]
|
||||
|
||||
|
||||
def convert_ldm_checkpoint(checkpoint, config):
|
||||
@@ -155,7 +167,7 @@ def convert_ldm_checkpoint(checkpoint, config):
|
||||
paths = renew_resnet_paths(resnets)
|
||||
meta_path = {'old': f'input_blocks.{i}.0', 'new': f'downsample_blocks.{block_id}.resnets.{layer_in_block_id}'}
|
||||
resnet_op = {'old': 'resnets.2.op', 'new': 'downsamplers.0.op'}
|
||||
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[meta_path, resnet_op])
|
||||
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[meta_path, resnet_op], config=config)
|
||||
|
||||
if len(attentions):
|
||||
paths = renew_attention_paths(attentions)
|
||||
@@ -177,19 +189,19 @@ def convert_ldm_checkpoint(checkpoint, config):
|
||||
new_checkpoint,
|
||||
checkpoint,
|
||||
additional_replacements=[meta_path],
|
||||
attention_paths_to_split=to_split
|
||||
attention_paths_to_split=to_split,
|
||||
config=config
|
||||
)
|
||||
|
||||
|
||||
resnet_0 = middle_blocks[0]
|
||||
attentions = middle_blocks[1]
|
||||
resnet_1 = middle_blocks[2]
|
||||
|
||||
resnet_0_paths = renew_resnet_paths(resnet_0)
|
||||
assign_to_checkpoint(resnet_0_paths, new_checkpoint, checkpoint)
|
||||
assign_to_checkpoint(resnet_0_paths, new_checkpoint, checkpoint, config=config)
|
||||
|
||||
resnet_1_paths = renew_resnet_paths(resnet_1)
|
||||
assign_to_checkpoint(resnet_1_paths, new_checkpoint, checkpoint)
|
||||
assign_to_checkpoint(resnet_1_paths, new_checkpoint, checkpoint, config=config)
|
||||
|
||||
attentions_paths = renew_attention_paths(attentions)
|
||||
to_split = {
|
||||
@@ -204,7 +216,7 @@ def convert_ldm_checkpoint(checkpoint, config):
|
||||
'value': 'mid.attentions.0.value.weight',
|
||||
},
|
||||
}
|
||||
assign_to_checkpoint(attentions_paths, new_checkpoint, checkpoint, attention_paths_to_split=to_split)
|
||||
assign_to_checkpoint(attentions_paths, new_checkpoint, checkpoint, attention_paths_to_split=to_split, config=config)
|
||||
|
||||
for i in range(num_output_blocks):
|
||||
block_id = i // (config['num_res_blocks'] + 1)
|
||||
@@ -227,7 +239,7 @@ def convert_ldm_checkpoint(checkpoint, config):
|
||||
paths = renew_resnet_paths(resnets)
|
||||
|
||||
meta_path = {'old': f'output_blocks.{i}.0', 'new': f'upsample_blocks.{block_id}.resnets.{layer_in_block_id}'}
|
||||
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[meta_path])
|
||||
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[meta_path], config=config)
|
||||
|
||||
if ['conv.weight', 'conv.bias'] in output_block_list.values():
|
||||
index = list(output_block_list.values()).index(['conv.weight', 'conv.bias'])
|
||||
@@ -238,7 +250,6 @@ def convert_ldm_checkpoint(checkpoint, config):
|
||||
if len(attentions) == 2:
|
||||
attentions = []
|
||||
|
||||
|
||||
if len(attentions):
|
||||
paths = renew_attention_paths(attentions)
|
||||
meta_path = {
|
||||
@@ -262,7 +273,8 @@ def convert_ldm_checkpoint(checkpoint, config):
|
||||
new_checkpoint,
|
||||
checkpoint,
|
||||
additional_replacements=[meta_path],
|
||||
attention_paths_to_split=to_split if any('qkv' in key for key in attentions) else None
|
||||
attention_paths_to_split=to_split if any('qkv' in key for key in attentions) else None,
|
||||
config=config,
|
||||
)
|
||||
else:
|
||||
resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
|
||||
@@ -296,7 +308,6 @@ if __name__ == "__main__":
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
checkpoint = torch.load(args.checkpoint_path)
|
||||
|
||||
with open(args.config_file) as f:
|
||||
@@ -304,6 +315,3 @@ if __name__ == "__main__":
|
||||
|
||||
converted_checkpoint = convert_ldm_checkpoint(checkpoint, config)
|
||||
torch.save(checkpoint, args.dump_path)
|
||||
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user