mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-08 05:24:20 +08:00
Compare commits
1 Commits
fix-part-t
...
convert-sv
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9cc99cff36 |
53
scripts/convert_svd_1_1_to_diffusers.py
Normal file
53
scripts/convert_svd_1_1_to_diffusers.py
Normal file
@@ -0,0 +1,53 @@
|
||||
import argparse
|
||||
|
||||
import safetensors.torch
|
||||
import yaml
|
||||
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKLTemporalDecoder,
|
||||
EulerDiscreteScheduler,
|
||||
StableVideoDiffusionPipeline,
|
||||
UNetSpatioTemporalConditionModel,
|
||||
)
|
||||
|
||||
from .convert_svd_to_diffusers import (
|
||||
convert_ldm_unet_checkpoint,
|
||||
convert_ldm_vae_checkpoint,
|
||||
create_unet_diffusers_config,
|
||||
)
|
||||
|
||||
|
||||
SVD_V1_CKPT = "stabilityai/stable-video-diffusion-img2vid-xt"
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--original_ckpt_path", default=None, type=str, required=True, help="Path to the checkpoint to convert."
|
||||
)
|
||||
parser.add_argument("--config_path", default=None, type=str, required=True, help="Config filepath.")
|
||||
parser.add_argument("--dump_path", default=None, type=str)
|
||||
parser.add_argument("--push_to_hub", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
original_ckpt = safetensors.torch.load_file(args.original_ckpt_path, device="cpu")
|
||||
config = yaml.safe_load(args.config_path)
|
||||
|
||||
unet_config = create_unet_diffusers_config(config, image_size=768)
|
||||
unet = UNetSpatioTemporalConditionModel(**unet_config)
|
||||
unet_state_dict = convert_ldm_unet_checkpoint(original_ckpt, config)
|
||||
unet.load_state_dict(unet_state_dict, strict=True)
|
||||
|
||||
vae = AutoencoderKLTemporalDecoder()
|
||||
vae_state_dict = convert_ldm_vae_checkpoint(original_ckpt, config)
|
||||
vae.load_state_dict(vae_state_dict, strict=True)
|
||||
|
||||
scheduler = EulerDiscreteScheduler.from_pretrained(SVD_V1_CKPT, subfolder="scheduler")
|
||||
image_encoder = CLIPVisionModelWithProjection.from_pretrained(SVD_V1_CKPT, subfolder="image_encoder")
|
||||
feature_extractor = CLIPImageProcessor.from_pretrained(SVD_V1_CKPT, subfolder="feature_extractor")
|
||||
|
||||
pipeline = StableVideoDiffusionPipeline(
|
||||
unet=unet, vae=vae, image_encoder=image_encoder, feature_extractor=feature_extractor, scheduler=scheduler
|
||||
)
|
||||
pipeline.save_pretrained(args.dump_path, push_to_hub=args.push_to_hub)
|
||||
@@ -12,23 +12,26 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa
|
||||
Creates a config for the diffusers based on the config of the LDM model.
|
||||
"""
|
||||
if controlnet:
|
||||
unet_params = original_config.model.params.control_stage_config.params
|
||||
unet_params = original_config["model"]["params"]["control_stage_config"]["params"]
|
||||
else:
|
||||
if "unet_config" in original_config.model.params and original_config.model.params.unet_config is not None:
|
||||
unet_params = original_config.model.params.unet_config.params
|
||||
if (
|
||||
"unet_config" in original_config["model"]["params"]
|
||||
and original_config["model"]["params"]["unet_config"] is not None
|
||||
):
|
||||
unet_params = original_config["model"]["params"]["unet_config"]["params"]
|
||||
else:
|
||||
unet_params = original_config.model.params.network_config.params
|
||||
unet_params = original_config["model"]["params"]["network_config"]["params"]
|
||||
|
||||
vae_params = original_config.model.params.first_stage_config.params.encoder_config.params
|
||||
vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["encoder_config"]["params"]
|
||||
|
||||
block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult]
|
||||
block_out_channels = [unet_params["model_channels"] * mult for mult in unet_params["channel_mult"]]
|
||||
|
||||
down_block_types = []
|
||||
resolution = 1
|
||||
for i in range(len(block_out_channels)):
|
||||
block_type = (
|
||||
"CrossAttnDownBlockSpatioTemporal"
|
||||
if resolution in unet_params.attention_resolutions
|
||||
if resolution in unet_params["attention_resolutions"]
|
||||
else "DownBlockSpatioTemporal"
|
||||
)
|
||||
down_block_types.append(block_type)
|
||||
@@ -39,32 +42,32 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa
|
||||
for i in range(len(block_out_channels)):
|
||||
block_type = (
|
||||
"CrossAttnUpBlockSpatioTemporal"
|
||||
if resolution in unet_params.attention_resolutions
|
||||
if resolution in unet_params["attention_resolutions"]
|
||||
else "UpBlockSpatioTemporal"
|
||||
)
|
||||
up_block_types.append(block_type)
|
||||
resolution //= 2
|
||||
|
||||
if unet_params.transformer_depth is not None:
|
||||
if unet_params["transformer_depth"] is not None:
|
||||
transformer_layers_per_block = (
|
||||
unet_params.transformer_depth
|
||||
if isinstance(unet_params.transformer_depth, int)
|
||||
else list(unet_params.transformer_depth)
|
||||
unet_params["transformer_depth"]
|
||||
if isinstance(unet_params["transformer_depth"], int)
|
||||
else list(unet_params["transformer_depth"])
|
||||
)
|
||||
else:
|
||||
transformer_layers_per_block = 1
|
||||
|
||||
vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1)
|
||||
|
||||
head_dim = unet_params.num_heads if "num_heads" in unet_params else None
|
||||
head_dim = unet_params["num_heads"] if "num_heads" in unet_params else None
|
||||
use_linear_projection = (
|
||||
unet_params.use_linear_in_transformer if "use_linear_in_transformer" in unet_params else False
|
||||
unet_params["use_linear_in_transformer"] if "use_linear_in_transformer" in unet_params else False
|
||||
)
|
||||
if use_linear_projection:
|
||||
# stable diffusion 2-base-512 and 2-768
|
||||
if head_dim is None:
|
||||
head_dim_mult = unet_params.model_channels // unet_params.num_head_channels
|
||||
head_dim = [head_dim_mult * c for c in list(unet_params.channel_mult)]
|
||||
head_dim_mult = unet_params["model_channels"] // unet_params["num_head_channels"]
|
||||
head_dim = [head_dim_mult * c for c in list(unet_params["channel_mult"])]
|
||||
|
||||
class_embed_type = None
|
||||
addition_embed_type = None
|
||||
@@ -72,23 +75,25 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa
|
||||
projection_class_embeddings_input_dim = None
|
||||
context_dim = None
|
||||
|
||||
if unet_params.context_dim is not None:
|
||||
if unet_params["context_dim"] is not None:
|
||||
context_dim = (
|
||||
unet_params.context_dim if isinstance(unet_params.context_dim, int) else unet_params.context_dim[0]
|
||||
unet_params["context_dim"]
|
||||
if isinstance(unet_params["context_dim"], int)
|
||||
else unet_params["context_dim"][0]
|
||||
)
|
||||
|
||||
if "num_classes" in unet_params:
|
||||
if unet_params.num_classes == "sequential":
|
||||
if unet_params["num_classes"] == "sequential":
|
||||
addition_time_embed_dim = 256
|
||||
assert "adm_in_channels" in unet_params
|
||||
projection_class_embeddings_input_dim = unet_params.adm_in_channels
|
||||
projection_class_embeddings_input_dim = unet_params["adm_in_channels"]
|
||||
|
||||
config = {
|
||||
"sample_size": image_size // vae_scale_factor,
|
||||
"in_channels": unet_params.in_channels,
|
||||
"in_channels": unet_params["in_channels"],
|
||||
"down_block_types": tuple(down_block_types),
|
||||
"block_out_channels": tuple(block_out_channels),
|
||||
"layers_per_block": unet_params.num_res_blocks,
|
||||
"layers_per_block": unet_params["num_res_blocks"],
|
||||
"cross_attention_dim": context_dim,
|
||||
"attention_head_dim": head_dim,
|
||||
"use_linear_projection": use_linear_projection,
|
||||
@@ -100,15 +105,15 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa
|
||||
}
|
||||
|
||||
if "disable_self_attentions" in unet_params:
|
||||
config["only_cross_attention"] = unet_params.disable_self_attentions
|
||||
config["only_cross_attention"] = unet_params["disable_self_attentions"]
|
||||
|
||||
if "num_classes" in unet_params and isinstance(unet_params.num_classes, int):
|
||||
config["num_class_embeds"] = unet_params.num_classes
|
||||
if "num_classes" in unet_params and isinstance(unet_params["num_classes"], int):
|
||||
config["num_class_embeds"] = unet_params["num_classes"]
|
||||
|
||||
if controlnet:
|
||||
config["conditioning_channels"] = unet_params.hint_channels
|
||||
config["conditioning_channels"] = unet_params["hint_channels"]
|
||||
else:
|
||||
config["out_channels"] = unet_params.out_channels
|
||||
config["out_channels"] = unet_params["out_channels"]
|
||||
config["up_block_types"] = tuple(up_block_types)
|
||||
|
||||
return config
|
||||
|
||||
Reference in New Issue
Block a user