Compare commits

...

5 Commits

Author SHA1 Message Date
Patrick von Platen
7965655fd3 up 2022-12-15 11:14:13 +00:00
Patrick von Platen
f22326de59 Merge branch 'main' of https://github.com/huggingface/diffusers into main 2022-12-15 11:13:41 +00:00
Patrick von Platen
2595aa0c2f Merge branch 'main' of https://github.com/huggingface/diffusers into main 2022-12-14 14:18:16 +00:00
Patrick von Platen
4725e488b9 Merge branch 'main' of https://github.com/huggingface/diffusers into main 2022-12-14 11:19:35 +00:00
Patrick von Platen
4ab89f22fd Remove bogus file 2022-12-14 11:19:31 +00:00

View File

@@ -102,15 +102,6 @@ def renew_attention_paths(old_list, n_shave_prefix_segments=0):
mapping = []
for old_item in old_list:
new_item = old_item
# new_item = new_item.replace('norm.weight', 'group_norm.weight')
# new_item = new_item.replace('norm.bias', 'group_norm.bias')
# new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
# new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
# new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
mapping.append({"old": old_item, "new": new_item})
return mapping
@@ -476,15 +467,8 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False
return new_checkpoint
def convert_ldm_vae_checkpoint(checkpoint, config):
def convert_ldm_vae_checkpoint(vae_state_dict, config):
# extract state dict for VAE
vae_state_dict = {}
vae_key = "first_stage_model."
keys = list(checkpoint.keys())
for key in keys:
if key.startswith(vae_key):
vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
new_checkpoint = {}
new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
@@ -793,6 +777,12 @@ if __name__ == "__main__":
type=str,
help="The YAML config file corresponding to the original architecture.",
)
parser.add_argument(
"--vae_checkpoint_path",
default=None,
type=str,
help="The path to a vae checkpoint. If left to `None` the vae will be extracted from `checkpoint_path`."
)
parser.add_argument(
"--num_in_channels",
default=None,
@@ -861,7 +851,9 @@ if __name__ == "__main__":
else:
print("global_step key not found in model")
global_step = None
checkpoint = checkpoint["state_dict"]
if "state_dict" in checkpoint:
checkpoint = checkpoint["state_dict"]
upcast_attention = False
if args.original_config_file is None:
@@ -960,7 +952,19 @@ if __name__ == "__main__":
# Convert the VAE model.
vae_config = create_vae_diffusers_config(original_config, image_size=image_size)
converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
if args.vae_checkpoint_path is not None:
vae_state_dict = torch.load(args.vae_checkpoint_path)
vae_state_dict = vae_state_dict["state_dict"]
else:
vae_state_dict = {}
vae_key = "first_stage_model."
keys = list(checkpoint.keys())
for key in keys:
if key.startswith(vae_key):
vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_state_dict, vae_config)
vae = AutoencoderKL(**vae_config)
vae.load_state_dict(converted_vae_checkpoint)