mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-06 12:34:13 +08:00
The function load_model_dict_into_meta was moved from modeling_utils.py to model_loading_utils.py but the imports in the conversion scripts were not updated, causing ImportError when running these scripts. This fixes the import in 6 conversion scripts: - scripts/convert_sd3_to_diffusers.py - scripts/convert_stable_cascade_lite.py - scripts/convert_stable_cascade.py - scripts/convert_stable_audio.py - scripts/convert_sana_to_diffusers.py - scripts/convert_sana_controlnet_to_diffusers.py Fixes #12606
280 lines
11 KiB
Python
280 lines
11 KiB
Python
# Run this script to convert the Stable Audio model weights to a diffusers pipeline.
|
|
import argparse
|
|
import json
|
|
import os
|
|
from contextlib import nullcontext
|
|
|
|
import torch
|
|
from safetensors.torch import load_file
|
|
from transformers import (
|
|
AutoTokenizer,
|
|
T5EncoderModel,
|
|
)
|
|
|
|
from diffusers import (
|
|
AutoencoderOobleck,
|
|
CosineDPMSolverMultistepScheduler,
|
|
StableAudioDiTModel,
|
|
StableAudioPipeline,
|
|
StableAudioProjectionModel,
|
|
)
|
|
from diffusers.models.model_loading_utils import load_model_dict_into_meta
|
|
from diffusers.utils import is_accelerate_available
|
|
|
|
|
|
if is_accelerate_available():
|
|
from accelerate import init_empty_weights
|
|
|
|
|
|
def convert_stable_audio_state_dict_to_diffusers(state_dict, num_autoencoder_layers=5):
|
|
projection_model_state_dict = {
|
|
k.replace("conditioner.conditioners.", "").replace("embedder.embedding", "time_positional_embedding"): v
|
|
for (k, v) in state_dict.items()
|
|
if "conditioner.conditioners" in k
|
|
}
|
|
|
|
# NOTE: we assume here that there's no projection layer from the text encoder to the latent space, script should be adapted a bit if there is.
|
|
for key, value in list(projection_model_state_dict.items()):
|
|
new_key = key.replace("seconds_start", "start_number_conditioner").replace(
|
|
"seconds_total", "end_number_conditioner"
|
|
)
|
|
projection_model_state_dict[new_key] = projection_model_state_dict.pop(key)
|
|
|
|
model_state_dict = {k.replace("model.model.", ""): v for (k, v) in state_dict.items() if "model.model." in k}
|
|
for key, value in list(model_state_dict.items()):
|
|
# attention layers
|
|
new_key = (
|
|
key.replace("transformer.", "")
|
|
.replace("layers", "transformer_blocks")
|
|
.replace("self_attn", "attn1")
|
|
.replace("cross_attn", "attn2")
|
|
.replace("ff.ff", "ff.net")
|
|
)
|
|
new_key = (
|
|
new_key.replace("pre_norm", "norm1")
|
|
.replace("cross_attend_norm", "norm2")
|
|
.replace("ff_norm", "norm3")
|
|
.replace("to_out", "to_out.0")
|
|
)
|
|
new_key = new_key.replace("gamma", "weight").replace("beta", "bias") # replace layernorm
|
|
|
|
# other layers
|
|
new_key = (
|
|
new_key.replace("project", "proj")
|
|
.replace("to_timestep_embed", "timestep_proj")
|
|
.replace("timestep_features", "time_proj")
|
|
.replace("to_global_embed", "global_proj")
|
|
.replace("to_cond_embed", "cross_attention_proj")
|
|
)
|
|
|
|
# we're using diffusers implementation of time_proj (GaussianFourierProjection) which creates a 1D tensor
|
|
if new_key == "time_proj.weight":
|
|
model_state_dict[key] = model_state_dict[key].squeeze(1)
|
|
|
|
if "to_qkv" in new_key:
|
|
q, k, v = torch.chunk(model_state_dict.pop(key), 3, dim=0)
|
|
model_state_dict[new_key.replace("qkv", "q")] = q
|
|
model_state_dict[new_key.replace("qkv", "k")] = k
|
|
model_state_dict[new_key.replace("qkv", "v")] = v
|
|
elif "to_kv" in new_key:
|
|
k, v = torch.chunk(model_state_dict.pop(key), 2, dim=0)
|
|
model_state_dict[new_key.replace("kv", "k")] = k
|
|
model_state_dict[new_key.replace("kv", "v")] = v
|
|
else:
|
|
model_state_dict[new_key] = model_state_dict.pop(key)
|
|
|
|
autoencoder_state_dict = {
|
|
k.replace("pretransform.model.", "").replace("coder.layers.0", "coder.conv1"): v
|
|
for (k, v) in state_dict.items()
|
|
if "pretransform.model." in k
|
|
}
|
|
|
|
for key, _ in list(autoencoder_state_dict.items()):
|
|
new_key = key
|
|
if "coder.layers" in new_key:
|
|
# get idx of the layer
|
|
idx = int(new_key.split("coder.layers.")[1].split(".")[0])
|
|
|
|
new_key = new_key.replace(f"coder.layers.{idx}", f"coder.block.{idx - 1}")
|
|
|
|
if "encoder" in new_key:
|
|
for i in range(3):
|
|
new_key = new_key.replace(f"block.{idx - 1}.layers.{i}", f"block.{idx - 1}.res_unit{i + 1}")
|
|
new_key = new_key.replace(f"block.{idx - 1}.layers.3", f"block.{idx - 1}.snake1")
|
|
new_key = new_key.replace(f"block.{idx - 1}.layers.4", f"block.{idx - 1}.conv1")
|
|
else:
|
|
for i in range(2, 5):
|
|
new_key = new_key.replace(f"block.{idx - 1}.layers.{i}", f"block.{idx - 1}.res_unit{i - 1}")
|
|
new_key = new_key.replace(f"block.{idx - 1}.layers.0", f"block.{idx - 1}.snake1")
|
|
new_key = new_key.replace(f"block.{idx - 1}.layers.1", f"block.{idx - 1}.conv_t1")
|
|
|
|
new_key = new_key.replace("layers.0.beta", "snake1.beta")
|
|
new_key = new_key.replace("layers.0.alpha", "snake1.alpha")
|
|
new_key = new_key.replace("layers.2.beta", "snake2.beta")
|
|
new_key = new_key.replace("layers.2.alpha", "snake2.alpha")
|
|
new_key = new_key.replace("layers.1.bias", "conv1.bias")
|
|
new_key = new_key.replace("layers.1.weight_", "conv1.weight_")
|
|
new_key = new_key.replace("layers.3.bias", "conv2.bias")
|
|
new_key = new_key.replace("layers.3.weight_", "conv2.weight_")
|
|
|
|
if idx == num_autoencoder_layers + 1:
|
|
new_key = new_key.replace(f"block.{idx - 1}", "snake1")
|
|
elif idx == num_autoencoder_layers + 2:
|
|
new_key = new_key.replace(f"block.{idx - 1}", "conv2")
|
|
|
|
else:
|
|
new_key = new_key
|
|
|
|
value = autoencoder_state_dict.pop(key)
|
|
if "snake" in new_key:
|
|
value = value.unsqueeze(0).unsqueeze(-1)
|
|
if new_key in autoencoder_state_dict:
|
|
raise ValueError(f"{new_key} already in state dict.")
|
|
autoencoder_state_dict[new_key] = value
|
|
|
|
return model_state_dict, projection_model_state_dict, autoencoder_state_dict
|
|
|
|
|
|
parser = argparse.ArgumentParser(description="Convert Stable Audio 1.0 model weights to a diffusers pipeline")
|
|
parser.add_argument("--model_folder_path", type=str, help="Location of Stable Audio weights and config")
|
|
parser.add_argument("--use_safetensors", action="store_true", help="Use SafeTensors for conversion")
|
|
parser.add_argument(
|
|
"--save_directory",
|
|
type=str,
|
|
default="./tmp/stable-audio-1.0",
|
|
help="Directory to save a pipeline to. Will be created if it doesn't exist.",
|
|
)
|
|
parser.add_argument(
|
|
"--repo_id",
|
|
type=str,
|
|
default="stable-audio-1.0",
|
|
help="Hub organization to save the pipelines to",
|
|
)
|
|
parser.add_argument("--push_to_hub", action="store_true", help="Push to hub")
|
|
parser.add_argument("--variant", type=str, help="Set to bf16 to save bfloat16 weights")
|
|
|
|
args = parser.parse_args()
|
|
|
|
checkpoint_path = (
|
|
os.path.join(args.model_folder_path, "model.safetensors")
|
|
if args.use_safetensors
|
|
else os.path.join(args.model_folder_path, "model.ckpt")
|
|
)
|
|
config_path = os.path.join(args.model_folder_path, "model_config.json")
|
|
|
|
device = "cpu"
|
|
if args.variant == "bf16":
|
|
dtype = torch.bfloat16
|
|
else:
|
|
dtype = torch.float32
|
|
|
|
with open(config_path) as f_in:
|
|
config_dict = json.load(f_in)
|
|
|
|
conditioning_dict = {
|
|
conditioning["id"]: conditioning["config"] for conditioning in config_dict["model"]["conditioning"]["configs"]
|
|
}
|
|
|
|
t5_model_config = conditioning_dict["prompt"]
|
|
|
|
# T5 Text encoder
|
|
text_encoder = T5EncoderModel.from_pretrained(t5_model_config["t5_model_name"])
|
|
tokenizer = AutoTokenizer.from_pretrained(
|
|
t5_model_config["t5_model_name"], truncation=True, model_max_length=t5_model_config["max_length"]
|
|
)
|
|
|
|
|
|
# scheduler
|
|
scheduler = CosineDPMSolverMultistepScheduler(
|
|
sigma_min=0.3,
|
|
sigma_max=500,
|
|
solver_order=2,
|
|
prediction_type="v_prediction",
|
|
sigma_data=1.0,
|
|
sigma_schedule="exponential",
|
|
)
|
|
ctx = init_empty_weights if is_accelerate_available() else nullcontext
|
|
|
|
|
|
if args.use_safetensors:
|
|
orig_state_dict = load_file(checkpoint_path, device=device)
|
|
else:
|
|
orig_state_dict = torch.load(checkpoint_path, map_location=device)
|
|
|
|
|
|
model_config = config_dict["model"]["diffusion"]["config"]
|
|
|
|
model_state_dict, projection_model_state_dict, autoencoder_state_dict = convert_stable_audio_state_dict_to_diffusers(
|
|
orig_state_dict
|
|
)
|
|
|
|
|
|
with ctx():
|
|
projection_model = StableAudioProjectionModel(
|
|
text_encoder_dim=text_encoder.config.d_model,
|
|
conditioning_dim=config_dict["model"]["conditioning"]["cond_dim"],
|
|
min_value=conditioning_dict["seconds_start"][
|
|
"min_val"
|
|
], # assume `seconds_start` and `seconds_total` have the same min / max values.
|
|
max_value=conditioning_dict["seconds_start"][
|
|
"max_val"
|
|
], # assume `seconds_start` and `seconds_total` have the same min / max values.
|
|
)
|
|
if is_accelerate_available():
|
|
load_model_dict_into_meta(projection_model, projection_model_state_dict)
|
|
else:
|
|
projection_model.load_state_dict(projection_model_state_dict)
|
|
|
|
attention_head_dim = model_config["embed_dim"] // model_config["num_heads"]
|
|
with ctx():
|
|
model = StableAudioDiTModel(
|
|
sample_size=int(config_dict["sample_size"])
|
|
/ int(config_dict["model"]["pretransform"]["config"]["downsampling_ratio"]),
|
|
in_channels=model_config["io_channels"],
|
|
num_layers=model_config["depth"],
|
|
attention_head_dim=attention_head_dim,
|
|
num_key_value_attention_heads=model_config["cond_token_dim"] // attention_head_dim,
|
|
num_attention_heads=model_config["num_heads"],
|
|
out_channels=model_config["io_channels"],
|
|
cross_attention_dim=model_config["cond_token_dim"],
|
|
time_proj_dim=256,
|
|
global_states_input_dim=model_config["global_cond_dim"],
|
|
cross_attention_input_dim=model_config["cond_token_dim"],
|
|
)
|
|
if is_accelerate_available():
|
|
load_model_dict_into_meta(model, model_state_dict)
|
|
else:
|
|
model.load_state_dict(model_state_dict)
|
|
|
|
|
|
autoencoder_config = config_dict["model"]["pretransform"]["config"]
|
|
with ctx():
|
|
autoencoder = AutoencoderOobleck(
|
|
encoder_hidden_size=autoencoder_config["encoder"]["config"]["channels"],
|
|
downsampling_ratios=autoencoder_config["encoder"]["config"]["strides"],
|
|
decoder_channels=autoencoder_config["decoder"]["config"]["channels"],
|
|
decoder_input_channels=autoencoder_config["decoder"]["config"]["latent_dim"],
|
|
audio_channels=autoencoder_config["io_channels"],
|
|
channel_multiples=autoencoder_config["encoder"]["config"]["c_mults"],
|
|
sampling_rate=config_dict["sample_rate"],
|
|
)
|
|
|
|
if is_accelerate_available():
|
|
load_model_dict_into_meta(autoencoder, autoencoder_state_dict)
|
|
else:
|
|
autoencoder.load_state_dict(autoencoder_state_dict)
|
|
|
|
|
|
# Prior pipeline
|
|
pipeline = StableAudioPipeline(
|
|
transformer=model,
|
|
tokenizer=tokenizer,
|
|
text_encoder=text_encoder,
|
|
scheduler=scheduler,
|
|
vae=autoencoder,
|
|
projection_model=projection_model,
|
|
)
|
|
pipeline.to(dtype).save_pretrained(
|
|
args.save_directory, repo_id=args.repo_id, push_to_hub=args.push_to_hub, variant=args.variant
|
|
)
|