mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-06 12:34:13 +08:00
The function load_model_dict_into_meta was moved from modeling_utils.py to model_loading_utils.py but the imports in the conversion scripts were not updated, causing ImportError when running these scripts. This fixes the import in 6 conversion scripts: - scripts/convert_sd3_to_diffusers.py - scripts/convert_stable_cascade_lite.py - scripts/convert_stable_cascade.py - scripts/convert_stable_audio.py - scripts/convert_sana_to_diffusers.py - scripts/convert_sana_controlnet_to_diffusers.py Fixes #12606
227 lines
8.2 KiB
Python
227 lines
8.2 KiB
Python
# Run this script to convert the Stable Cascade model weights to a diffusers pipeline.
|
|
import argparse
|
|
from contextlib import nullcontext
|
|
|
|
import torch
|
|
from safetensors.torch import load_file
|
|
from transformers import (
|
|
AutoTokenizer,
|
|
CLIPConfig,
|
|
CLIPImageProcessor,
|
|
CLIPTextModelWithProjection,
|
|
CLIPVisionModelWithProjection,
|
|
)
|
|
|
|
from diffusers import (
|
|
DDPMWuerstchenScheduler,
|
|
StableCascadeCombinedPipeline,
|
|
StableCascadeDecoderPipeline,
|
|
StableCascadePriorPipeline,
|
|
)
|
|
from diffusers.loaders.single_file_utils import convert_stable_cascade_unet_single_file_to_diffusers
|
|
from diffusers.models import StableCascadeUNet
|
|
from diffusers.models.model_loading_utils import load_model_dict_into_meta
|
|
from diffusers.pipelines.wuerstchen import PaellaVQModel
|
|
from diffusers.utils import is_accelerate_available
|
|
|
|
|
|
if is_accelerate_available():
|
|
from accelerate import init_empty_weights
|
|
|
|
parser = argparse.ArgumentParser(description="Convert Stable Cascade model weights to a diffusers pipeline")
|
|
parser.add_argument("--model_path", type=str, help="Location of Stable Cascade weights")
|
|
parser.add_argument(
|
|
"--stage_c_name", type=str, default="stage_c_lite.safetensors", help="Name of stage c checkpoint file"
|
|
)
|
|
parser.add_argument(
|
|
"--stage_b_name", type=str, default="stage_b_lite.safetensors", help="Name of stage b checkpoint file"
|
|
)
|
|
parser.add_argument("--skip_stage_c", action="store_true", help="Skip converting stage c")
|
|
parser.add_argument("--skip_stage_b", action="store_true", help="Skip converting stage b")
|
|
parser.add_argument("--use_safetensors", action="store_true", help="Use SafeTensors for conversion")
|
|
parser.add_argument(
|
|
"--prior_output_path",
|
|
default="stable-cascade-prior-lite",
|
|
type=str,
|
|
help="Hub organization to save the pipelines to",
|
|
)
|
|
parser.add_argument(
|
|
"--decoder_output_path",
|
|
type=str,
|
|
default="stable-cascade-decoder-lite",
|
|
help="Hub organization to save the pipelines to",
|
|
)
|
|
parser.add_argument(
|
|
"--combined_output_path",
|
|
type=str,
|
|
default="stable-cascade-combined-lite",
|
|
help="Hub organization to save the pipelines to",
|
|
)
|
|
parser.add_argument("--save_combined", action="store_true")
|
|
parser.add_argument("--push_to_hub", action="store_true", help="Push to hub")
|
|
parser.add_argument("--variant", type=str, help="Set to bf16 to save bfloat16 weights")
|
|
|
|
args = parser.parse_args()
|
|
|
|
if args.skip_stage_b and args.skip_stage_c:
|
|
raise ValueError("At least one stage should be converted")
|
|
if (args.skip_stage_b or args.skip_stage_c) and args.save_combined:
|
|
raise ValueError("Cannot skip stages when creating a combined pipeline")
|
|
|
|
model_path = args.model_path
|
|
|
|
device = "cpu"
|
|
if args.variant == "bf16":
|
|
dtype = torch.bfloat16
|
|
else:
|
|
dtype = torch.float32
|
|
|
|
# set paths to model weights
|
|
prior_checkpoint_path = f"{model_path}/{args.stage_c_name}"
|
|
decoder_checkpoint_path = f"{model_path}/{args.stage_b_name}"
|
|
|
|
# Clip Text encoder and tokenizer
|
|
config = CLIPConfig.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k")
|
|
config.text_config.projection_dim = config.projection_dim
|
|
text_encoder = CLIPTextModelWithProjection.from_pretrained(
|
|
"laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", config=config.text_config
|
|
)
|
|
tokenizer = AutoTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k")
|
|
|
|
# image processor
|
|
feature_extractor = CLIPImageProcessor()
|
|
image_encoder = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14")
|
|
# scheduler for prior and decoder
|
|
scheduler = DDPMWuerstchenScheduler()
|
|
|
|
ctx = init_empty_weights if is_accelerate_available() else nullcontext
|
|
|
|
if not args.skip_stage_c:
|
|
# Prior
|
|
if args.use_safetensors:
|
|
prior_orig_state_dict = load_file(prior_checkpoint_path, device=device)
|
|
else:
|
|
prior_orig_state_dict = torch.load(prior_checkpoint_path, map_location=device)
|
|
|
|
prior_state_dict = convert_stable_cascade_unet_single_file_to_diffusers(prior_orig_state_dict)
|
|
with ctx():
|
|
prior_model = StableCascadeUNet(
|
|
in_channels=16,
|
|
out_channels=16,
|
|
timestep_ratio_embedding_dim=64,
|
|
patch_size=1,
|
|
conditioning_dim=1536,
|
|
block_out_channels=[1536, 1536],
|
|
num_attention_heads=[24, 24],
|
|
down_num_layers_per_block=[4, 12],
|
|
up_num_layers_per_block=[12, 4],
|
|
down_blocks_repeat_mappers=[1, 1],
|
|
up_blocks_repeat_mappers=[1, 1],
|
|
block_types_per_layer=[
|
|
["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
|
|
["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
|
|
],
|
|
clip_text_in_channels=1280,
|
|
clip_text_pooled_in_channels=1280,
|
|
clip_image_in_channels=768,
|
|
clip_seq=4,
|
|
kernel_size=3,
|
|
dropout=[0.1, 0.1],
|
|
self_attn=True,
|
|
timestep_conditioning_type=["sca", "crp"],
|
|
switch_level=[False],
|
|
)
|
|
|
|
if is_accelerate_available():
|
|
load_model_dict_into_meta(prior_model, prior_state_dict)
|
|
else:
|
|
prior_model.load_state_dict(prior_state_dict)
|
|
|
|
# Prior pipeline
|
|
prior_pipeline = StableCascadePriorPipeline(
|
|
prior=prior_model,
|
|
tokenizer=tokenizer,
|
|
text_encoder=text_encoder,
|
|
image_encoder=image_encoder,
|
|
scheduler=scheduler,
|
|
feature_extractor=feature_extractor,
|
|
)
|
|
prior_pipeline.to(dtype).save_pretrained(
|
|
args.prior_output_path, push_to_hub=args.push_to_hub, variant=args.variant
|
|
)
|
|
|
|
if not args.skip_stage_b:
|
|
# Decoder
|
|
if args.use_safetensors:
|
|
decoder_orig_state_dict = load_file(decoder_checkpoint_path, device=device)
|
|
else:
|
|
decoder_orig_state_dict = torch.load(decoder_checkpoint_path, map_location=device)
|
|
|
|
decoder_state_dict = convert_stable_cascade_unet_single_file_to_diffusers(decoder_orig_state_dict)
|
|
|
|
with ctx():
|
|
decoder = StableCascadeUNet(
|
|
in_channels=4,
|
|
out_channels=4,
|
|
timestep_ratio_embedding_dim=64,
|
|
patch_size=2,
|
|
conditioning_dim=1280,
|
|
block_out_channels=[320, 576, 1152, 1152],
|
|
down_num_layers_per_block=[2, 4, 14, 4],
|
|
up_num_layers_per_block=[4, 14, 4, 2],
|
|
down_blocks_repeat_mappers=[1, 1, 1, 1],
|
|
up_blocks_repeat_mappers=[2, 2, 2, 2],
|
|
num_attention_heads=[0, 9, 18, 18],
|
|
block_types_per_layer=[
|
|
["SDCascadeResBlock", "SDCascadeTimestepBlock"],
|
|
["SDCascadeResBlock", "SDCascadeTimestepBlock"],
|
|
["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
|
|
["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
|
|
],
|
|
clip_text_pooled_in_channels=1280,
|
|
clip_seq=4,
|
|
effnet_in_channels=16,
|
|
pixel_mapper_in_channels=3,
|
|
kernel_size=3,
|
|
dropout=[0, 0, 0.1, 0.1],
|
|
self_attn=True,
|
|
timestep_conditioning_type=["sca"],
|
|
)
|
|
|
|
if is_accelerate_available():
|
|
load_model_dict_into_meta(decoder, decoder_state_dict)
|
|
else:
|
|
decoder.load_state_dict(decoder_state_dict)
|
|
|
|
# VQGAN from Wuerstchen-V2
|
|
vqmodel = PaellaVQModel.from_pretrained("warp-ai/wuerstchen", subfolder="vqgan")
|
|
|
|
# Decoder pipeline
|
|
decoder_pipeline = StableCascadeDecoderPipeline(
|
|
decoder=decoder, text_encoder=text_encoder, tokenizer=tokenizer, vqgan=vqmodel, scheduler=scheduler
|
|
)
|
|
decoder_pipeline.to(dtype).save_pretrained(
|
|
args.decoder_output_path, push_to_hub=args.push_to_hub, variant=args.variant
|
|
)
|
|
|
|
if args.save_combined:
|
|
# Stable Cascade combined pipeline
|
|
stable_cascade_pipeline = StableCascadeCombinedPipeline(
|
|
# Decoder
|
|
text_encoder=text_encoder,
|
|
tokenizer=tokenizer,
|
|
decoder=decoder,
|
|
scheduler=scheduler,
|
|
vqgan=vqmodel,
|
|
# Prior
|
|
prior_text_encoder=text_encoder,
|
|
prior_tokenizer=tokenizer,
|
|
prior_prior=prior_model,
|
|
prior_scheduler=scheduler,
|
|
prior_image_encoder=image_encoder,
|
|
prior_feature_extractor=feature_extractor,
|
|
)
|
|
stable_cascade_pipeline.to(dtype).save_pretrained(
|
|
args.combined_output_path, push_to_hub=args.push_to_hub, variant=args.variant
|
|
)
|