Compare commits

...

18 Commits

Author SHA1 Message Date
Dhruv Nair
046ae26d1e Merge branch 'convert-cascade-lite' of https://github.com/huggingface/diffusers into convert-cascade-lite 2024-03-13 06:37:31 +00:00
Dhruv Nair
6308b16aba update 2024-03-13 06:37:16 +00:00
Sayak Paul
1b00c4e554 Merge branch 'main' into convert-cascade-lite 2024-03-13 11:35:20 +05:30
Dhruv Nair
4b48de0c41 update 2024-03-13 05:35:21 +00:00
Dhruv Nair
832071b468 update 2024-03-13 05:23:54 +00:00
Dhruv Nair
12390a204c Merge branch 'convert-cascade-lite' of https://github.com/huggingface/diffusers into convert-cascade-lite 2024-03-13 04:15:35 +00:00
Dhruv Nair
6732aa3a28 Merge branch 'main' into convert-cascade-lite 2024-03-13 04:10:07 +00:00
Sayak Paul
aad9d41c54 Merge branch 'main' into convert-cascade-lite 2024-03-13 08:37:41 +05:30
Sayak Paul
8e060da4ef Merge branch 'main' into convert-cascade-lite 2024-03-13 08:14:49 +05:30
Sayak Paul
86daa3518d Merge branch 'main' into convert-cascade-lite 2024-03-12 21:15:42 +05:30
Dhruv Nair
b82edd9404 update 2024-03-11 14:29:20 +00:00
Dhruv Nair
7119e087ba Merge branch 'convert-cascade-lite' of https://github.com/huggingface/diffusers into convert-cascade-lite 2024-03-11 13:41:42 +00:00
Dhruv Nair
962ceec438 update 2024-03-11 13:20:31 +00:00
Sayak Paul
4a00ef579a Merge branch 'main' into convert-cascade-lite 2024-03-11 17:12:55 +05:30
Dhruv Nair
5cc31912a2 update 2024-03-11 07:19:20 +00:00
Dhruv Nair
1f4120eed1 update 2024-03-11 03:23:44 +00:00
Dhruv Nair
068abab096 update 2024-03-11 03:20:21 +00:00
Dhruv Nair
872799e029 update 2024-03-08 14:46:29 +00:00
2 changed files with 389 additions and 160 deletions

View File

