Files
diffusers/scripts/convert_dcae_to_diffusers.py
Junsong Chen cd892041e2 [DC-AE] Add the official Deep Compression Autoencoder code(32x,64x,128x compression ratio); (#9708)
* first add a script for DC-AE;

* DC-AE init

* replace triton with custom implementation

* 1. rename file and remove un-used codes;

* no longer rely on omegaconf and dataclass

* replace custom activation with diffuers activation

* remove dc_ae attention in attention_processor.py

* iinherit from ModelMixin

* inherit from ConfigMixin

* dc-ae reduce to one file

* update downsample and upsample

* clean code

* support DecoderOutput

* remove get_same_padding and val2tuple

* remove autocast and some assert

* update ResBlock

* remove contents within super().__init__

* Update src/diffusers/models/autoencoders/dc_ae.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* remove opsequential

* update other blocks to support the removal of build_norm

* remove build encoder/decoder project in/out

* remove inheritance of RMSNorm2d from LayerNorm

* remove reset_parameters for RMSNorm2d

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* remove device and dtype in RMSNorm2d __init__

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/models/autoencoders/dc_ae.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/models/autoencoders/dc_ae.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/models/autoencoders/dc_ae.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* remove op_list & build_block

* remove build_stage_main

* change file name to autoencoder_dc

* move LiteMLA to attention.py

* align with other vae decode output;

* add DC-AE into init files;

* update

* make quality && make style;

* quick push before dgx disappears again

* update

* make style

* update

* update

* fix

* refactor

* refactor

* refactor

* update

* possibly change to nn.Linear

* refactor

* make fix-copies

* replace vae with ae

* replace get_block_from_block_type to get_block

* replace downsample_block_type from Conv to conv for consistency

* add scaling factors

* incorporate changes for all checkpoints

* make style

* move mla to attention processor file; split qkv conv to linears

* refactor

* add tests

* from original file loader

* add docs

* add standard autoencoder methods

* combine attention processor

* fix tests

* update

* minor fix

* minor fix

* minor fix & in/out shortcut rename

* minor fix

* make style

* fix paper link

* update docs

* update single file loading

* make style

* remove single file loading support; todo for DN6

* Apply suggestions from code review

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* add abstract

---------

Co-authored-by: Junyu Chen <chenjydl2003@gmail.com>
Co-authored-by: YiYi Xu <yixu310@gmail.com>
Co-authored-by: chenjy2003 <70215701+chenjy2003@users.noreply.github.com>
Co-authored-by: Aryan <aryan@huggingface.co>
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
2024-12-07 01:01:51 +05:30

324 lines
11 KiB
Python

import argparse
from typing import Any, Dict
import torch
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from diffusers import AutoencoderDC
def remap_qkv_(key: str, state_dict: Dict[str, Any]):
qkv = state_dict.pop(key)
q, k, v = torch.chunk(qkv, 3, dim=0)
parent_module, _, _ = key.rpartition(".qkv.conv.weight")
state_dict[f"{parent_module}.to_q.weight"] = q.squeeze()
state_dict[f"{parent_module}.to_k.weight"] = k.squeeze()
state_dict[f"{parent_module}.to_v.weight"] = v.squeeze()
def remap_proj_conv_(key: str, state_dict: Dict[str, Any]):
parent_module, _, _ = key.rpartition(".proj.conv.weight")
state_dict[f"{parent_module}.to_out.weight"] = state_dict.pop(key).squeeze()
AE_KEYS_RENAME_DICT = {
# common
"main.": "",
"op_list.": "",
"context_module": "attn",
"local_module": "conv_out",
# NOTE: The below two lines work because scales in the available configs only have a tuple length of 1
# If there were more scales, there would be more layers, so a loop would be better to handle this
"aggreg.0.0": "to_qkv_multiscale.0.proj_in",
"aggreg.0.1": "to_qkv_multiscale.0.proj_out",
"depth_conv.conv": "conv_depth",
"inverted_conv.conv": "conv_inverted",
"point_conv.conv": "conv_point",
"point_conv.norm": "norm",
"conv.conv.": "conv.",
"conv1.conv": "conv1",
"conv2.conv": "conv2",
"conv2.norm": "norm",
"proj.norm": "norm_out",
# encoder
"encoder.project_in.conv": "encoder.conv_in",
"encoder.project_out.0.conv": "encoder.conv_out",
"encoder.stages": "encoder.down_blocks",
# decoder
"decoder.project_in.conv": "decoder.conv_in",
"decoder.project_out.0": "decoder.norm_out",
"decoder.project_out.2.conv": "decoder.conv_out",
"decoder.stages": "decoder.up_blocks",
}
AE_F32C32_KEYS = {
# encoder
"encoder.project_in.conv": "encoder.conv_in.conv",
# decoder
"decoder.project_out.2.conv": "decoder.conv_out.conv",
}
AE_F64C128_KEYS = {
# encoder
"encoder.project_in.conv": "encoder.conv_in.conv",
# decoder
"decoder.project_out.2.conv": "decoder.conv_out.conv",
}
AE_F128C512_KEYS = {
# encoder
"encoder.project_in.conv": "encoder.conv_in.conv",
# decoder
"decoder.project_out.2.conv": "decoder.conv_out.conv",
}
AE_SPECIAL_KEYS_REMAP = {
"qkv.conv.weight": remap_qkv_,
"proj.conv.weight": remap_proj_conv_,
}
def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]:
state_dict = saved_dict
if "model" in saved_dict.keys():
state_dict = state_dict["model"]
if "module" in saved_dict.keys():
state_dict = state_dict["module"]
if "state_dict" in saved_dict.keys():
state_dict = state_dict["state_dict"]
return state_dict
def update_state_dict_(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]:
state_dict[new_key] = state_dict.pop(old_key)
def convert_ae(config_name: str, dtype: torch.dtype):
config = get_ae_config(config_name)
hub_id = f"mit-han-lab/{config_name}"
ckpt_path = hf_hub_download(hub_id, "model.safetensors")
original_state_dict = get_state_dict(load_file(ckpt_path))
ae = AutoencoderDC(**config).to(dtype=dtype)
for key in list(original_state_dict.keys()):
new_key = key[:]
for replace_key, rename_key in AE_KEYS_RENAME_DICT.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict_(original_state_dict, key, new_key)
for key in list(original_state_dict.keys()):
for special_key, handler_fn_inplace in AE_SPECIAL_KEYS_REMAP.items():
if special_key not in key:
continue
handler_fn_inplace(key, original_state_dict)
ae.load_state_dict(original_state_dict, strict=True)
return ae
def get_ae_config(name: str):
if name in ["dc-ae-f32c32-sana-1.0"]:
config = {
"latent_channels": 32,
"encoder_block_types": (
"ResBlock",
"ResBlock",
"ResBlock",
"EfficientViTBlock",
"EfficientViTBlock",
"EfficientViTBlock",
),
"decoder_block_types": (
"ResBlock",
"ResBlock",
"ResBlock",
"EfficientViTBlock",
"EfficientViTBlock",
"EfficientViTBlock",
),
"encoder_block_out_channels": (128, 256, 512, 512, 1024, 1024),
"decoder_block_out_channels": (128, 256, 512, 512, 1024, 1024),
"encoder_qkv_multiscales": ((), (), (), (5,), (5,), (5,)),
"decoder_qkv_multiscales": ((), (), (), (5,), (5,), (5,)),
"encoder_layers_per_block": (2, 2, 2, 3, 3, 3),
"decoder_layers_per_block": [3, 3, 3, 3, 3, 3],
"downsample_block_type": "conv",
"upsample_block_type": "interpolate",
"decoder_norm_types": "rms_norm",
"decoder_act_fns": "silu",
"scaling_factor": 0.41407,
}
elif name in ["dc-ae-f32c32-in-1.0", "dc-ae-f32c32-mix-1.0"]:
AE_KEYS_RENAME_DICT.update(AE_F32C32_KEYS)
config = {
"latent_channels": 32,
"encoder_block_types": [
"ResBlock",
"ResBlock",
"ResBlock",
"EfficientViTBlock",
"EfficientViTBlock",
"EfficientViTBlock",
],
"decoder_block_types": [
"ResBlock",
"ResBlock",
"ResBlock",
"EfficientViTBlock",
"EfficientViTBlock",
"EfficientViTBlock",
],
"encoder_block_out_channels": [128, 256, 512, 512, 1024, 1024],
"decoder_block_out_channels": [128, 256, 512, 512, 1024, 1024],
"encoder_layers_per_block": [0, 4, 8, 2, 2, 2],
"decoder_layers_per_block": [0, 5, 10, 2, 2, 2],
"encoder_qkv_multiscales": ((), (), (), (), (), ()),
"decoder_qkv_multiscales": ((), (), (), (), (), ()),
"decoder_norm_types": ["batch_norm", "batch_norm", "batch_norm", "rms_norm", "rms_norm", "rms_norm"],
"decoder_act_fns": ["relu", "relu", "relu", "silu", "silu", "silu"],
}
if name == "dc-ae-f32c32-in-1.0":
config["scaling_factor"] = 0.3189
elif name == "dc-ae-f32c32-mix-1.0":
config["scaling_factor"] = 0.4552
elif name in ["dc-ae-f64c128-in-1.0", "dc-ae-f64c128-mix-1.0"]:
AE_KEYS_RENAME_DICT.update(AE_F64C128_KEYS)
config = {
"latent_channels": 128,
"encoder_block_types": [
"ResBlock",
"ResBlock",
"ResBlock",
"EfficientViTBlock",
"EfficientViTBlock",
"EfficientViTBlock",
"EfficientViTBlock",
],
"decoder_block_types": [
"ResBlock",
"ResBlock",
"ResBlock",
"EfficientViTBlock",
"EfficientViTBlock",
"EfficientViTBlock",
"EfficientViTBlock",
],
"encoder_block_out_channels": [128, 256, 512, 512, 1024, 1024, 2048],
"decoder_block_out_channels": [128, 256, 512, 512, 1024, 1024, 2048],
"encoder_layers_per_block": [0, 4, 8, 2, 2, 2, 2],
"decoder_layers_per_block": [0, 5, 10, 2, 2, 2, 2],
"encoder_qkv_multiscales": ((), (), (), (), (), (), ()),
"decoder_qkv_multiscales": ((), (), (), (), (), (), ()),
"decoder_norm_types": [
"batch_norm",
"batch_norm",
"batch_norm",
"rms_norm",
"rms_norm",
"rms_norm",
"rms_norm",
],
"decoder_act_fns": ["relu", "relu", "relu", "silu", "silu", "silu", "silu"],
}
if name == "dc-ae-f64c128-in-1.0":
config["scaling_factor"] = 0.2889
elif name == "dc-ae-f64c128-mix-1.0":
config["scaling_factor"] = 0.4538
elif name in ["dc-ae-f128c512-in-1.0", "dc-ae-f128c512-mix-1.0"]:
AE_KEYS_RENAME_DICT.update(AE_F128C512_KEYS)
config = {
"latent_channels": 512,
"encoder_block_types": [
"ResBlock",
"ResBlock",
"ResBlock",
"EfficientViTBlock",
"EfficientViTBlock",
"EfficientViTBlock",
"EfficientViTBlock",
"EfficientViTBlock",
],
"decoder_block_types": [
"ResBlock",
"ResBlock",
"ResBlock",
"EfficientViTBlock",
"EfficientViTBlock",
"EfficientViTBlock",
"EfficientViTBlock",
"EfficientViTBlock",
],
"encoder_block_out_channels": [128, 256, 512, 512, 1024, 1024, 2048, 2048],
"decoder_block_out_channels": [128, 256, 512, 512, 1024, 1024, 2048, 2048],
"encoder_layers_per_block": [0, 4, 8, 2, 2, 2, 2, 2],
"decoder_layers_per_block": [0, 5, 10, 2, 2, 2, 2, 2],
"encoder_qkv_multiscales": ((), (), (), (), (), (), (), ()),
"decoder_qkv_multiscales": ((), (), (), (), (), (), (), ()),
"decoder_norm_types": [
"batch_norm",
"batch_norm",
"batch_norm",
"rms_norm",
"rms_norm",
"rms_norm",
"rms_norm",
"rms_norm",
],
"decoder_act_fns": ["relu", "relu", "relu", "silu", "silu", "silu", "silu", "silu"],
}
if name == "dc-ae-f128c512-in-1.0":
config["scaling_factor"] = 0.4883
elif name == "dc-ae-f128c512-mix-1.0":
config["scaling_factor"] = 0.3620
else:
raise ValueError("Invalid config name provided.")
return config
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--config_name",
type=str,
default="dc-ae-f32c32-sana-1.0",
choices=[
"dc-ae-f32c32-sana-1.0",
"dc-ae-f32c32-in-1.0",
"dc-ae-f32c32-mix-1.0",
"dc-ae-f64c128-in-1.0",
"dc-ae-f64c128-mix-1.0",
"dc-ae-f128c512-in-1.0",
"dc-ae-f128c512-mix-1.0",
],
help="The DCAE checkpoint to convert",
)
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
parser.add_argument("--dtype", default="fp32", help="Torch dtype to save the model in.")
return parser.parse_args()
DTYPE_MAPPING = {
"fp32": torch.float32,
"fp16": torch.float16,
"bf16": torch.bfloat16,
}
VARIANT_MAPPING = {
"fp32": None,
"fp16": "fp16",
"bf16": "bf16",
}
if __name__ == "__main__":
args = get_args()
dtype = DTYPE_MAPPING[args.dtype]
variant = VARIANT_MAPPING[args.dtype]
ae = convert_ae(args.config_name, dtype)
ae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB", variant=variant)