Compare commits

...

1 Commits

Author SHA1 Message Date
sayakpaul
9cc99cff36 add v1.1 SVD conversion 2024-02-04 23:05:49 +05:30
2 changed files with 85 additions and 27 deletions

View File

@@ -0,0 +1,53 @@
import argparse
import safetensors.torch
import yaml
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from diffusers import (
AutoencoderKLTemporalDecoder,
EulerDiscreteScheduler,
StableVideoDiffusionPipeline,
UNetSpatioTemporalConditionModel,
)
from .convert_svd_to_diffusers import (
convert_ldm_unet_checkpoint,
convert_ldm_vae_checkpoint,
create_unet_diffusers_config,
)
SVD_V1_CKPT = "stabilityai/stable-video-diffusion-img2vid-xt"
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--original_ckpt_path", default=None, type=str, required=True, help="Path to the checkpoint to convert."
)
parser.add_argument("--config_path", default=None, type=str, required=True, help="Config filepath.")
parser.add_argument("--dump_path", default=None, type=str)
parser.add_argument("--push_to_hub", action="store_true")
args = parser.parse_args()
original_ckpt = safetensors.torch.load_file(args.original_ckpt_path, device="cpu")
config = yaml.safe_load(args.config_path)
unet_config = create_unet_diffusers_config(config, image_size=768)
unet = UNetSpatioTemporalConditionModel(**unet_config)
unet_state_dict = convert_ldm_unet_checkpoint(original_ckpt, config)
unet.load_state_dict(unet_state_dict, strict=True)
vae = AutoencoderKLTemporalDecoder()
vae_state_dict = convert_ldm_vae_checkpoint(original_ckpt, config)
vae.load_state_dict(vae_state_dict, strict=True)
scheduler = EulerDiscreteScheduler.from_pretrained(SVD_V1_CKPT, subfolder="scheduler")
image_encoder = CLIPVisionModelWithProjection.from_pretrained(SVD_V1_CKPT, subfolder="image_encoder")
feature_extractor = CLIPImageProcessor.from_pretrained(SVD_V1_CKPT, subfolder="feature_extractor")
pipeline = StableVideoDiffusionPipeline(
unet=unet, vae=vae, image_encoder=image_encoder, feature_extractor=feature_extractor, scheduler=scheduler
)
pipeline.save_pretrained(args.dump_path, push_to_hub=args.push_to_hub)

View File

@@ -12,23 +12,26 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa
Creates a config for the diffusers based on the config of the LDM model.
"""
if controlnet:
unet_params = original_config.model.params.control_stage_config.params
unet_params = original_config["model"]["params"]["control_stage_config"]["params"]
else:
if "unet_config" in original_config.model.params and original_config.model.params.unet_config is not None:
unet_params = original_config.model.params.unet_config.params
if (
"unet_config" in original_config["model"]["params"]
and original_config["model"]["params"]["unet_config"] is not None
):
unet_params = original_config["model"]["params"]["unet_config"]["params"]
else:
unet_params = original_config.model.params.network_config.params
unet_params = original_config["model"]["params"]["network_config"]["params"]
vae_params = original_config.model.params.first_stage_config.params.encoder_config.params
vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["encoder_config"]["params"]
block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult]
block_out_channels = [unet_params["model_channels"] * mult for mult in unet_params["channel_mult"]]
down_block_types = []
resolution = 1
for i in range(len(block_out_channels)):
block_type = (
"CrossAttnDownBlockSpatioTemporal"
if resolution in unet_params.attention_resolutions
if resolution in unet_params["attention_resolutions"]
else "DownBlockSpatioTemporal"
)
down_block_types.append(block_type)
@@ -39,32 +42,32 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa
for i in range(len(block_out_channels)):
block_type = (
"CrossAttnUpBlockSpatioTemporal"
if resolution in unet_params.attention_resolutions
if resolution in unet_params["attention_resolutions"]
else "UpBlockSpatioTemporal"
)
up_block_types.append(block_type)
resolution //= 2
if unet_params.transformer_depth is not None:
if unet_params["transformer_depth"] is not None:
transformer_layers_per_block = (
unet_params.transformer_depth
if isinstance(unet_params.transformer_depth, int)
else list(unet_params.transformer_depth)
unet_params["transformer_depth"]
if isinstance(unet_params["transformer_depth"], int)
else list(unet_params["transformer_depth"])
)
else:
transformer_layers_per_block = 1
vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1)
head_dim = unet_params.num_heads if "num_heads" in unet_params else None
head_dim = unet_params["num_heads"] if "num_heads" in unet_params else None
use_linear_projection = (
unet_params.use_linear_in_transformer if "use_linear_in_transformer" in unet_params else False
unet_params["use_linear_in_transformer"] if "use_linear_in_transformer" in unet_params else False
)
if use_linear_projection:
# stable diffusion 2-base-512 and 2-768
if head_dim is None:
head_dim_mult = unet_params.model_channels // unet_params.num_head_channels
head_dim = [head_dim_mult * c for c in list(unet_params.channel_mult)]
head_dim_mult = unet_params["model_channels"] // unet_params["num_head_channels"]
head_dim = [head_dim_mult * c for c in list(unet_params["channel_mult"])]
class_embed_type = None
addition_embed_type = None
@@ -72,23 +75,25 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa
projection_class_embeddings_input_dim = None
context_dim = None
if unet_params.context_dim is not None:
if unet_params["context_dim"] is not None:
context_dim = (
unet_params.context_dim if isinstance(unet_params.context_dim, int) else unet_params.context_dim[0]
unet_params["context_dim"]
if isinstance(unet_params["context_dim"], int)
else unet_params["context_dim"][0]
)
if "num_classes" in unet_params:
if unet_params.num_classes == "sequential":
if unet_params["num_classes"] == "sequential":
addition_time_embed_dim = 256
assert "adm_in_channels" in unet_params
projection_class_embeddings_input_dim = unet_params.adm_in_channels
projection_class_embeddings_input_dim = unet_params["adm_in_channels"]
config = {
"sample_size": image_size // vae_scale_factor,
"in_channels": unet_params.in_channels,
"in_channels": unet_params["in_channels"],
"down_block_types": tuple(down_block_types),
"block_out_channels": tuple(block_out_channels),
"layers_per_block": unet_params.num_res_blocks,
"layers_per_block": unet_params["num_res_blocks"],
"cross_attention_dim": context_dim,
"attention_head_dim": head_dim,
"use_linear_projection": use_linear_projection,
@@ -100,15 +105,15 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa
}
if "disable_self_attentions" in unet_params:
config["only_cross_attention"] = unet_params.disable_self_attentions
config["only_cross_attention"] = unet_params["disable_self_attentions"]
if "num_classes" in unet_params and isinstance(unet_params.num_classes, int):
config["num_class_embeds"] = unet_params.num_classes
if "num_classes" in unet_params and isinstance(unet_params["num_classes"], int):
config["num_class_embeds"] = unet_params["num_classes"]
if controlnet:
config["conditioning_channels"] = unet_params.hint_channels
config["conditioning_channels"] = unet_params["hint_channels"]
else:
config["out_channels"] = unet_params.out_channels
config["out_channels"] = unet_params["out_channels"]
config["up_block_types"] = tuple(up_block_types)
return config