mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-15 00:44:51 +08:00
Compare commits
5 Commits
update-sty
...
improve_co
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7965655fd3 | ||
|
|
f22326de59 | ||
|
|
2595aa0c2f | ||
|
|
4725e488b9 | ||
|
|
4ab89f22fd |
@@ -102,15 +102,6 @@ def renew_attention_paths(old_list, n_shave_prefix_segments=0):
|
|||||||
mapping = []
|
mapping = []
|
||||||
for old_item in old_list:
|
for old_item in old_list:
|
||||||
new_item = old_item
|
new_item = old_item
|
||||||
|
|
||||||
# new_item = new_item.replace('norm.weight', 'group_norm.weight')
|
|
||||||
# new_item = new_item.replace('norm.bias', 'group_norm.bias')
|
|
||||||
|
|
||||||
# new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
|
|
||||||
# new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
|
|
||||||
|
|
||||||
# new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
|
||||||
|
|
||||||
mapping.append({"old": old_item, "new": new_item})
|
mapping.append({"old": old_item, "new": new_item})
|
||||||
|
|
||||||
return mapping
|
return mapping
|
||||||
@@ -476,15 +467,8 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False
|
|||||||
return new_checkpoint
|
return new_checkpoint
|
||||||
|
|
||||||
|
|
||||||
def convert_ldm_vae_checkpoint(checkpoint, config):
|
def convert_ldm_vae_checkpoint(vae_state_dict, config):
|
||||||
# extract state dict for VAE
|
# extract state dict for VAE
|
||||||
vae_state_dict = {}
|
|
||||||
vae_key = "first_stage_model."
|
|
||||||
keys = list(checkpoint.keys())
|
|
||||||
for key in keys:
|
|
||||||
if key.startswith(vae_key):
|
|
||||||
vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
|
|
||||||
|
|
||||||
new_checkpoint = {}
|
new_checkpoint = {}
|
||||||
|
|
||||||
new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
|
new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
|
||||||
@@ -793,6 +777,12 @@ if __name__ == "__main__":
|
|||||||
type=str,
|
type=str,
|
||||||
help="The YAML config file corresponding to the original architecture.",
|
help="The YAML config file corresponding to the original architecture.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--vae_checkpoint_path",
|
||||||
|
default=None,
|
||||||
|
type=str,
|
||||||
|
help="The path to a vae checkpoint. If left to `None` the vae will be extracted from `checkpoint_path`."
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--num_in_channels",
|
"--num_in_channels",
|
||||||
default=None,
|
default=None,
|
||||||
@@ -861,7 +851,9 @@ if __name__ == "__main__":
|
|||||||
else:
|
else:
|
||||||
print("global_step key not found in model")
|
print("global_step key not found in model")
|
||||||
global_step = None
|
global_step = None
|
||||||
checkpoint = checkpoint["state_dict"]
|
|
||||||
|
if "state_dict" in checkpoint:
|
||||||
|
checkpoint = checkpoint["state_dict"]
|
||||||
|
|
||||||
upcast_attention = False
|
upcast_attention = False
|
||||||
if args.original_config_file is None:
|
if args.original_config_file is None:
|
||||||
@@ -960,7 +952,19 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
# Convert the VAE model.
|
# Convert the VAE model.
|
||||||
vae_config = create_vae_diffusers_config(original_config, image_size=image_size)
|
vae_config = create_vae_diffusers_config(original_config, image_size=image_size)
|
||||||
converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
|
|
||||||
|
if args.vae_checkpoint_path is not None:
|
||||||
|
vae_state_dict = torch.load(args.vae_checkpoint_path)
|
||||||
|
vae_state_dict = vae_state_dict["state_dict"]
|
||||||
|
else:
|
||||||
|
vae_state_dict = {}
|
||||||
|
vae_key = "first_stage_model."
|
||||||
|
keys = list(checkpoint.keys())
|
||||||
|
for key in keys:
|
||||||
|
if key.startswith(vae_key):
|
||||||
|
vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
|
||||||
|
|
||||||
|
converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_state_dict, vae_config)
|
||||||
|
|
||||||
vae = AutoencoderKL(**vae_config)
|
vae = AutoencoderKL(**vae_config)
|
||||||
vae.load_state_dict(converted_vae_checkpoint)
|
vae.load_state_dict(converted_vae_checkpoint)
|
||||||
|
|||||||
Reference in New Issue
Block a user