mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-06 12:34:13 +08:00
* Add VAE slicing and tiling methods. * Switch to using VaeImageProcessing for preprocessing and postprocessing of images. * Rename the VaeImageProcessor to vae_image_processor to avoid a name clash with the CLIPImageProcessor (image_processor). * Remove the postprocess() function because we're using a VaeImageProcessor instead. * Remove UniDiffuserPipeline.decode_image_latents because we're using VaeImageProcessor instead. * Refactor generating text from text latents into a decode_text_latents method. * Add enable_full_determinism() to UniDiffuser tests. * make style * Add PipelineLatentTesterMixin to UniDiffuserPipelineFastTests. * Remove enable_model_cpu_offload since it is now part of DiffusionPipeline. * Rename the VaeImageProcessor instance to self.image_processor for consistency with other pipelines and rename the CLIPImageProcessor instance to clip_image_processor to avoid a name clash. * Update UniDiffuser conversion script. * Make safe_serialization configurable in UniDiffuser conversion script. * Rename image_processor to clip_image_processor in UniDiffuser tests. * Add PipelineKarrasSchedulerTesterMixin to UniDiffuserPipelineFastTests. * Add initial test for compiling the UniDiffuser model (not tested yet). * Update encode_prompt and _encode_prompt to match that of StableDiffusionPipeline. * Turn off standard classifier-free guidance for now. * make style * make fix-copies * apply suggestions from review --------- Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
787 lines
31 KiB
Python
787 lines
31 KiB
Python
# Convert the original UniDiffuser checkpoints into diffusers equivalents.
|
|
|
|
import argparse
|
|
from argparse import Namespace
|
|
|
|
import torch
|
|
from transformers import (
|
|
CLIPImageProcessor,
|
|
CLIPTextConfig,
|
|
CLIPTextModel,
|
|
CLIPTokenizer,
|
|
CLIPVisionConfig,
|
|
CLIPVisionModelWithProjection,
|
|
GPT2Tokenizer,
|
|
)
|
|
|
|
from diffusers import (
|
|
AutoencoderKL,
|
|
DPMSolverMultistepScheduler,
|
|
UniDiffuserModel,
|
|
UniDiffuserPipeline,
|
|
UniDiffuserTextDecoder,
|
|
)
|
|
|
|
|
|
SCHEDULER_CONFIG = Namespace(
|
|
**{
|
|
"beta_start": 0.00085,
|
|
"beta_end": 0.012,
|
|
"beta_schedule": "scaled_linear",
|
|
"solver_order": 3,
|
|
}
|
|
)
|
|
|
|
|
|
# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.shave_segments
|
|
def shave_segments(path, n_shave_prefix_segments=1):
|
|
"""
|
|
Removes segments. Positive values shave the first segments, negative shave the last segments.
|
|
"""
|
|
if n_shave_prefix_segments >= 0:
|
|
return ".".join(path.split(".")[n_shave_prefix_segments:])
|
|
else:
|
|
return ".".join(path.split(".")[:n_shave_prefix_segments])
|
|
|
|
|
|
# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.renew_vae_resnet_paths
|
|
def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
|
|
"""
|
|
Updates paths inside resnets to the new naming scheme (local renaming)
|
|
"""
|
|
mapping = []
|
|
for old_item in old_list:
|
|
new_item = old_item
|
|
|
|
new_item = new_item.replace("nin_shortcut", "conv_shortcut")
|
|
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
|
|
|
mapping.append({"old": old_item, "new": new_item})
|
|
|
|
return mapping
|
|
|
|
|
|
# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.renew_vae_attention_paths
|
|
def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
|
|
"""
|
|
Updates paths inside attentions to the new naming scheme (local renaming)
|
|
"""
|
|
mapping = []
|
|
for old_item in old_list:
|
|
new_item = old_item
|
|
|
|
new_item = new_item.replace("norm.weight", "group_norm.weight")
|
|
new_item = new_item.replace("norm.bias", "group_norm.bias")
|
|
|
|
new_item = new_item.replace("q.weight", "to_q.weight")
|
|
new_item = new_item.replace("q.bias", "to_q.bias")
|
|
|
|
new_item = new_item.replace("k.weight", "to_k.weight")
|
|
new_item = new_item.replace("k.bias", "to_k.bias")
|
|
|
|
new_item = new_item.replace("v.weight", "to_v.weight")
|
|
new_item = new_item.replace("v.bias", "to_v.bias")
|
|
|
|
new_item = new_item.replace("proj_out.weight", "to_out.0.weight")
|
|
new_item = new_item.replace("proj_out.bias", "to_out.0.bias")
|
|
|
|
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
|
|
|
mapping.append({"old": old_item, "new": new_item})
|
|
|
|
return mapping
|
|
|
|
|
|
# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.conv_attn_to_linear
|
|
def conv_attn_to_linear(checkpoint):
|
|
keys = list(checkpoint.keys())
|
|
attn_keys = ["query.weight", "key.weight", "value.weight"]
|
|
for key in keys:
|
|
if ".".join(key.split(".")[-2:]) in attn_keys:
|
|
if checkpoint[key].ndim > 2:
|
|
checkpoint[key] = checkpoint[key][:, :, 0, 0]
|
|
elif "proj_attn.weight" in key:
|
|
if checkpoint[key].ndim > 2:
|
|
checkpoint[key] = checkpoint[key][:, :, 0]
|
|
|
|
|
|
# Modified from diffusers.pipelines.stable_diffusion.convert_from_ckpt.assign_to_checkpoint
|
|
# config.num_head_channels => num_head_channels
|
|
def assign_to_checkpoint(
|
|
paths,
|
|
checkpoint,
|
|
old_checkpoint,
|
|
attention_paths_to_split=None,
|
|
additional_replacements=None,
|
|
num_head_channels=1,
|
|
):
|
|
"""
|
|
This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits
|
|
attention layers, and takes into account additional replacements that may arise.
|
|
|
|
Assigns the weights to the new checkpoint.
|
|
"""
|
|
assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
|
|
|
|
# Splits the attention layers into three variables.
|
|
if attention_paths_to_split is not None:
|
|
for path, path_map in attention_paths_to_split.items():
|
|
old_tensor = old_checkpoint[path]
|
|
channels = old_tensor.shape[0] // 3
|
|
|
|
target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
|
|
|
|
num_heads = old_tensor.shape[0] // num_head_channels // 3
|
|
|
|
old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
|
|
query, key, value = old_tensor.split(channels // num_heads, dim=1)
|
|
|
|
checkpoint[path_map["query"]] = query.reshape(target_shape)
|
|
checkpoint[path_map["key"]] = key.reshape(target_shape)
|
|
checkpoint[path_map["value"]] = value.reshape(target_shape)
|
|
|
|
for path in paths:
|
|
new_path = path["new"]
|
|
|
|
# These have already been assigned
|
|
if attention_paths_to_split is not None and new_path in attention_paths_to_split:
|
|
continue
|
|
|
|
# Global renaming happens here
|
|
new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
|
|
new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
|
|
new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
|
|
|
|
if additional_replacements is not None:
|
|
for replacement in additional_replacements:
|
|
new_path = new_path.replace(replacement["old"], replacement["new"])
|
|
|
|
# proj_attn.weight has to be converted from conv 1D to linear
|
|
is_attn_weight = "proj_attn.weight" in new_path or ("attentions" in new_path and "to_" in new_path)
|
|
shape = old_checkpoint[path["old"]].shape
|
|
if is_attn_weight and len(shape) == 3:
|
|
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
|
|
elif is_attn_weight and len(shape) == 4:
|
|
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0, 0]
|
|
else:
|
|
checkpoint[new_path] = old_checkpoint[path["old"]]
|
|
|
|
|
|
def create_vae_diffusers_config(config_type):
|
|
# Hardcoded for now
|
|
if args.config_type == "test":
|
|
vae_config = create_vae_diffusers_config_test()
|
|
elif args.config_type == "big":
|
|
vae_config = create_vae_diffusers_config_big()
|
|
else:
|
|
raise NotImplementedError(
|
|
f"Config type {config_type} is not implemented, currently only config types"
|
|
" 'test' and 'big' are available."
|
|
)
|
|
return vae_config
|
|
|
|
|
|
def create_unidiffuser_unet_config(config_type, version):
|
|
# Hardcoded for now
|
|
if args.config_type == "test":
|
|
unet_config = create_unidiffuser_unet_config_test()
|
|
elif args.config_type == "big":
|
|
unet_config = create_unidiffuser_unet_config_big()
|
|
else:
|
|
raise NotImplementedError(
|
|
f"Config type {config_type} is not implemented, currently only config types"
|
|
" 'test' and 'big' are available."
|
|
)
|
|
# Unidiffuser-v1 uses data type embeddings
|
|
if version == 1:
|
|
unet_config["use_data_type_embedding"] = True
|
|
return unet_config
|
|
|
|
|
|
def create_text_decoder_config(config_type):
|
|
# Hardcoded for now
|
|
if args.config_type == "test":
|
|
text_decoder_config = create_text_decoder_config_test()
|
|
elif args.config_type == "big":
|
|
text_decoder_config = create_text_decoder_config_big()
|
|
else:
|
|
raise NotImplementedError(
|
|
f"Config type {config_type} is not implemented, currently only config types"
|
|
" 'test' and 'big' are available."
|
|
)
|
|
return text_decoder_config
|
|
|
|
|
|
# Hardcoded configs for test versions of the UniDiffuser models, corresponding to those in the fast default tests.
|
|
def create_vae_diffusers_config_test():
|
|
vae_config = {
|
|
"sample_size": 32,
|
|
"in_channels": 3,
|
|
"out_channels": 3,
|
|
"down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"],
|
|
"up_block_types": ["UpDecoderBlock2D", "UpDecoderBlock2D"],
|
|
"block_out_channels": [32, 64],
|
|
"latent_channels": 4,
|
|
"layers_per_block": 1,
|
|
}
|
|
return vae_config
|
|
|
|
|
|
def create_unidiffuser_unet_config_test():
|
|
unet_config = {
|
|
"text_dim": 32,
|
|
"clip_img_dim": 32,
|
|
"num_text_tokens": 77,
|
|
"num_attention_heads": 2,
|
|
"attention_head_dim": 8,
|
|
"in_channels": 4,
|
|
"out_channels": 4,
|
|
"num_layers": 2,
|
|
"dropout": 0.0,
|
|
"norm_num_groups": 32,
|
|
"attention_bias": False,
|
|
"sample_size": 16,
|
|
"patch_size": 2,
|
|
"activation_fn": "gelu",
|
|
"num_embeds_ada_norm": 1000,
|
|
"norm_type": "layer_norm",
|
|
"block_type": "unidiffuser",
|
|
"pre_layer_norm": False,
|
|
"use_timestep_embedding": False,
|
|
"norm_elementwise_affine": True,
|
|
"use_patch_pos_embed": False,
|
|
"ff_final_dropout": True,
|
|
"use_data_type_embedding": False,
|
|
}
|
|
return unet_config
|
|
|
|
|
|
def create_text_decoder_config_test():
|
|
text_decoder_config = {
|
|
"prefix_length": 77,
|
|
"prefix_inner_dim": 32,
|
|
"prefix_hidden_dim": 32,
|
|
"vocab_size": 1025, # 1024 + 1 for new EOS token
|
|
"n_positions": 1024,
|
|
"n_embd": 32,
|
|
"n_layer": 5,
|
|
"n_head": 4,
|
|
"n_inner": 37,
|
|
"activation_function": "gelu",
|
|
"resid_pdrop": 0.1,
|
|
"embd_pdrop": 0.1,
|
|
"attn_pdrop": 0.1,
|
|
"layer_norm_epsilon": 1e-5,
|
|
"initializer_range": 0.02,
|
|
}
|
|
return text_decoder_config
|
|
|
|
|
|
# Hardcoded configs for the UniDiffuser V1 model at https://huggingface.co/thu-ml/unidiffuser-v1
|
|
# See also https://github.com/thu-ml/unidiffuser/blob/main/configs/sample_unidiffuser_v1.py
|
|
def create_vae_diffusers_config_big():
|
|
vae_config = {
|
|
"sample_size": 256,
|
|
"in_channels": 3,
|
|
"out_channels": 3,
|
|
"down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D"],
|
|
"up_block_types": ["UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D"],
|
|
"block_out_channels": [128, 256, 512, 512],
|
|
"latent_channels": 4,
|
|
"layers_per_block": 2,
|
|
}
|
|
return vae_config
|
|
|
|
|
|
def create_unidiffuser_unet_config_big():
|
|
unet_config = {
|
|
"text_dim": 64,
|
|
"clip_img_dim": 512,
|
|
"num_text_tokens": 77,
|
|
"num_attention_heads": 24,
|
|
"attention_head_dim": 64,
|
|
"in_channels": 4,
|
|
"out_channels": 4,
|
|
"num_layers": 30,
|
|
"dropout": 0.0,
|
|
"norm_num_groups": 32,
|
|
"attention_bias": False,
|
|
"sample_size": 64,
|
|
"patch_size": 2,
|
|
"activation_fn": "gelu",
|
|
"num_embeds_ada_norm": 1000,
|
|
"norm_type": "layer_norm",
|
|
"block_type": "unidiffuser",
|
|
"pre_layer_norm": False,
|
|
"use_timestep_embedding": False,
|
|
"norm_elementwise_affine": True,
|
|
"use_patch_pos_embed": False,
|
|
"ff_final_dropout": True,
|
|
"use_data_type_embedding": False,
|
|
}
|
|
return unet_config
|
|
|
|
|
|
# From https://huggingface.co/gpt2/blob/main/config.json, the GPT2 checkpoint used by UniDiffuser
|
|
def create_text_decoder_config_big():
|
|
text_decoder_config = {
|
|
"prefix_length": 77,
|
|
"prefix_inner_dim": 768,
|
|
"prefix_hidden_dim": 64,
|
|
"vocab_size": 50258, # 50257 + 1 for new EOS token
|
|
"n_positions": 1024,
|
|
"n_embd": 768,
|
|
"n_layer": 12,
|
|
"n_head": 12,
|
|
"n_inner": 3072,
|
|
"activation_function": "gelu",
|
|
"resid_pdrop": 0.1,
|
|
"embd_pdrop": 0.1,
|
|
"attn_pdrop": 0.1,
|
|
"layer_norm_epsilon": 1e-5,
|
|
"initializer_range": 0.02,
|
|
}
|
|
return text_decoder_config
|
|
|
|
|
|
# Based on diffusers.pipelines.stable_diffusion.convert_from_ckpt.convert_ldm_vae_checkpoint
|
|
def convert_vae_to_diffusers(ckpt, diffusers_model, num_head_channels=1):
|
|
"""
|
|
Converts a UniDiffuser autoencoder_kl.pth checkpoint to a diffusers AutoencoderKL.
|
|
"""
|
|
# autoencoder_kl.pth ckpt is a torch state dict
|
|
vae_state_dict = torch.load(ckpt, map_location="cpu")
|
|
|
|
new_checkpoint = {}
|
|
|
|
new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
|
|
new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
|
|
new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
|
|
new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
|
|
new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
|
|
new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
|
|
|
|
new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
|
|
new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
|
|
new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
|
|
new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
|
|
new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
|
|
new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
|
|
|
|
new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
|
|
new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
|
|
new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
|
|
new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
|
|
|
|
# Retrieves the keys for the encoder down blocks only
|
|
num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
|
|
down_blocks = {
|
|
layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
|
|
}
|
|
|
|
# Retrieves the keys for the decoder up blocks only
|
|
num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
|
|
up_blocks = {
|
|
layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
|
|
}
|
|
|
|
for i in range(num_down_blocks):
|
|
resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
|
|
|
|
if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
|
|
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
|
|
f"encoder.down.{i}.downsample.conv.weight"
|
|
)
|
|
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
|
|
f"encoder.down.{i}.downsample.conv.bias"
|
|
)
|
|
|
|
paths = renew_vae_resnet_paths(resnets)
|
|
meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
|
|
assign_to_checkpoint(
|
|
paths,
|
|
new_checkpoint,
|
|
vae_state_dict,
|
|
additional_replacements=[meta_path],
|
|
num_head_channels=num_head_channels, # not used in vae
|
|
)
|
|
|
|
mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
|
|
num_mid_res_blocks = 2
|
|
for i in range(1, num_mid_res_blocks + 1):
|
|
resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
|
|
|
|
paths = renew_vae_resnet_paths(resnets)
|
|
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
|
|
assign_to_checkpoint(
|
|
paths,
|
|
new_checkpoint,
|
|
vae_state_dict,
|
|
additional_replacements=[meta_path],
|
|
num_head_channels=num_head_channels, # not used in vae
|
|
)
|
|
|
|
mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
|
|
paths = renew_vae_attention_paths(mid_attentions)
|
|
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
|
|
assign_to_checkpoint(
|
|
paths,
|
|
new_checkpoint,
|
|
vae_state_dict,
|
|
additional_replacements=[meta_path],
|
|
num_head_channels=num_head_channels, # not used in vae
|
|
)
|
|
conv_attn_to_linear(new_checkpoint)
|
|
|
|
for i in range(num_up_blocks):
|
|
block_id = num_up_blocks - 1 - i
|
|
resnets = [
|
|
key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
|
|
]
|
|
|
|
if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
|
|
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
|
|
f"decoder.up.{block_id}.upsample.conv.weight"
|
|
]
|
|
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
|
|
f"decoder.up.{block_id}.upsample.conv.bias"
|
|
]
|
|
|
|
paths = renew_vae_resnet_paths(resnets)
|
|
meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
|
|
assign_to_checkpoint(
|
|
paths,
|
|
new_checkpoint,
|
|
vae_state_dict,
|
|
additional_replacements=[meta_path],
|
|
num_head_channels=num_head_channels, # not used in vae
|
|
)
|
|
|
|
mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
|
|
num_mid_res_blocks = 2
|
|
for i in range(1, num_mid_res_blocks + 1):
|
|
resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
|
|
|
|
paths = renew_vae_resnet_paths(resnets)
|
|
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
|
|
assign_to_checkpoint(
|
|
paths,
|
|
new_checkpoint,
|
|
vae_state_dict,
|
|
additional_replacements=[meta_path],
|
|
num_head_channels=num_head_channels, # not used in vae
|
|
)
|
|
|
|
mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
|
|
paths = renew_vae_attention_paths(mid_attentions)
|
|
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
|
|
assign_to_checkpoint(
|
|
paths,
|
|
new_checkpoint,
|
|
vae_state_dict,
|
|
additional_replacements=[meta_path],
|
|
num_head_channels=num_head_channels, # not used in vae
|
|
)
|
|
conv_attn_to_linear(new_checkpoint)
|
|
|
|
missing_keys, unexpected_keys = diffusers_model.load_state_dict(new_checkpoint)
|
|
for missing_key in missing_keys:
|
|
print(f"Missing key: {missing_key}")
|
|
for unexpected_key in unexpected_keys:
|
|
print(f"Unexpected key: {unexpected_key}")
|
|
|
|
return diffusers_model
|
|
|
|
|
|
def convert_uvit_block_to_diffusers_block(
|
|
uvit_state_dict,
|
|
new_state_dict,
|
|
block_prefix,
|
|
new_prefix="transformer.transformer_",
|
|
skip_connection=False,
|
|
):
|
|
"""
|
|
Maps the keys in a UniDiffuser transformer block (`Block`) to the keys in a diffusers transformer block
|
|
(`UTransformerBlock`/`UniDiffuserBlock`).
|
|
"""
|
|
prefix = new_prefix + block_prefix
|
|
if skip_connection:
|
|
new_state_dict[prefix + ".skip.skip_linear.weight"] = uvit_state_dict[block_prefix + ".skip_linear.weight"]
|
|
new_state_dict[prefix + ".skip.skip_linear.bias"] = uvit_state_dict[block_prefix + ".skip_linear.bias"]
|
|
new_state_dict[prefix + ".skip.norm.weight"] = uvit_state_dict[block_prefix + ".norm1.weight"]
|
|
new_state_dict[prefix + ".skip.norm.bias"] = uvit_state_dict[block_prefix + ".norm1.bias"]
|
|
|
|
# Create the prefix string for out_blocks.
|
|
prefix += ".block"
|
|
|
|
# Split up attention qkv.weight into to_q.weight, to_k.weight, to_v.weight
|
|
qkv = uvit_state_dict[block_prefix + ".attn.qkv.weight"]
|
|
new_attn_keys = [".attn1.to_q.weight", ".attn1.to_k.weight", ".attn1.to_v.weight"]
|
|
new_attn_keys = [prefix + key for key in new_attn_keys]
|
|
shape = qkv.shape[0] // len(new_attn_keys)
|
|
for i, attn_key in enumerate(new_attn_keys):
|
|
new_state_dict[attn_key] = qkv[i * shape : (i + 1) * shape]
|
|
|
|
new_state_dict[prefix + ".attn1.to_out.0.weight"] = uvit_state_dict[block_prefix + ".attn.proj.weight"]
|
|
new_state_dict[prefix + ".attn1.to_out.0.bias"] = uvit_state_dict[block_prefix + ".attn.proj.bias"]
|
|
new_state_dict[prefix + ".norm1.weight"] = uvit_state_dict[block_prefix + ".norm2.weight"]
|
|
new_state_dict[prefix + ".norm1.bias"] = uvit_state_dict[block_prefix + ".norm2.bias"]
|
|
new_state_dict[prefix + ".ff.net.0.proj.weight"] = uvit_state_dict[block_prefix + ".mlp.fc1.weight"]
|
|
new_state_dict[prefix + ".ff.net.0.proj.bias"] = uvit_state_dict[block_prefix + ".mlp.fc1.bias"]
|
|
new_state_dict[prefix + ".ff.net.2.weight"] = uvit_state_dict[block_prefix + ".mlp.fc2.weight"]
|
|
new_state_dict[prefix + ".ff.net.2.bias"] = uvit_state_dict[block_prefix + ".mlp.fc2.bias"]
|
|
new_state_dict[prefix + ".norm3.weight"] = uvit_state_dict[block_prefix + ".norm3.weight"]
|
|
new_state_dict[prefix + ".norm3.bias"] = uvit_state_dict[block_prefix + ".norm3.bias"]
|
|
|
|
return uvit_state_dict, new_state_dict
|
|
|
|
|
|
def convert_uvit_to_diffusers(ckpt, diffusers_model):
|
|
"""
|
|
Converts a UniDiffuser uvit_v*.pth checkpoint to a diffusers UniDiffusersModel.
|
|
"""
|
|
# uvit_v*.pth ckpt is a torch state dict
|
|
uvit_state_dict = torch.load(ckpt, map_location="cpu")
|
|
|
|
new_state_dict = {}
|
|
|
|
# Input layers
|
|
new_state_dict["vae_img_in.proj.weight"] = uvit_state_dict["patch_embed.proj.weight"]
|
|
new_state_dict["vae_img_in.proj.bias"] = uvit_state_dict["patch_embed.proj.bias"]
|
|
new_state_dict["clip_img_in.weight"] = uvit_state_dict["clip_img_embed.weight"]
|
|
new_state_dict["clip_img_in.bias"] = uvit_state_dict["clip_img_embed.bias"]
|
|
new_state_dict["text_in.weight"] = uvit_state_dict["text_embed.weight"]
|
|
new_state_dict["text_in.bias"] = uvit_state_dict["text_embed.bias"]
|
|
|
|
new_state_dict["pos_embed"] = uvit_state_dict["pos_embed"]
|
|
|
|
# Handle data type token embeddings for UniDiffuser-v1
|
|
if "token_embedding.weight" in uvit_state_dict and diffusers_model.use_data_type_embedding:
|
|
new_state_dict["data_type_pos_embed_token"] = uvit_state_dict["pos_embed_token"]
|
|
new_state_dict["data_type_token_embedding.weight"] = uvit_state_dict["token_embedding.weight"]
|
|
|
|
# Also initialize the PatchEmbedding in UTransformer2DModel with the PatchEmbedding from the checkpoint.
|
|
# This isn't used in the current implementation, so might want to remove.
|
|
new_state_dict["transformer.pos_embed.proj.weight"] = uvit_state_dict["patch_embed.proj.weight"]
|
|
new_state_dict["transformer.pos_embed.proj.bias"] = uvit_state_dict["patch_embed.proj.bias"]
|
|
|
|
# Output layers
|
|
new_state_dict["transformer.norm_out.weight"] = uvit_state_dict["norm.weight"]
|
|
new_state_dict["transformer.norm_out.bias"] = uvit_state_dict["norm.bias"]
|
|
|
|
new_state_dict["vae_img_out.weight"] = uvit_state_dict["decoder_pred.weight"]
|
|
new_state_dict["vae_img_out.bias"] = uvit_state_dict["decoder_pred.bias"]
|
|
new_state_dict["clip_img_out.weight"] = uvit_state_dict["clip_img_out.weight"]
|
|
new_state_dict["clip_img_out.bias"] = uvit_state_dict["clip_img_out.bias"]
|
|
new_state_dict["text_out.weight"] = uvit_state_dict["text_out.weight"]
|
|
new_state_dict["text_out.bias"] = uvit_state_dict["text_out.bias"]
|
|
|
|
# in_blocks
|
|
in_blocks_prefixes = {".".join(layer.split(".")[:2]) for layer in uvit_state_dict if "in_blocks" in layer}
|
|
for in_block_prefix in list(in_blocks_prefixes):
|
|
convert_uvit_block_to_diffusers_block(uvit_state_dict, new_state_dict, in_block_prefix)
|
|
|
|
# mid_block
|
|
# Assume there's only one mid block
|
|
convert_uvit_block_to_diffusers_block(uvit_state_dict, new_state_dict, "mid_block")
|
|
|
|
# out_blocks
|
|
out_blocks_prefixes = {".".join(layer.split(".")[:2]) for layer in uvit_state_dict if "out_blocks" in layer}
|
|
for out_block_prefix in list(out_blocks_prefixes):
|
|
convert_uvit_block_to_diffusers_block(uvit_state_dict, new_state_dict, out_block_prefix, skip_connection=True)
|
|
|
|
missing_keys, unexpected_keys = diffusers_model.load_state_dict(new_state_dict)
|
|
for missing_key in missing_keys:
|
|
print(f"Missing key: {missing_key}")
|
|
for unexpected_key in unexpected_keys:
|
|
print(f"Unexpected key: {unexpected_key}")
|
|
|
|
return diffusers_model
|
|
|
|
|
|
def convert_caption_decoder_to_diffusers(ckpt, diffusers_model):
|
|
"""
|
|
Converts a UniDiffuser caption_decoder.pth checkpoint to a diffusers UniDiffuserTextDecoder.
|
|
"""
|
|
# caption_decoder.pth ckpt is a torch state dict
|
|
checkpoint_state_dict = torch.load(ckpt, map_location="cpu")
|
|
decoder_state_dict = {}
|
|
# Remove the "module." prefix, if necessary
|
|
caption_decoder_key = "module."
|
|
for key in checkpoint_state_dict:
|
|
if key.startswith(caption_decoder_key):
|
|
decoder_state_dict[key.replace(caption_decoder_key, "")] = checkpoint_state_dict.get(key)
|
|
else:
|
|
decoder_state_dict[key] = checkpoint_state_dict.get(key)
|
|
|
|
new_state_dict = {}
|
|
|
|
# Encoder and Decoder
|
|
new_state_dict["encode_prefix.weight"] = decoder_state_dict["encode_prefix.weight"]
|
|
new_state_dict["encode_prefix.bias"] = decoder_state_dict["encode_prefix.bias"]
|
|
new_state_dict["decode_prefix.weight"] = decoder_state_dict["decode_prefix.weight"]
|
|
new_state_dict["decode_prefix.bias"] = decoder_state_dict["decode_prefix.bias"]
|
|
|
|
# Internal GPT2LMHeadModel transformer model
|
|
for key, val in decoder_state_dict.items():
|
|
if key.startswith("gpt"):
|
|
suffix = key[len("gpt") :]
|
|
new_state_dict["transformer" + suffix] = val
|
|
|
|
missing_keys, unexpected_keys = diffusers_model.load_state_dict(new_state_dict)
|
|
for missing_key in missing_keys:
|
|
print(f"Missing key: {missing_key}")
|
|
for unexpected_key in unexpected_keys:
|
|
print(f"Unexpected key: {unexpected_key}")
|
|
|
|
return diffusers_model
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
|
|
parser.add_argument(
|
|
"--caption_decoder_checkpoint_path",
|
|
default=None,
|
|
type=str,
|
|
required=False,
|
|
help="Path to caption decoder checkpoint to convert.",
|
|
)
|
|
parser.add_argument(
|
|
"--uvit_checkpoint_path", default=None, type=str, required=False, help="Path to U-ViT checkpoint to convert."
|
|
)
|
|
parser.add_argument(
|
|
"--vae_checkpoint_path",
|
|
default=None,
|
|
type=str,
|
|
required=False,
|
|
help="Path to VAE checkpoint to convert.",
|
|
)
|
|
parser.add_argument(
|
|
"--pipeline_output_path",
|
|
default=None,
|
|
type=str,
|
|
required=True,
|
|
help="Path to save the output pipeline to.",
|
|
)
|
|
parser.add_argument(
|
|
"--config_type",
|
|
default="test",
|
|
type=str,
|
|
help=(
|
|
"Config type to use. Should be 'test' to create small models for testing or 'big' to convert a full"
|
|
" checkpoint."
|
|
),
|
|
)
|
|
parser.add_argument(
|
|
"--version",
|
|
default=0,
|
|
type=int,
|
|
help="The UniDiffuser model type to convert to. Should be 0 for UniDiffuser-v0 and 1 for UniDiffuser-v1.",
|
|
)
|
|
parser.add_argument(
|
|
"--safe_serialization",
|
|
action="store_true",
|
|
help="Whether to use safetensors/safe seialization when saving the pipeline.",
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Convert the VAE model.
|
|
if args.vae_checkpoint_path is not None:
|
|
vae_config = create_vae_diffusers_config(args.config_type)
|
|
vae = AutoencoderKL(**vae_config)
|
|
vae = convert_vae_to_diffusers(args.vae_checkpoint_path, vae)
|
|
|
|
# Convert the U-ViT ("unet") model.
|
|
if args.uvit_checkpoint_path is not None:
|
|
unet_config = create_unidiffuser_unet_config(args.config_type, args.version)
|
|
unet = UniDiffuserModel(**unet_config)
|
|
unet = convert_uvit_to_diffusers(args.uvit_checkpoint_path, unet)
|
|
|
|
# Convert the caption decoder ("text_decoder") model.
|
|
if args.caption_decoder_checkpoint_path is not None:
|
|
text_decoder_config = create_text_decoder_config(args.config_type)
|
|
text_decoder = UniDiffuserTextDecoder(**text_decoder_config)
|
|
text_decoder = convert_caption_decoder_to_diffusers(args.caption_decoder_checkpoint_path, text_decoder)
|
|
|
|
# Scheduler is the same for both the test and big models.
|
|
scheduler_config = SCHEDULER_CONFIG
|
|
scheduler = DPMSolverMultistepScheduler(
|
|
beta_start=scheduler_config.beta_start,
|
|
beta_end=scheduler_config.beta_end,
|
|
beta_schedule=scheduler_config.beta_schedule,
|
|
solver_order=scheduler_config.solver_order,
|
|
)
|
|
|
|
if args.config_type == "test":
|
|
# Make a small random CLIPTextModel
|
|
torch.manual_seed(0)
|
|
clip_text_encoder_config = CLIPTextConfig(
|
|
bos_token_id=0,
|
|
eos_token_id=2,
|
|
hidden_size=32,
|
|
intermediate_size=37,
|
|
layer_norm_eps=1e-05,
|
|
num_attention_heads=4,
|
|
num_hidden_layers=5,
|
|
pad_token_id=1,
|
|
vocab_size=1000,
|
|
)
|
|
text_encoder = CLIPTextModel(clip_text_encoder_config)
|
|
clip_tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
|
|
|
# Make a small random CLIPVisionModel and accompanying CLIPImageProcessor
|
|
torch.manual_seed(0)
|
|
clip_image_encoder_config = CLIPVisionConfig(
|
|
image_size=32,
|
|
patch_size=2,
|
|
num_channels=3,
|
|
hidden_size=32,
|
|
projection_dim=32,
|
|
num_hidden_layers=5,
|
|
num_attention_heads=4,
|
|
intermediate_size=37,
|
|
dropout=0.1,
|
|
attention_dropout=0.1,
|
|
initializer_range=0.02,
|
|
)
|
|
image_encoder = CLIPVisionModelWithProjection(clip_image_encoder_config)
|
|
image_processor = CLIPImageProcessor(crop_size=32, size=32)
|
|
|
|
# Note that the text_decoder should already have its token embeddings resized.
|
|
text_tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-GPT2Model")
|
|
eos = "<|EOS|>"
|
|
special_tokens_dict = {"eos_token": eos}
|
|
text_tokenizer.add_special_tokens(special_tokens_dict)
|
|
elif args.config_type == "big":
|
|
text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
|
|
clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
|
|
|
image_encoder = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-base-patch32")
|
|
image_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
|
|
|
# Note that the text_decoder should already have its token embeddings resized.
|
|
text_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
|
eos = "<|EOS|>"
|
|
special_tokens_dict = {"eos_token": eos}
|
|
text_tokenizer.add_special_tokens(special_tokens_dict)
|
|
else:
|
|
raise NotImplementedError(
|
|
f"Config type {args.config_type} is not implemented, currently only config types"
|
|
" 'test' and 'big' are available."
|
|
)
|
|
|
|
pipeline = UniDiffuserPipeline(
|
|
vae=vae,
|
|
text_encoder=text_encoder,
|
|
image_encoder=image_encoder,
|
|
clip_image_processor=image_processor,
|
|
clip_tokenizer=clip_tokenizer,
|
|
text_decoder=text_decoder,
|
|
text_tokenizer=text_tokenizer,
|
|
unet=unet,
|
|
scheduler=scheduler,
|
|
)
|
|
pipeline.save_pretrained(args.pipeline_output_path, safe_serialization=args.safe_serialization)
|