@@ -1,7 +1,7 @@
# Run this script to convert the Stable Cascade model weights to a diffusers pipeline.
import argparse
from contextlib import nullcontext
import accelerate
import torch
from safetensors.torch import load_file
from transformers import (
@@ -18,23 +18,56 @@ from diffusers import (
StableCascadeDecoderPipeline,
StableCascadePriorPipeline,
)
from diffusers.loaders.single_file_utils import convert_stable_cascade_unet_single_file_to_diffusers
from diffusers.models import StableCascadeUNet
from diffusers.models.modeling_utils import load_model_dict_into_meta
from diffusers.pipelines.wuerstchen import PaellaVQModel
from diffusers.utils import is_accelerate_available
if is_accelerate_available():
from accelerate import init_empty_weights
parser = argparse.ArgumentParser(description="Convert Stable Cascade model weights to a diffusers pipeline")
parser.add_argument("--model_path", type=str, default="../StableCascade", help="Location of Stable Cascade weights")
parser.add_argument("--model_path", type=str, help="Location of Stable Cascade weights")
parser.add_argument("--stage_c_name", type=str, default="stage_c.safetensors", help="Name of stage c checkpoint file")
parser.add_argument("--stage_b_name", type=str, default="stage_b.safetensors", help="Name of stage b checkpoint file")
parser.add_argument("--skip_stage_c", action="store_true", help="Skip converting stage c")
parser.add_argument("--skip_stage_b", action="store_true", help="Skip converting stage b")
parser.add_argument("--use_safetensors", action="store_true", help="Use SafeTensors for conversion")
parser.add_argument("--save_org", type=str, default="diffusers", help="Hub organization to save the pipelines to")
parser.add_argument(
"--prior_output_path", default="stable-cascade-prior", type=str, help="Hub organization to save the pipelines to"
)
parser.add_argument(
"--decoder_output_path",
type=str,
default="stable-cascade-decoder",
help="Hub organization to save the pipelines to",
)
parser.add_argument(
"--combined_output_path",
type=str,
default="stable-cascade-combined",
help="Hub organization to save the pipelines to",
)
parser.add_argument("--save_combined", action="store_true")
parser.add_argument("--push_to_hub", action="store_true", help="Push to hub")
parser.add_argument("--variant", type=str, help="Set to bf16 to save bfloat16 weights")
args = parser.parse_args()
if args.skip_stage_b and args.skip_stage_c:
raise ValueError("At least one stage should be converted")
if (args.skip_stage_b or args.skip_stage_c) and args.save_combined:
raise ValueError("Cannot skip stages when creating a combined pipeline")
model_path = args.model_path
device = "cpu"
if args.variant == "bf16":
dtype = torch.bfloat16
else:
dtype = torch.float32
# set paths to model weights
prior_checkpoint_path = f"{model_path}/{args.stage_c_name}"
@@ -52,164 +85,134 @@ tokenizer = AutoTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b1
feature_extractor = CLIPImageProcessor()
image_encoder = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14")
# Prior
if args.use_safetensors:
orig_state_dict = load_file(prior_checkpoint_path, device=device)
else:
orig_state_dict = torch.load(prior_checkpoint_path, map_location=device)
state_dict = {}
for key in orig_state_dict.keys():
if key.endswith("in_proj_weight"):
weights = orig_state_dict[key].chunk(3, 0)
state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = weights[0]
state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = weights[1]
state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = weights[2]
elif key.endswith("in_proj_bias"):
weights = orig_state_dict[key].chunk(3, 0)
state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = weights[0]
state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = weights[1]
state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = weights[2]
elif key.endswith("out_proj.weight"):
weights = orig_state_dict[key]
state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights
elif key.endswith("out_proj.bias"):
weights = orig_state_dict[key]
state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights
else:
state_dict[key] = orig_state_dict[key]
with accelerate.init_empty_weights():
prior_model = StableCascadeUNet(
in_channels=16,
out_channels=16,
timestep_ratio_embedding_dim=64,
patch_size=1,
conditioning_dim=2048,
block_out_channels=[2048, 2048],
num_attention_heads=[32, 32],
down_num_layers_per_block=[8, 24],
up_num_layers_per_block=[24, 8],
down_blocks_repeat_mappers=[1, 1],
up_blocks_repeat_mappers=[1, 1],
block_types_per_layer=[
["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
],
clip_text_in_channels=1280,
clip_text_pooled_in_channels=1280,
clip_image_in_channels=768,
clip_seq=4,
kernel_size=3,
dropout=[0.1, 0.1],
self_attn=True,
timestep_conditioning_type=["sca", "crp"],
switch_level=[False],
)
load_model_dict_into_meta(prior_model, state_dict)
# scheduler for prior and decoder
scheduler = DDPMWuerstchenScheduler()
ctx = init_empty_weights if is_accelerate_available() else nullcontext
# Prior pipeline
prior_pipeline = StableCascadePriorPipeline(
prior=prior_model,
tokenizer=tokenizer,
text_encoder=text_encoder,
image_encoder=image_encoder,
scheduler=scheduler,
feature_extractor=feature_extractor,
)
prior_pipeline.save_pretrained(f"{args.save_org}/StableCascade-prior", push_to_hub=args.push_to_hub)
# Decoder
if args.use_safetensors:
orig_state_dict = load_file(decoder_checkpoint_path, device=device)
else:
orig_state_dict = torch.load(decoder_checkpoint_path, map_location=device)
state_dict = {}
for key in orig_state_dict.keys():
if key.endswith("in_proj_weight"):
weights = orig_state_dict[key].chunk(3, 0)
state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = weights[0]
state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = weights[1]
state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = weights[2]
elif key.endswith("in_proj_bias"):
weights = orig_state_dict[key].chunk(3, 0)
state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = weights[0]
state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = weights[1]
state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = weights[2]
elif key.endswith("out_proj.weight"):
weights = orig_state_dict[key]
state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights
elif key.endswith("out_proj.bias"):
weights = orig_state_dict[key]
state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights
# rename clip_mapper to clip_txt_pooled_mapper
elif key.endswith("clip_mapper.weight"):
weights = orig_state_dict[key]
state_dict[key.replace("clip_mapper.weight", "clip_txt_pooled_mapper.weight")] = weights
elif key.endswith("clip_mapper.bias"):
weights = orig_state_dict[key]
state_dict[key.replace("clip_mapper.bias", "clip_txt_pooled_mapper.bias")] = weights
else:
state_dict[key] = orig_state_dict[key]
with accelerate.init_empty_weights():
decoder = StableCascadeUNet(
in_channels=4,
out_channels=4,
timestep_ratio_embedding_dim=64,
patch_size=2,
conditioning_dim=1280,
block_out_channels=[320, 640, 1280, 1280],
down_num_layers_per_block=[2, 6, 28, 6],
up_num_layers_per_block=[6, 28, 6, 2],
down_blocks_repeat_mappers=[1, 1, 1, 1],
up_blocks_repeat_mappers=[3, 3, 2, 2],
num_attention_heads=[0, 0, 20, 20],
block_types_per_layer=[
["SDCascadeResBlock", "SDCascadeTimestepBlock"],
["SDCascadeResBlock", "SDCascadeTimestepBlock"],
["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
],
clip_text_pooled_in_channels=1280,
clip_seq=4,
effnet_in_channels=16,
pixel_mapper_in_channels=3,
kernel_size=3,
dropout=[0, 0, 0.1, 0.1],
self_attn=True,
timestep_conditioning_type=["sca"],
)
load_model_dict_into_meta(decoder, state_dict)
# VQGAN from Wuerstchen-V2
vqmodel = PaellaVQModel.from_pretrained("warp-ai/wuerstchen", subfolder="vqgan")
# Decoder pipeline
decoder_pipeline = StableCascadeDecoderPipeline(
decoder=decoder, text_encoder=text_encoder, tokenizer=tokenizer, vqgan=vqmodel, scheduler=scheduler
)
decoder_pipeline.save_pretrained(f"{args.save_org}/StableCascade-decoder", push_to_hub=args.push_to_hub)
# Stable Cascade combined pipeline
stable_cascade_pipeline = StableCascadeCombinedPipeline(
# Decoder
text_encoder=text_encoder,
tokenizer=tokenizer,
decoder=decoder,
scheduler=scheduler,
vqgan=vqmodel,
if not args.skip_stage_c:
# Prior
prior_text_encoder=text_encoder,
prior_tokenizer=tokenizer,
prior_prior=prior_model,
prior_scheduler=scheduler,
prior_image_encoder=image_encoder,
prior_feature_extractor=feature_extractor,
)
stable_cascade_pipeline.save_pretrained(f"{args.save_org}/StableCascade", push_to_hub=args.push_to_hub)
if args.use_safetensors:
prior_orig_state_dict = load_file(prior_checkpoint_path, device=device)
else:
prior_orig_state_dict = torch.load(prior_checkpoint_path, map_location=device)
prior_state_dict = convert_stable_cascade_unet_single_file_to_diffusers(prior_orig_state_dict)
with ctx():
prior_model = StableCascadeUNet(
in_channels=16,
out_channels=16,
timestep_ratio_embedding_dim=64,
patch_size=1,
conditioning_dim=2048,
block_out_channels=[2048, 2048],
num_attention_heads=[32, 32],
down_num_layers_per_block=[8, 24],
up_num_layers_per_block=[24, 8],
down_blocks_repeat_mappers=[1, 1],
up_blocks_repeat_mappers=[1, 1],
block_types_per_layer=[
["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
],
clip_text_in_channels=1280,
clip_text_pooled_in_channels=1280,
clip_image_in_channels=768,
clip_seq=4,
kernel_size=3,
dropout=[0.1, 0.1],
self_attn=True,
timestep_conditioning_type=["sca", "crp"],
switch_level=[False],
)
if is_accelerate_available():
load_model_dict_into_meta(prior_model, prior_state_dict)
else:
prior_model.load_state_dict(prior_state_dict)
# Prior pipeline
prior_pipeline = StableCascadePriorPipeline(
prior=prior_model,
tokenizer=tokenizer,
text_encoder=text_encoder,
image_encoder=image_encoder,
scheduler=scheduler,
feature_extractor=feature_extractor,
)
prior_pipeline.to(dtype).save_pretrained(
args.prior_output_path, push_to_hub=args.push_to_hub, variant=args.variant
)
if not args.skip_stage_b:
# Decoder
if args.use_safetensors:
decoder_orig_state_dict = load_file(decoder_checkpoint_path, device=device)
else:
decoder_orig_state_dict = torch.load(decoder_checkpoint_path, map_location=device)
decoder_state_dict = convert_stable_cascade_unet_single_file_to_diffusers(decoder_orig_state_dict)
with ctx():
decoder = StableCascadeUNet(
in_channels=4,
out_channels=4,
timestep_ratio_embedding_dim=64,
patch_size=2,
conditioning_dim=1280,
block_out_channels=[320, 640, 1280, 1280],
down_num_layers_per_block=[2, 6, 28, 6],
up_num_layers_per_block=[6, 28, 6, 2],
down_blocks_repeat_mappers=[1, 1, 1, 1],
up_blocks_repeat_mappers=[3, 3, 2, 2],
num_attention_heads=[0, 0, 20, 20],
block_types_per_layer=[
["SDCascadeResBlock", "SDCascadeTimestepBlock"],
["SDCascadeResBlock", "SDCascadeTimestepBlock"],
["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
],
clip_text_pooled_in_channels=1280,
clip_seq=4,
effnet_in_channels=16,
pixel_mapper_in_channels=3,
kernel_size=3,
dropout=[0, 0, 0.1, 0.1],
self_attn=True,
timestep_conditioning_type=["sca"],
)
if is_accelerate_available():
load_model_dict_into_meta(decoder, decoder_state_dict)
else:
decoder.load_state_dict(decoder_state_dict)
# VQGAN from Wuerstchen-V2
vqmodel = PaellaVQModel.from_pretrained("warp-ai/wuerstchen", subfolder="vqgan")
# Decoder pipeline
decoder_pipeline = StableCascadeDecoderPipeline(
decoder=decoder, text_encoder=text_encoder, tokenizer=tokenizer, vqgan=vqmodel, scheduler=scheduler
)
decoder_pipeline.to(dtype).save_pretrained(
args.decoder_output_path, push_to_hub=args.push_to_hub, variant=args.variant
)
if args.save_combined:
# Stable Cascade combined pipeline
stable_cascade_pipeline = StableCascadeCombinedPipeline(
# Decoder
text_encoder=text_encoder,
tokenizer=tokenizer,
decoder=decoder,
scheduler=scheduler,
vqgan=vqmodel,
# Prior
prior_text_encoder=text_encoder,
prior_tokenizer=tokenizer,
prior_prior=prior_model,
prior_scheduler=scheduler,
prior_image_encoder=image_encoder,
prior_feature_extractor=feature_extractor,
)
stable_cascade_pipeline.to(dtype).save_pretrained(
args.combined_output_path, push_to_hub=args.push_to_hub, variant=args.variant
)

View File

@@ -0,0 +1,226 @@
# Run this script to convert the Stable Cascade model weights to a diffusers pipeline.
import argparse
from contextlib import nullcontext
import torch
from safetensors.torch import load_file
from transformers import (
AutoTokenizer,
CLIPConfig,
CLIPImageProcessor,
CLIPTextModelWithProjection,
CLIPVisionModelWithProjection,
)
from diffusers import (
DDPMWuerstchenScheduler,
StableCascadeCombinedPipeline,
StableCascadeDecoderPipeline,
StableCascadePriorPipeline,
)
from diffusers.loaders.single_file_utils import convert_stable_cascade_unet_single_file_to_diffusers
from diffusers.models import StableCascadeUNet
from diffusers.models.modeling_utils import load_model_dict_into_meta
from diffusers.pipelines.wuerstchen import PaellaVQModel
from diffusers.utils import is_accelerate_available
if is_accelerate_available():
from accelerate import init_empty_weights
parser = argparse.ArgumentParser(description="Convert Stable Cascade model weights to a diffusers pipeline")
parser.add_argument("--model_path", type=str, help="Location of Stable Cascade weights")
parser.add_argument(
"--stage_c_name", type=str, default="stage_c_lite.safetensors", help="Name of stage c checkpoint file"
)
parser.add_argument(
"--stage_b_name", type=str, default="stage_b_lite.safetensors", help="Name of stage b checkpoint file"
)
parser.add_argument("--skip_stage_c", action="store_true", help="Skip converting stage c")
parser.add_argument("--skip_stage_b", action="store_true", help="Skip converting stage b")
parser.add_argument("--use_safetensors", action="store_true", help="Use SafeTensors for conversion")
parser.add_argument(
"--prior_output_path",
default="stable-cascade-prior-lite",
type=str,
help="Hub organization to save the pipelines to",
)
parser.add_argument(
"--decoder_output_path",
type=str,
default="stable-cascade-decoder-lite",
help="Hub organization to save the pipelines to",
)
parser.add_argument(
"--combined_output_path",
type=str,
default="stable-cascade-combined-lite",
help="Hub organization to save the pipelines to",
)
parser.add_argument("--save_combined", action="store_true")
parser.add_argument("--push_to_hub", action="store_true", help="Push to hub")
parser.add_argument("--variant", type=str, help="Set to bf16 to save bfloat16 weights")
args = parser.parse_args()
if args.skip_stage_b and args.skip_stage_c:
raise ValueError("At least one stage should be converted")
if (args.skip_stage_b or args.skip_stage_c) and args.save_combined:
raise ValueError("Cannot skip stages when creating a combined pipeline")
model_path = args.model_path
device = "cpu"
if args.variant == "bf16":
dtype = torch.bfloat16
else:
dtype = torch.float32
# set paths to model weights
prior_checkpoint_path = f"{model_path}/{args.stage_c_name}"
decoder_checkpoint_path = f"{model_path}/{args.stage_b_name}"
# Clip Text encoder and tokenizer
config = CLIPConfig.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k")
config.text_config.projection_dim = config.projection_dim
text_encoder = CLIPTextModelWithProjection.from_pretrained(
"laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", config=config.text_config
)
tokenizer = AutoTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k")
# image processor
feature_extractor = CLIPImageProcessor()
image_encoder = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14")
# scheduler for prior and decoder
scheduler = DDPMWuerstchenScheduler()
ctx = init_empty_weights if is_accelerate_available() else nullcontext
if not args.skip_stage_c:
# Prior
if args.use_safetensors:
prior_orig_state_dict = load_file(prior_checkpoint_path, device=device)
else:
prior_orig_state_dict = torch.load(prior_checkpoint_path, map_location=device)
prior_state_dict = convert_stable_cascade_unet_single_file_to_diffusers(prior_orig_state_dict)
with ctx():
prior_model = StableCascadeUNet(
in_channels=16,
out_channels=16,
timestep_ratio_embedding_dim=64,
patch_size=1,
conditioning_dim=1536,
block_out_channels=[1536, 1536],
num_attention_heads=[24, 24],
down_num_layers_per_block=[4, 12],
up_num_layers_per_block=[12, 4],
down_blocks_repeat_mappers=[1, 1],
up_blocks_repeat_mappers=[1, 1],
block_types_per_layer=[
["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
],
clip_text_in_channels=1280,
clip_text_pooled_in_channels=1280,
clip_image_in_channels=768,
clip_seq=4,
kernel_size=3,
dropout=[0.1, 0.1],
self_attn=True,
timestep_conditioning_type=["sca", "crp"],
switch_level=[False],
)
if is_accelerate_available():
load_model_dict_into_meta(prior_model, prior_state_dict)
else:
prior_model.load_state_dict(prior_state_dict)
# Prior pipeline
prior_pipeline = StableCascadePriorPipeline(
prior=prior_model,
tokenizer=tokenizer,
text_encoder=text_encoder,
image_encoder=image_encoder,
scheduler=scheduler,
feature_extractor=feature_extractor,
)
prior_pipeline.to(dtype).save_pretrained(
args.prior_output_path, push_to_hub=args.push_to_hub, variant=args.variant
)
if not args.skip_stage_b:
# Decoder
if args.use_safetensors:
decoder_orig_state_dict = load_file(decoder_checkpoint_path, device=device)
else:
decoder_orig_state_dict = torch.load(decoder_checkpoint_path, map_location=device)
decoder_state_dict = convert_stable_cascade_unet_single_file_to_diffusers(decoder_orig_state_dict)
with ctx():
decoder = StableCascadeUNet(
in_channels=4,
out_channels=4,
timestep_ratio_embedding_dim=64,
patch_size=2,
conditioning_dim=1280,
block_out_channels=[320, 576, 1152, 1152],
down_num_layers_per_block=[2, 4, 14, 4],
up_num_layers_per_block=[4, 14, 4, 2],
down_blocks_repeat_mappers=[1, 1, 1, 1],
up_blocks_repeat_mappers=[2, 2, 2, 2],
num_attention_heads=[0, 9, 18, 18],
block_types_per_layer=[
["SDCascadeResBlock", "SDCascadeTimestepBlock"],
["SDCascadeResBlock", "SDCascadeTimestepBlock"],
["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
],
clip_text_pooled_in_channels=1280,
clip_seq=4,
effnet_in_channels=16,
pixel_mapper_in_channels=3,
kernel_size=3,
dropout=[0, 0, 0.1, 0.1],
self_attn=True,
timestep_conditioning_type=["sca"],
)
if is_accelerate_available():
load_model_dict_into_meta(decoder, decoder_state_dict)
else:
decoder.load_state_dict(decoder_state_dict)
# VQGAN from Wuerstchen-V2
vqmodel = PaellaVQModel.from_pretrained("warp-ai/wuerstchen", subfolder="vqgan")
# Decoder pipeline
decoder_pipeline = StableCascadeDecoderPipeline(
decoder=decoder, text_encoder=text_encoder, tokenizer=tokenizer, vqgan=vqmodel, scheduler=scheduler
)
decoder_pipeline.to(dtype).save_pretrained(
args.decoder_output_path, push_to_hub=args.push_to_hub, variant=args.variant
)
if args.save_combined:
# Stable Cascade combined pipeline
stable_cascade_pipeline = StableCascadeCombinedPipeline(
# Decoder
text_encoder=text_encoder,
tokenizer=tokenizer,
decoder=decoder,
scheduler=scheduler,
vqgan=vqmodel,
# Prior
prior_text_encoder=text_encoder,
prior_tokenizer=tokenizer,
prior_prior=prior_model,
prior_scheduler=scheduler,
prior_image_encoder=image_encoder,
prior_feature_extractor=feature_extractor,
)
stable_cascade_pipeline.to(dtype).save_pretrained(
args.combined_output_path, push_to_hub=args.push_to_hub, variant=args.variant
)