mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-06 12:34:13 +08:00
[DC-AE] Add the official Deep Compression Autoencoder code(32x,64x,128x compression ratio); (#9708)
* first add a script for DC-AE; * DC-AE init * replace triton with custom implementation * 1. rename file and remove un-used codes; * no longer rely on omegaconf and dataclass * replace custom activation with diffuers activation * remove dc_ae attention in attention_processor.py * iinherit from ModelMixin * inherit from ConfigMixin * dc-ae reduce to one file * update downsample and upsample * clean code * support DecoderOutput * remove get_same_padding and val2tuple * remove autocast and some assert * update ResBlock * remove contents within super().__init__ * Update src/diffusers/models/autoencoders/dc_ae.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * remove opsequential * update other blocks to support the removal of build_norm * remove build encoder/decoder project in/out * remove inheritance of RMSNorm2d from LayerNorm * remove reset_parameters for RMSNorm2d Co-authored-by: YiYi Xu <yixu310@gmail.com> * remove device and dtype in RMSNorm2d __init__ Co-authored-by: YiYi Xu <yixu310@gmail.com> * Update src/diffusers/models/autoencoders/dc_ae.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * Update src/diffusers/models/autoencoders/dc_ae.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * Update src/diffusers/models/autoencoders/dc_ae.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * remove op_list & build_block * remove build_stage_main * change file name to autoencoder_dc * move LiteMLA to attention.py * align with other vae decode output; * add DC-AE into init files; * update * make quality && make style; * quick push before dgx disappears again * update * make style * update * update * fix * refactor * refactor * refactor * update * possibly change to nn.Linear * refactor * make fix-copies * replace vae with ae * replace get_block_from_block_type to get_block * replace downsample_block_type from Conv to conv for consistency * add scaling factors * incorporate changes for all checkpoints * make style * move mla to attention processor file; split qkv conv to linears * refactor * add tests * from original file loader * add docs * add standard autoencoder methods * combine attention processor * fix tests * update * minor fix * minor fix * minor fix & in/out shortcut rename * minor fix * make style * fix paper link * update docs * update single file loading * make style * remove single file loading support; todo for DN6 * Apply suggestions from code review Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * add abstract --------- Co-authored-by: Junyu Chen <chenjydl2003@gmail.com> Co-authored-by: YiYi Xu <yixu310@gmail.com> Co-authored-by: chenjy2003 <70215701+chenjy2003@users.noreply.github.com> Co-authored-by: Aryan <aryan@huggingface.co> Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
This commit is contained in:
@@ -314,6 +314,8 @@
|
||||
title: AutoencoderKLMochi
|
||||
- local: api/models/asymmetricautoencoderkl
|
||||
title: AsymmetricAutoencoderKL
|
||||
- local: api/models/autoencoder_dc
|
||||
title: AutoencoderDC
|
||||
- local: api/models/consistency_decoder_vae
|
||||
title: ConsistencyDecoderVAE
|
||||
- local: api/models/autoencoder_oobleck
|
||||
|
||||
50
docs/source/en/api/models/autoencoder_dc.md
Normal file
50
docs/source/en/api/models/autoencoder_dc.md
Normal file
@@ -0,0 +1,50 @@
|
||||
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License. -->
|
||||
|
||||
# AutoencoderDC
|
||||
|
||||
The 2D Autoencoder model used in [SANA](https://huggingface.co/papers/2410.10629) and introduced in [DCAE](https://huggingface.co/papers/2410.10733) by authors Junyu Chen\*, Han Cai\*, Junsong Chen, Enze Xie, Shang Yang, Haotian Tang, Muyang Li, Yao Lu, Song Han from MIT HAN Lab.
|
||||
|
||||
The abstract from the paper is:
|
||||
|
||||
*We present Deep Compression Autoencoder (DC-AE), a new family of autoencoder models for accelerating high-resolution diffusion models. Existing autoencoder models have demonstrated impressive results at a moderate spatial compression ratio (e.g., 8x), but fail to maintain satisfactory reconstruction accuracy for high spatial compression ratios (e.g., 64x). We address this challenge by introducing two key techniques: (1) Residual Autoencoding, where we design our models to learn residuals based on the space-to-channel transformed features to alleviate the optimization difficulty of high spatial-compression autoencoders; (2) Decoupled High-Resolution Adaptation, an efficient decoupled three-phases training strategy for mitigating the generalization penalty of high spatial-compression autoencoders. With these designs, we improve the autoencoder's spatial compression ratio up to 128 while maintaining the reconstruction quality. Applying our DC-AE to latent diffusion models, we achieve significant speedup without accuracy drop. For example, on ImageNet 512x512, our DC-AE provides 19.1x inference speedup and 17.9x training speedup on H100 GPU for UViT-H while achieving a better FID, compared with the widely used SD-VAE-f8 autoencoder. Our code is available at [this https URL](https://github.com/mit-han-lab/efficientvit).*
|
||||
|
||||
The following DCAE models are released and supported in Diffusers.
|
||||
|
||||
| Diffusers format | Original format |
|
||||
|:----------------:|:---------------:|
|
||||
| [`mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers`](https://huggingface.co/mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers) | [`mit-han-lab/dc-ae-f32c32-sana-1.0`](https://huggingface.co/mit-han-lab/dc-ae-f32c32-sana-1.0)
|
||||
| [`mit-han-lab/dc-ae-f32c32-in-1.0-diffusers`](https://huggingface.co/mit-han-lab/dc-ae-f32c32-in-1.0-diffusers) | [`mit-han-lab/dc-ae-f32c32-in-1.0`](https://huggingface.co/mit-han-lab/dc-ae-f32c32-in-1.0)
|
||||
| [`mit-han-lab/dc-ae-f32c32-mix-1.0-diffusers`](https://huggingface.co/mit-han-lab/dc-ae-f32c32-mix-1.0-diffusers) | [`mit-han-lab/dc-ae-f32c32-mix-1.0`](https://huggingface.co/mit-han-lab/dc-ae-f32c32-mix-1.0)
|
||||
| [`mit-han-lab/dc-ae-f64c128-in-1.0-diffusers`](https://huggingface.co/mit-han-lab/dc-ae-f64c128-in-1.0-diffusers) | [`mit-han-lab/dc-ae-f64c128-in-1.0`](https://huggingface.co/mit-han-lab/dc-ae-f64c128-in-1.0)
|
||||
| [`mit-han-lab/dc-ae-f64c128-mix-1.0-diffusers`](https://huggingface.co/mit-han-lab/dc-ae-f64c128-mix-1.0-diffusers) | [`mit-han-lab/dc-ae-f64c128-mix-1.0`](https://huggingface.co/mit-han-lab/dc-ae-f64c128-mix-1.0)
|
||||
| [`mit-han-lab/dc-ae-f128c512-in-1.0-diffusers`](https://huggingface.co/mit-han-lab/dc-ae-f128c512-in-1.0-diffusers) | [`mit-han-lab/dc-ae-f128c512-in-1.0`](https://huggingface.co/mit-han-lab/dc-ae-f128c512-in-1.0)
|
||||
| [`mit-han-lab/dc-ae-f128c512-mix-1.0-diffusers`](https://huggingface.co/mit-han-lab/dc-ae-f128c512-mix-1.0-diffusers) | [`mit-han-lab/dc-ae-f128c512-mix-1.0`](https://huggingface.co/mit-han-lab/dc-ae-f128c512-mix-1.0)
|
||||
|
||||
Load a model in Diffusers format with [`~ModelMixin.from_pretrained`].
|
||||
|
||||
```python
|
||||
from diffusers import AutoencoderDC
|
||||
|
||||
ae = AutoencoderDC.from_pretrained("mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers", torch_dtype=torch.float32).to("cuda")
|
||||
```
|
||||
|
||||
## AutoencoderDC
|
||||
|
||||
[[autodoc]] AutoencoderDC
|
||||
- encode
|
||||
- decode
|
||||
- all
|
||||
|
||||
## DecoderOutput
|
||||
|
||||
[[autodoc]] models.autoencoders.vae.DecoderOutput
|
||||
|
||||
323
scripts/convert_dcae_to_diffusers.py
Normal file
323
scripts/convert_dcae_to_diffusers.py
Normal file
@@ -0,0 +1,323 @@
|
||||
import argparse
|
||||
from typing import Any, Dict
|
||||
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download
|
||||
from safetensors.torch import load_file
|
||||
|
||||
from diffusers import AutoencoderDC
|
||||
|
||||
|
||||
def remap_qkv_(key: str, state_dict: Dict[str, Any]):
|
||||
qkv = state_dict.pop(key)
|
||||
q, k, v = torch.chunk(qkv, 3, dim=0)
|
||||
parent_module, _, _ = key.rpartition(".qkv.conv.weight")
|
||||
state_dict[f"{parent_module}.to_q.weight"] = q.squeeze()
|
||||
state_dict[f"{parent_module}.to_k.weight"] = k.squeeze()
|
||||
state_dict[f"{parent_module}.to_v.weight"] = v.squeeze()
|
||||
|
||||
|
||||
def remap_proj_conv_(key: str, state_dict: Dict[str, Any]):
|
||||
parent_module, _, _ = key.rpartition(".proj.conv.weight")
|
||||
state_dict[f"{parent_module}.to_out.weight"] = state_dict.pop(key).squeeze()
|
||||
|
||||
|
||||
AE_KEYS_RENAME_DICT = {
|
||||
# common
|
||||
"main.": "",
|
||||
"op_list.": "",
|
||||
"context_module": "attn",
|
||||
"local_module": "conv_out",
|
||||
# NOTE: The below two lines work because scales in the available configs only have a tuple length of 1
|
||||
# If there were more scales, there would be more layers, so a loop would be better to handle this
|
||||
"aggreg.0.0": "to_qkv_multiscale.0.proj_in",
|
||||
"aggreg.0.1": "to_qkv_multiscale.0.proj_out",
|
||||
"depth_conv.conv": "conv_depth",
|
||||
"inverted_conv.conv": "conv_inverted",
|
||||
"point_conv.conv": "conv_point",
|
||||
"point_conv.norm": "norm",
|
||||
"conv.conv.": "conv.",
|
||||
"conv1.conv": "conv1",
|
||||
"conv2.conv": "conv2",
|
||||
"conv2.norm": "norm",
|
||||
"proj.norm": "norm_out",
|
||||
# encoder
|
||||
"encoder.project_in.conv": "encoder.conv_in",
|
||||
"encoder.project_out.0.conv": "encoder.conv_out",
|
||||
"encoder.stages": "encoder.down_blocks",
|
||||
# decoder
|
||||
"decoder.project_in.conv": "decoder.conv_in",
|
||||
"decoder.project_out.0": "decoder.norm_out",
|
||||
"decoder.project_out.2.conv": "decoder.conv_out",
|
||||
"decoder.stages": "decoder.up_blocks",
|
||||
}
|
||||
|
||||
AE_F32C32_KEYS = {
|
||||
# encoder
|
||||
"encoder.project_in.conv": "encoder.conv_in.conv",
|
||||
# decoder
|
||||
"decoder.project_out.2.conv": "decoder.conv_out.conv",
|
||||
}
|
||||
|
||||
AE_F64C128_KEYS = {
|
||||
# encoder
|
||||
"encoder.project_in.conv": "encoder.conv_in.conv",
|
||||
# decoder
|
||||
"decoder.project_out.2.conv": "decoder.conv_out.conv",
|
||||
}
|
||||
|
||||
AE_F128C512_KEYS = {
|
||||
# encoder
|
||||
"encoder.project_in.conv": "encoder.conv_in.conv",
|
||||
# decoder
|
||||
"decoder.project_out.2.conv": "decoder.conv_out.conv",
|
||||
}
|
||||
|
||||
AE_SPECIAL_KEYS_REMAP = {
|
||||
"qkv.conv.weight": remap_qkv_,
|
||||
"proj.conv.weight": remap_proj_conv_,
|
||||
}
|
||||
|
||||
|
||||
def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]:
|
||||
state_dict = saved_dict
|
||||
if "model" in saved_dict.keys():
|
||||
state_dict = state_dict["model"]
|
||||
if "module" in saved_dict.keys():
|
||||
state_dict = state_dict["module"]
|
||||
if "state_dict" in saved_dict.keys():
|
||||
state_dict = state_dict["state_dict"]
|
||||
return state_dict
|
||||
|
||||
|
||||
def update_state_dict_(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]:
|
||||
state_dict[new_key] = state_dict.pop(old_key)
|
||||
|
||||
|
||||
def convert_ae(config_name: str, dtype: torch.dtype):
|
||||
config = get_ae_config(config_name)
|
||||
hub_id = f"mit-han-lab/{config_name}"
|
||||
ckpt_path = hf_hub_download(hub_id, "model.safetensors")
|
||||
original_state_dict = get_state_dict(load_file(ckpt_path))
|
||||
|
||||
ae = AutoencoderDC(**config).to(dtype=dtype)
|
||||
|
||||
for key in list(original_state_dict.keys()):
|
||||
new_key = key[:]
|
||||
for replace_key, rename_key in AE_KEYS_RENAME_DICT.items():
|
||||
new_key = new_key.replace(replace_key, rename_key)
|
||||
update_state_dict_(original_state_dict, key, new_key)
|
||||
|
||||
for key in list(original_state_dict.keys()):
|
||||
for special_key, handler_fn_inplace in AE_SPECIAL_KEYS_REMAP.items():
|
||||
if special_key not in key:
|
||||
continue
|
||||
handler_fn_inplace(key, original_state_dict)
|
||||
|
||||
ae.load_state_dict(original_state_dict, strict=True)
|
||||
return ae
|
||||
|
||||
|
||||
def get_ae_config(name: str):
|
||||
if name in ["dc-ae-f32c32-sana-1.0"]:
|
||||
config = {
|
||||
"latent_channels": 32,
|
||||
"encoder_block_types": (
|
||||
"ResBlock",
|
||||
"ResBlock",
|
||||
"ResBlock",
|
||||
"EfficientViTBlock",
|
||||
"EfficientViTBlock",
|
||||
"EfficientViTBlock",
|
||||
),
|
||||
"decoder_block_types": (
|
||||
"ResBlock",
|
||||
"ResBlock",
|
||||
"ResBlock",
|
||||
"EfficientViTBlock",
|
||||
"EfficientViTBlock",
|
||||
"EfficientViTBlock",
|
||||
),
|
||||
"encoder_block_out_channels": (128, 256, 512, 512, 1024, 1024),
|
||||
"decoder_block_out_channels": (128, 256, 512, 512, 1024, 1024),
|
||||
"encoder_qkv_multiscales": ((), (), (), (5,), (5,), (5,)),
|
||||
"decoder_qkv_multiscales": ((), (), (), (5,), (5,), (5,)),
|
||||
"encoder_layers_per_block": (2, 2, 2, 3, 3, 3),
|
||||
"decoder_layers_per_block": [3, 3, 3, 3, 3, 3],
|
||||
"downsample_block_type": "conv",
|
||||
"upsample_block_type": "interpolate",
|
||||
"decoder_norm_types": "rms_norm",
|
||||
"decoder_act_fns": "silu",
|
||||
"scaling_factor": 0.41407,
|
||||
}
|
||||
elif name in ["dc-ae-f32c32-in-1.0", "dc-ae-f32c32-mix-1.0"]:
|
||||
AE_KEYS_RENAME_DICT.update(AE_F32C32_KEYS)
|
||||
config = {
|
||||
"latent_channels": 32,
|
||||
"encoder_block_types": [
|
||||
"ResBlock",
|
||||
"ResBlock",
|
||||
"ResBlock",
|
||||
"EfficientViTBlock",
|
||||
"EfficientViTBlock",
|
||||
"EfficientViTBlock",
|
||||
],
|
||||
"decoder_block_types": [
|
||||
"ResBlock",
|
||||
"ResBlock",
|
||||
"ResBlock",
|
||||
"EfficientViTBlock",
|
||||
"EfficientViTBlock",
|
||||
"EfficientViTBlock",
|
||||
],
|
||||
"encoder_block_out_channels": [128, 256, 512, 512, 1024, 1024],
|
||||
"decoder_block_out_channels": [128, 256, 512, 512, 1024, 1024],
|
||||
"encoder_layers_per_block": [0, 4, 8, 2, 2, 2],
|
||||
"decoder_layers_per_block": [0, 5, 10, 2, 2, 2],
|
||||
"encoder_qkv_multiscales": ((), (), (), (), (), ()),
|
||||
"decoder_qkv_multiscales": ((), (), (), (), (), ()),
|
||||
"decoder_norm_types": ["batch_norm", "batch_norm", "batch_norm", "rms_norm", "rms_norm", "rms_norm"],
|
||||
"decoder_act_fns": ["relu", "relu", "relu", "silu", "silu", "silu"],
|
||||
}
|
||||
if name == "dc-ae-f32c32-in-1.0":
|
||||
config["scaling_factor"] = 0.3189
|
||||
elif name == "dc-ae-f32c32-mix-1.0":
|
||||
config["scaling_factor"] = 0.4552
|
||||
elif name in ["dc-ae-f64c128-in-1.0", "dc-ae-f64c128-mix-1.0"]:
|
||||
AE_KEYS_RENAME_DICT.update(AE_F64C128_KEYS)
|
||||
config = {
|
||||
"latent_channels": 128,
|
||||
"encoder_block_types": [
|
||||
"ResBlock",
|
||||
"ResBlock",
|
||||
"ResBlock",
|
||||
"EfficientViTBlock",
|
||||
"EfficientViTBlock",
|
||||
"EfficientViTBlock",
|
||||
"EfficientViTBlock",
|
||||
],
|
||||
"decoder_block_types": [
|
||||
"ResBlock",
|
||||
"ResBlock",
|
||||
"ResBlock",
|
||||
"EfficientViTBlock",
|
||||
"EfficientViTBlock",
|
||||
"EfficientViTBlock",
|
||||
"EfficientViTBlock",
|
||||
],
|
||||
"encoder_block_out_channels": [128, 256, 512, 512, 1024, 1024, 2048],
|
||||
"decoder_block_out_channels": [128, 256, 512, 512, 1024, 1024, 2048],
|
||||
"encoder_layers_per_block": [0, 4, 8, 2, 2, 2, 2],
|
||||
"decoder_layers_per_block": [0, 5, 10, 2, 2, 2, 2],
|
||||
"encoder_qkv_multiscales": ((), (), (), (), (), (), ()),
|
||||
"decoder_qkv_multiscales": ((), (), (), (), (), (), ()),
|
||||
"decoder_norm_types": [
|
||||
"batch_norm",
|
||||
"batch_norm",
|
||||
"batch_norm",
|
||||
"rms_norm",
|
||||
"rms_norm",
|
||||
"rms_norm",
|
||||
"rms_norm",
|
||||
],
|
||||
"decoder_act_fns": ["relu", "relu", "relu", "silu", "silu", "silu", "silu"],
|
||||
}
|
||||
if name == "dc-ae-f64c128-in-1.0":
|
||||
config["scaling_factor"] = 0.2889
|
||||
elif name == "dc-ae-f64c128-mix-1.0":
|
||||
config["scaling_factor"] = 0.4538
|
||||
elif name in ["dc-ae-f128c512-in-1.0", "dc-ae-f128c512-mix-1.0"]:
|
||||
AE_KEYS_RENAME_DICT.update(AE_F128C512_KEYS)
|
||||
config = {
|
||||
"latent_channels": 512,
|
||||
"encoder_block_types": [
|
||||
"ResBlock",
|
||||
"ResBlock",
|
||||
"ResBlock",
|
||||
"EfficientViTBlock",
|
||||
"EfficientViTBlock",
|
||||
"EfficientViTBlock",
|
||||
"EfficientViTBlock",
|
||||
"EfficientViTBlock",
|
||||
],
|
||||
"decoder_block_types": [
|
||||
"ResBlock",
|
||||
"ResBlock",
|
||||
"ResBlock",
|
||||
"EfficientViTBlock",
|
||||
"EfficientViTBlock",
|
||||
"EfficientViTBlock",
|
||||
"EfficientViTBlock",
|
||||
"EfficientViTBlock",
|
||||
],
|
||||
"encoder_block_out_channels": [128, 256, 512, 512, 1024, 1024, 2048, 2048],
|
||||
"decoder_block_out_channels": [128, 256, 512, 512, 1024, 1024, 2048, 2048],
|
||||
"encoder_layers_per_block": [0, 4, 8, 2, 2, 2, 2, 2],
|
||||
"decoder_layers_per_block": [0, 5, 10, 2, 2, 2, 2, 2],
|
||||
"encoder_qkv_multiscales": ((), (), (), (), (), (), (), ()),
|
||||
"decoder_qkv_multiscales": ((), (), (), (), (), (), (), ()),
|
||||
"decoder_norm_types": [
|
||||
"batch_norm",
|
||||
"batch_norm",
|
||||
"batch_norm",
|
||||
"rms_norm",
|
||||
"rms_norm",
|
||||
"rms_norm",
|
||||
"rms_norm",
|
||||
"rms_norm",
|
||||
],
|
||||
"decoder_act_fns": ["relu", "relu", "relu", "silu", "silu", "silu", "silu", "silu"],
|
||||
}
|
||||
if name == "dc-ae-f128c512-in-1.0":
|
||||
config["scaling_factor"] = 0.4883
|
||||
elif name == "dc-ae-f128c512-mix-1.0":
|
||||
config["scaling_factor"] = 0.3620
|
||||
else:
|
||||
raise ValueError("Invalid config name provided.")
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--config_name",
|
||||
type=str,
|
||||
default="dc-ae-f32c32-sana-1.0",
|
||||
choices=[
|
||||
"dc-ae-f32c32-sana-1.0",
|
||||
"dc-ae-f32c32-in-1.0",
|
||||
"dc-ae-f32c32-mix-1.0",
|
||||
"dc-ae-f64c128-in-1.0",
|
||||
"dc-ae-f64c128-mix-1.0",
|
||||
"dc-ae-f128c512-in-1.0",
|
||||
"dc-ae-f128c512-mix-1.0",
|
||||
],
|
||||
help="The DCAE checkpoint to convert",
|
||||
)
|
||||
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
|
||||
parser.add_argument("--dtype", default="fp32", help="Torch dtype to save the model in.")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
DTYPE_MAPPING = {
|
||||
"fp32": torch.float32,
|
||||
"fp16": torch.float16,
|
||||
"bf16": torch.bfloat16,
|
||||
}
|
||||
|
||||
VARIANT_MAPPING = {
|
||||
"fp32": None,
|
||||
"fp16": "fp16",
|
||||
"bf16": "bf16",
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = get_args()
|
||||
|
||||
dtype = DTYPE_MAPPING[args.dtype]
|
||||
variant = VARIANT_MAPPING[args.dtype]
|
||||
|
||||
ae = convert_ae(args.config_name, dtype)
|
||||
ae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB", variant=variant)
|
||||
@@ -80,6 +80,7 @@ else:
|
||||
"AllegroTransformer3DModel",
|
||||
"AsymmetricAutoencoderKL",
|
||||
"AuraFlowTransformer2DModel",
|
||||
"AutoencoderDC",
|
||||
"AutoencoderKL",
|
||||
"AutoencoderKLAllegro",
|
||||
"AutoencoderKLCogVideoX",
|
||||
@@ -572,6 +573,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
AllegroTransformer3DModel,
|
||||
AsymmetricAutoencoderKL,
|
||||
AuraFlowTransformer2DModel,
|
||||
AutoencoderDC,
|
||||
AutoencoderKL,
|
||||
AutoencoderKLAllegro,
|
||||
AutoencoderKLCogVideoX,
|
||||
|
||||
@@ -27,6 +27,7 @@ _import_structure = {}
|
||||
if is_torch_available():
|
||||
_import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"]
|
||||
_import_structure["autoencoders.autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"]
|
||||
_import_structure["autoencoders.autoencoder_dc"] = ["AutoencoderDC"]
|
||||
_import_structure["autoencoders.autoencoder_kl"] = ["AutoencoderKL"]
|
||||
_import_structure["autoencoders.autoencoder_kl_allegro"] = ["AutoencoderKLAllegro"]
|
||||
_import_structure["autoencoders.autoencoder_kl_cogvideox"] = ["AutoencoderKLCogVideoX"]
|
||||
@@ -88,6 +89,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .adapter import MultiAdapter, T2IAdapter
|
||||
from .autoencoders import (
|
||||
AsymmetricAutoencoderKL,
|
||||
AutoencoderDC,
|
||||
AutoencoderKL,
|
||||
AutoencoderKLAllegro,
|
||||
AutoencoderKLCogVideoX,
|
||||
|
||||
@@ -752,6 +752,98 @@ class Attention(nn.Module):
|
||||
self.fused_projections = fuse
|
||||
|
||||
|
||||
class SanaMultiscaleAttentionProjection(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
num_attention_heads: int,
|
||||
kernel_size: int,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
channels = 3 * in_channels
|
||||
self.proj_in = nn.Conv2d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
padding=kernel_size // 2,
|
||||
groups=channels,
|
||||
bias=False,
|
||||
)
|
||||
self.proj_out = nn.Conv2d(channels, channels, 1, 1, 0, groups=3 * num_attention_heads, bias=False)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.proj_in(hidden_states)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class SanaMultiscaleLinearAttention(nn.Module):
|
||||
r"""Lightweight multi-scale linear attention"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
num_attention_heads: Optional[int] = None,
|
||||
attention_head_dim: int = 8,
|
||||
mult: float = 1.0,
|
||||
norm_type: str = "batch_norm",
|
||||
kernel_sizes: Tuple[int, ...] = (5,),
|
||||
eps: float = 1e-15,
|
||||
residual_connection: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# To prevent circular import
|
||||
from .normalization import get_normalization
|
||||
|
||||
self.eps = eps
|
||||
self.attention_head_dim = attention_head_dim
|
||||
self.norm_type = norm_type
|
||||
self.residual_connection = residual_connection
|
||||
|
||||
num_attention_heads = (
|
||||
int(in_channels // attention_head_dim * mult) if num_attention_heads is None else num_attention_heads
|
||||
)
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
|
||||
self.to_q = nn.Linear(in_channels, inner_dim, bias=False)
|
||||
self.to_k = nn.Linear(in_channels, inner_dim, bias=False)
|
||||
self.to_v = nn.Linear(in_channels, inner_dim, bias=False)
|
||||
|
||||
self.to_qkv_multiscale = nn.ModuleList()
|
||||
for kernel_size in kernel_sizes:
|
||||
self.to_qkv_multiscale.append(
|
||||
SanaMultiscaleAttentionProjection(inner_dim, num_attention_heads, kernel_size)
|
||||
)
|
||||
|
||||
self.nonlinearity = nn.ReLU()
|
||||
self.to_out = nn.Linear(inner_dim * (1 + len(kernel_sizes)), out_channels, bias=False)
|
||||
self.norm_out = get_normalization(norm_type, num_features=out_channels)
|
||||
|
||||
self.processor = SanaMultiscaleAttnProcessor2_0()
|
||||
|
||||
def apply_linear_attention(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
|
||||
value = F.pad(value, (0, 0, 0, 1), mode="constant", value=1) # Adds padding
|
||||
scores = torch.matmul(value, key.transpose(-1, -2))
|
||||
hidden_states = torch.matmul(scores, query)
|
||||
|
||||
hidden_states = hidden_states.to(dtype=torch.float32)
|
||||
hidden_states = hidden_states[:, :, :-1] / (hidden_states[:, :, -1:] + self.eps)
|
||||
return hidden_states
|
||||
|
||||
def apply_quadratic_attention(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
|
||||
scores = torch.matmul(key.transpose(-1, -2), query)
|
||||
scores = scores.to(dtype=torch.float32)
|
||||
scores = scores / (torch.sum(scores, dim=2, keepdim=True) + self.eps)
|
||||
hidden_states = torch.matmul(value, scores)
|
||||
return hidden_states
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
return self.processor(self, hidden_states)
|
||||
|
||||
|
||||
class AttnProcessor:
|
||||
r"""
|
||||
Default processor for performing attention-related computations.
|
||||
@@ -5007,6 +5099,66 @@ class PAGCFGIdentitySelfAttnProcessor2_0:
|
||||
return hidden_states
|
||||
|
||||
|
||||
class SanaMultiscaleAttnProcessor2_0:
|
||||
r"""
|
||||
Processor for implementing multiscale quadratic attention.
|
||||
"""
|
||||
|
||||
def __call__(self, attn: SanaMultiscaleLinearAttention, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
height, width = hidden_states.shape[-2:]
|
||||
if height * width > attn.attention_head_dim:
|
||||
use_linear_attention = True
|
||||
else:
|
||||
use_linear_attention = False
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
batch_size, _, height, width = list(hidden_states.size())
|
||||
original_dtype = hidden_states.dtype
|
||||
|
||||
hidden_states = hidden_states.movedim(1, -1)
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(hidden_states)
|
||||
value = attn.to_v(hidden_states)
|
||||
hidden_states = torch.cat([query, key, value], dim=3)
|
||||
hidden_states = hidden_states.movedim(-1, 1)
|
||||
|
||||
multi_scale_qkv = [hidden_states]
|
||||
for block in attn.to_qkv_multiscale:
|
||||
multi_scale_qkv.append(block(hidden_states))
|
||||
|
||||
hidden_states = torch.cat(multi_scale_qkv, dim=1)
|
||||
|
||||
if use_linear_attention:
|
||||
# for linear attention upcast hidden_states to float32
|
||||
hidden_states = hidden_states.to(dtype=torch.float32)
|
||||
|
||||
hidden_states = hidden_states.reshape(batch_size, -1, 3 * attn.attention_head_dim, height * width)
|
||||
|
||||
query, key, value = hidden_states.chunk(3, dim=2)
|
||||
query = attn.nonlinearity(query)
|
||||
key = attn.nonlinearity(key)
|
||||
|
||||
if use_linear_attention:
|
||||
hidden_states = attn.apply_linear_attention(query, key, value)
|
||||
hidden_states = hidden_states.to(dtype=original_dtype)
|
||||
else:
|
||||
hidden_states = attn.apply_quadratic_attention(query, key, value)
|
||||
|
||||
hidden_states = torch.reshape(hidden_states, (batch_size, -1, height, width))
|
||||
hidden_states = attn.to_out(hidden_states.movedim(1, -1)).movedim(-1, 1)
|
||||
|
||||
if attn.norm_type == "rms_norm":
|
||||
hidden_states = attn.norm_out(hidden_states.movedim(1, -1)).movedim(-1, 1)
|
||||
else:
|
||||
hidden_states = attn.norm_out(hidden_states)
|
||||
|
||||
if attn.residual_connection:
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class LoRAAttnProcessor:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from .autoencoder_asym_kl import AsymmetricAutoencoderKL
|
||||
from .autoencoder_dc import AutoencoderDC
|
||||
from .autoencoder_kl import AutoencoderKL
|
||||
from .autoencoder_kl_allegro import AutoencoderKLAllegro
|
||||
from .autoencoder_kl_cogvideox import AutoencoderKLCogVideoX
|
||||
|
||||
648
src/diffusers/models/autoencoders/autoencoder_dc.py
Normal file
648
src/diffusers/models/autoencoders/autoencoder_dc.py
Normal file
@@ -0,0 +1,648 @@
|
||||
# Copyright 2024 MIT, Tsinghua University, NVIDIA CORPORATION and The HuggingFace Team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin
|
||||
from ...utils.accelerate_utils import apply_forward_hook
|
||||
from ..activations import get_activation
|
||||
from ..attention_processor import SanaMultiscaleLinearAttention
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import RMSNorm, get_normalization
|
||||
from .vae import DecoderOutput, EncoderOutput
|
||||
|
||||
|
||||
class GLUMBConv(nn.Module):
|
||||
def __init__(self, in_channels: int, out_channels: int) -> None:
|
||||
super().__init__()
|
||||
|
||||
hidden_channels = 4 * in_channels
|
||||
|
||||
self.nonlinearity = nn.SiLU()
|
||||
|
||||
self.conv_inverted = nn.Conv2d(in_channels, hidden_channels * 2, 1, 1, 0)
|
||||
self.conv_depth = nn.Conv2d(hidden_channels * 2, hidden_channels * 2, 3, 1, 1, groups=hidden_channels * 2)
|
||||
self.conv_point = nn.Conv2d(hidden_channels, out_channels, 1, 1, 0, bias=False)
|
||||
self.norm = RMSNorm(out_channels, eps=1e-5, elementwise_affine=True, bias=True)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.conv_inverted(hidden_states)
|
||||
hidden_states = self.nonlinearity(hidden_states)
|
||||
|
||||
hidden_states = self.conv_depth(hidden_states)
|
||||
hidden_states, gate = torch.chunk(hidden_states, 2, dim=1)
|
||||
hidden_states = hidden_states * self.nonlinearity(gate)
|
||||
|
||||
hidden_states = self.conv_point(hidden_states)
|
||||
# move channel to the last dimension so we apply RMSnorm across channel dimension
|
||||
hidden_states = self.norm(hidden_states.movedim(1, -1)).movedim(-1, 1)
|
||||
|
||||
return hidden_states + residual
|
||||
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
norm_type: str = "batch_norm",
|
||||
act_fn: str = "relu6",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.norm_type = norm_type
|
||||
|
||||
self.nonlinearity = get_activation(act_fn) if act_fn is not None else nn.Identity()
|
||||
self.conv1 = nn.Conv2d(in_channels, in_channels, 3, 1, 1)
|
||||
self.conv2 = nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False)
|
||||
self.norm = get_normalization(norm_type, out_channels)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
residual = hidden_states
|
||||
hidden_states = self.conv1(hidden_states)
|
||||
hidden_states = self.nonlinearity(hidden_states)
|
||||
hidden_states = self.conv2(hidden_states)
|
||||
|
||||
if self.norm_type == "rms_norm":
|
||||
# move channel to the last dimension so we apply RMSnorm across channel dimension
|
||||
hidden_states = self.norm(hidden_states.movedim(1, -1)).movedim(-1, 1)
|
||||
else:
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
return hidden_states + residual
|
||||
|
||||
|
||||
class EfficientViTBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
mult: float = 1.0,
|
||||
attention_head_dim: int = 32,
|
||||
qkv_multiscales: Tuple[int, ...] = (5,),
|
||||
norm_type: str = "batch_norm",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.attn = SanaMultiscaleLinearAttention(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
mult=mult,
|
||||
attention_head_dim=attention_head_dim,
|
||||
norm_type=norm_type,
|
||||
kernel_sizes=qkv_multiscales,
|
||||
residual_connection=True,
|
||||
)
|
||||
|
||||
self.conv_out = GLUMBConv(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.attn(x)
|
||||
x = self.conv_out(x)
|
||||
return x
|
||||
|
||||
|
||||
def get_block(
|
||||
block_type: str,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
attention_head_dim: int,
|
||||
norm_type: str,
|
||||
act_fn: str,
|
||||
qkv_mutliscales: Tuple[int] = (),
|
||||
):
|
||||
if block_type == "ResBlock":
|
||||
block = ResBlock(in_channels, out_channels, norm_type, act_fn)
|
||||
|
||||
elif block_type == "EfficientViTBlock":
|
||||
block = EfficientViTBlock(
|
||||
in_channels, attention_head_dim=attention_head_dim, norm_type=norm_type, qkv_multiscales=qkv_mutliscales
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Block with {block_type=} is not supported.")
|
||||
|
||||
return block
|
||||
|
||||
|
||||
class DCDownBlock2d(nn.Module):
|
||||
def __init__(self, in_channels: int, out_channels: int, downsample: bool = False, shortcut: bool = True) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.downsample = downsample
|
||||
self.factor = 2
|
||||
self.stride = 1 if downsample else 2
|
||||
self.group_size = in_channels * self.factor**2 // out_channels
|
||||
self.shortcut = shortcut
|
||||
|
||||
out_ratio = self.factor**2
|
||||
if downsample:
|
||||
assert out_channels % out_ratio == 0
|
||||
out_channels = out_channels // out_ratio
|
||||
|
||||
self.conv = nn.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=self.stride,
|
||||
padding=1,
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
x = self.conv(hidden_states)
|
||||
if self.downsample:
|
||||
x = F.pixel_unshuffle(x, self.factor)
|
||||
|
||||
if self.shortcut:
|
||||
y = F.pixel_unshuffle(hidden_states, self.factor)
|
||||
y = y.unflatten(1, (-1, self.group_size))
|
||||
y = y.mean(dim=2)
|
||||
hidden_states = x + y
|
||||
else:
|
||||
hidden_states = x
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class DCUpBlock2d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
interpolate: bool = False,
|
||||
shortcut: bool = True,
|
||||
interpolation_mode: str = "nearest",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.interpolate = interpolate
|
||||
self.interpolation_mode = interpolation_mode
|
||||
self.shortcut = shortcut
|
||||
self.factor = 2
|
||||
self.repeats = out_channels * self.factor**2 // in_channels
|
||||
|
||||
out_ratio = self.factor**2
|
||||
|
||||
if not interpolate:
|
||||
out_channels = out_channels * out_ratio
|
||||
|
||||
self.conv = nn.Conv2d(in_channels, out_channels, 3, 1, 1)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
if self.interpolate:
|
||||
x = F.interpolate(hidden_states, scale_factor=self.factor, mode=self.interpolation_mode)
|
||||
x = self.conv(x)
|
||||
else:
|
||||
x = self.conv(hidden_states)
|
||||
x = F.pixel_shuffle(x, self.factor)
|
||||
|
||||
if self.shortcut:
|
||||
y = hidden_states.repeat_interleave(self.repeats, dim=1)
|
||||
y = F.pixel_shuffle(y, self.factor)
|
||||
hidden_states = x + y
|
||||
else:
|
||||
hidden_states = x
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
latent_channels: int,
|
||||
attention_head_dim: int = 32,
|
||||
block_type: Union[str, Tuple[str]] = "ResBlock",
|
||||
block_out_channels: Tuple[int] = (128, 256, 512, 512, 1024, 1024),
|
||||
layers_per_block: Tuple[int] = (2, 2, 2, 2, 2, 2),
|
||||
qkv_multiscales: Tuple[Tuple[int, ...], ...] = ((), (), (), (5,), (5,), (5,)),
|
||||
downsample_block_type: str = "pixel_unshuffle",
|
||||
out_shortcut: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
num_blocks = len(block_out_channels)
|
||||
|
||||
if isinstance(block_type, str):
|
||||
block_type = (block_type,) * num_blocks
|
||||
|
||||
if layers_per_block[0] > 0:
|
||||
self.conv_in = nn.Conv2d(
|
||||
in_channels,
|
||||
block_out_channels[0] if layers_per_block[0] > 0 else block_out_channels[1],
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
)
|
||||
else:
|
||||
self.conv_in = DCDownBlock2d(
|
||||
in_channels=in_channels,
|
||||
out_channels=block_out_channels[0] if layers_per_block[0] > 0 else block_out_channels[1],
|
||||
downsample=downsample_block_type == "pixel_unshuffle",
|
||||
shortcut=False,
|
||||
)
|
||||
|
||||
down_blocks = []
|
||||
for i, (out_channel, num_layers) in enumerate(zip(block_out_channels, layers_per_block)):
|
||||
down_block_list = []
|
||||
|
||||
for _ in range(num_layers):
|
||||
block = get_block(
|
||||
block_type[i],
|
||||
out_channel,
|
||||
out_channel,
|
||||
attention_head_dim=attention_head_dim,
|
||||
norm_type="rms_norm",
|
||||
act_fn="silu",
|
||||
qkv_mutliscales=qkv_multiscales[i],
|
||||
)
|
||||
down_block_list.append(block)
|
||||
|
||||
if i < num_blocks - 1 and num_layers > 0:
|
||||
downsample_block = DCDownBlock2d(
|
||||
in_channels=out_channel,
|
||||
out_channels=block_out_channels[i + 1],
|
||||
downsample=downsample_block_type == "pixel_unshuffle",
|
||||
shortcut=True,
|
||||
)
|
||||
down_block_list.append(downsample_block)
|
||||
|
||||
down_blocks.append(nn.Sequential(*down_block_list))
|
||||
|
||||
self.down_blocks = nn.ModuleList(down_blocks)
|
||||
|
||||
self.conv_out = nn.Conv2d(block_out_channels[-1], latent_channels, 3, 1, 1)
|
||||
|
||||
self.out_shortcut = out_shortcut
|
||||
if out_shortcut:
|
||||
self.out_shortcut_average_group_size = block_out_channels[-1] // latent_channels
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.conv_in(hidden_states)
|
||||
for down_block in self.down_blocks:
|
||||
hidden_states = down_block(hidden_states)
|
||||
|
||||
if self.out_shortcut:
|
||||
x = hidden_states.unflatten(1, (-1, self.out_shortcut_average_group_size))
|
||||
x = x.mean(dim=2)
|
||||
hidden_states = self.conv_out(hidden_states) + x
|
||||
else:
|
||||
hidden_states = self.conv_out(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
latent_channels: int,
|
||||
attention_head_dim: int = 32,
|
||||
block_type: Union[str, Tuple[str]] = "ResBlock",
|
||||
block_out_channels: Tuple[int] = (128, 256, 512, 512, 1024, 1024),
|
||||
layers_per_block: Tuple[int] = (2, 2, 2, 2, 2, 2),
|
||||
qkv_multiscales: Tuple[Tuple[int, ...], ...] = ((), (), (), (5,), (5,), (5,)),
|
||||
norm_type: Union[str, Tuple[str]] = "rms_norm",
|
||||
act_fn: Union[str, Tuple[str]] = "silu",
|
||||
upsample_block_type: str = "pixel_shuffle",
|
||||
in_shortcut: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
num_blocks = len(block_out_channels)
|
||||
|
||||
if isinstance(block_type, str):
|
||||
block_type = (block_type,) * num_blocks
|
||||
if isinstance(norm_type, str):
|
||||
norm_type = (norm_type,) * num_blocks
|
||||
if isinstance(act_fn, str):
|
||||
act_fn = (act_fn,) * num_blocks
|
||||
|
||||
self.conv_in = nn.Conv2d(latent_channels, block_out_channels[-1], 3, 1, 1)
|
||||
|
||||
self.in_shortcut = in_shortcut
|
||||
if in_shortcut:
|
||||
self.in_shortcut_repeats = block_out_channels[-1] // latent_channels
|
||||
|
||||
up_blocks = []
|
||||
for i, (out_channel, num_layers) in reversed(list(enumerate(zip(block_out_channels, layers_per_block)))):
|
||||
up_block_list = []
|
||||
|
||||
if i < num_blocks - 1 and num_layers > 0:
|
||||
upsample_block = DCUpBlock2d(
|
||||
block_out_channels[i + 1],
|
||||
out_channel,
|
||||
interpolate=upsample_block_type == "interpolate",
|
||||
shortcut=True,
|
||||
)
|
||||
up_block_list.append(upsample_block)
|
||||
|
||||
for _ in range(num_layers):
|
||||
block = get_block(
|
||||
block_type[i],
|
||||
out_channel,
|
||||
out_channel,
|
||||
attention_head_dim=attention_head_dim,
|
||||
norm_type=norm_type[i],
|
||||
act_fn=act_fn[i],
|
||||
qkv_mutliscales=qkv_multiscales[i],
|
||||
)
|
||||
up_block_list.append(block)
|
||||
|
||||
up_blocks.insert(0, nn.Sequential(*up_block_list))
|
||||
|
||||
self.up_blocks = nn.ModuleList(up_blocks)
|
||||
|
||||
channels = block_out_channels[0] if layers_per_block[0] > 0 else block_out_channels[1]
|
||||
|
||||
self.norm_out = RMSNorm(channels, 1e-5, elementwise_affine=True, bias=True)
|
||||
self.conv_act = nn.ReLU()
|
||||
self.conv_out = None
|
||||
|
||||
if layers_per_block[0] > 0:
|
||||
self.conv_out = nn.Conv2d(channels, in_channels, 3, 1, 1)
|
||||
else:
|
||||
self.conv_out = DCUpBlock2d(
|
||||
channels, in_channels, interpolate=upsample_block_type == "interpolate", shortcut=False
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
if self.in_shortcut:
|
||||
x = hidden_states.repeat_interleave(self.in_shortcut_repeats, dim=1)
|
||||
hidden_states = self.conv_in(hidden_states) + x
|
||||
else:
|
||||
hidden_states = self.conv_in(hidden_states)
|
||||
|
||||
for up_block in reversed(self.up_blocks):
|
||||
hidden_states = up_block(hidden_states)
|
||||
|
||||
hidden_states = self.norm_out(hidden_states.movedim(1, -1)).movedim(-1, 1)
|
||||
hidden_states = self.conv_act(hidden_states)
|
||||
hidden_states = self.conv_out(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class AutoencoderDC(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
r"""
|
||||
An Autoencoder model introduced in [DCAE](https://arxiv.org/abs/2410.10733) and used in
|
||||
[SANA](https://arxiv.org/abs/2410.10629).
|
||||
|
||||
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
|
||||
for all models (such as downloading or saving).
|
||||
|
||||
Args:
|
||||
in_channels (`int`, defaults to `3`):
|
||||
The number of input channels in samples.
|
||||
latent_channels (`int`, defaults to `32`):
|
||||
The number of channels in the latent space representation.
|
||||
encoder_block_types (`Union[str, Tuple[str]]`, defaults to `"ResBlock"`):
|
||||
The type(s) of block to use in the encoder.
|
||||
decoder_block_types (`Union[str, Tuple[str]]`, defaults to `"ResBlock"`):
|
||||
The type(s) of block to use in the decoder.
|
||||
encoder_block_out_channels (`Tuple[int, ...]`, defaults to `(128, 256, 512, 512, 1024, 1024)`):
|
||||
The number of output channels for each block in the encoder.
|
||||
decoder_block_out_channels (`Tuple[int, ...]`, defaults to `(128, 256, 512, 512, 1024, 1024)`):
|
||||
The number of output channels for each block in the decoder.
|
||||
encoder_layers_per_block (`Tuple[int]`, defaults to `(2, 2, 2, 3, 3, 3)`):
|
||||
The number of layers per block in the encoder.
|
||||
decoder_layers_per_block (`Tuple[int]`, defaults to `(3, 3, 3, 3, 3, 3)`):
|
||||
The number of layers per block in the decoder.
|
||||
encoder_qkv_multiscales (`Tuple[Tuple[int, ...], ...]`, defaults to `((), (), (), (5,), (5,), (5,))`):
|
||||
Multi-scale configurations for the encoder's QKV (query-key-value) transformations.
|
||||
decoder_qkv_multiscales (`Tuple[Tuple[int, ...], ...]`, defaults to `((), (), (), (5,), (5,), (5,))`):
|
||||
Multi-scale configurations for the decoder's QKV (query-key-value) transformations.
|
||||
upsample_block_type (`str`, defaults to `"pixel_shuffle"`):
|
||||
The type of block to use for upsampling in the decoder.
|
||||
downsample_block_type (`str`, defaults to `"pixel_unshuffle"`):
|
||||
The type of block to use for downsampling in the encoder.
|
||||
decoder_norm_types (`Union[str, Tuple[str]]`, defaults to `"rms_norm"`):
|
||||
The normalization type(s) to use in the decoder.
|
||||
decoder_act_fns (`Union[str, Tuple[str]]`, defaults to `"silu"`):
|
||||
The activation function(s) to use in the decoder.
|
||||
scaling_factor (`float`, defaults to `1.0`):
|
||||
The multiplicative inverse of the root mean square of the latent features. This is used to scale the latent
|
||||
space to have unit variance when training the diffusion model. The latents are scaled with the formula `z =
|
||||
z * scaling_factor` before being passed to the diffusion model. When decoding, the latents are scaled back
|
||||
to the original scale with the formula: `z = 1 / scaling_factor * z`.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = False
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 3,
|
||||
latent_channels: int = 32,
|
||||
attention_head_dim: int = 32,
|
||||
encoder_block_types: Union[str, Tuple[str]] = "ResBlock",
|
||||
decoder_block_types: Union[str, Tuple[str]] = "ResBlock",
|
||||
encoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 512, 1024, 1024),
|
||||
decoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 512, 1024, 1024),
|
||||
encoder_layers_per_block: Tuple[int] = (2, 2, 2, 3, 3, 3),
|
||||
decoder_layers_per_block: Tuple[int] = (3, 3, 3, 3, 3, 3),
|
||||
encoder_qkv_multiscales: Tuple[Tuple[int, ...], ...] = ((), (), (), (5,), (5,), (5,)),
|
||||
decoder_qkv_multiscales: Tuple[Tuple[int, ...], ...] = ((), (), (), (5,), (5,), (5,)),
|
||||
upsample_block_type: str = "pixel_shuffle",
|
||||
downsample_block_type: str = "pixel_unshuffle",
|
||||
decoder_norm_types: Union[str, Tuple[str]] = "rms_norm",
|
||||
decoder_act_fns: Union[str, Tuple[str]] = "silu",
|
||||
scaling_factor: float = 1.0,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.encoder = Encoder(
|
||||
in_channels=in_channels,
|
||||
latent_channels=latent_channels,
|
||||
attention_head_dim=attention_head_dim,
|
||||
block_type=encoder_block_types,
|
||||
block_out_channels=encoder_block_out_channels,
|
||||
layers_per_block=encoder_layers_per_block,
|
||||
qkv_multiscales=encoder_qkv_multiscales,
|
||||
downsample_block_type=downsample_block_type,
|
||||
)
|
||||
self.decoder = Decoder(
|
||||
in_channels=in_channels,
|
||||
latent_channels=latent_channels,
|
||||
attention_head_dim=attention_head_dim,
|
||||
block_type=decoder_block_types,
|
||||
block_out_channels=decoder_block_out_channels,
|
||||
layers_per_block=decoder_layers_per_block,
|
||||
qkv_multiscales=decoder_qkv_multiscales,
|
||||
norm_type=decoder_norm_types,
|
||||
act_fn=decoder_act_fns,
|
||||
upsample_block_type=upsample_block_type,
|
||||
)
|
||||
|
||||
self.spatial_compression_ratio = 2 ** (len(encoder_block_out_channels) - 1)
|
||||
self.temporal_compression_ratio = 1
|
||||
|
||||
# When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension
|
||||
# to perform decoding of a single video latent at a time.
|
||||
self.use_slicing = False
|
||||
|
||||
# When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent
|
||||
# frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the
|
||||
# intermediate tiles together, the memory requirement can be lowered.
|
||||
self.use_tiling = False
|
||||
|
||||
# The minimal tile height and width for spatial tiling to be used
|
||||
self.tile_sample_min_height = 512
|
||||
self.tile_sample_min_width = 512
|
||||
|
||||
# The minimal distance between two spatial tiles
|
||||
self.tile_sample_stride_height = 448
|
||||
self.tile_sample_stride_width = 448
|
||||
|
||||
def enable_tiling(
|
||||
self,
|
||||
tile_sample_min_height: Optional[int] = None,
|
||||
tile_sample_min_width: Optional[int] = None,
|
||||
tile_sample_stride_height: Optional[float] = None,
|
||||
tile_sample_stride_width: Optional[float] = None,
|
||||
) -> None:
|
||||
r"""
|
||||
Enable tiled AE decoding. When this option is enabled, the AE will split the input tensor into tiles to compute
|
||||
decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
||||
processing larger images.
|
||||
|
||||
Args:
|
||||
tile_sample_min_height (`int`, *optional*):
|
||||
The minimum height required for a sample to be separated into tiles across the height dimension.
|
||||
tile_sample_min_width (`int`, *optional*):
|
||||
The minimum width required for a sample to be separated into tiles across the width dimension.
|
||||
tile_sample_stride_height (`int`, *optional*):
|
||||
The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
|
||||
no tiling artifacts produced across the height dimension.
|
||||
tile_sample_stride_width (`int`, *optional*):
|
||||
The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling
|
||||
artifacts produced across the width dimension.
|
||||
"""
|
||||
self.use_tiling = True
|
||||
self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
|
||||
self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
|
||||
self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height
|
||||
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
|
||||
|
||||
def disable_tiling(self) -> None:
|
||||
r"""
|
||||
Disable tiled AE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
|
||||
decoding in one step.
|
||||
"""
|
||||
self.use_tiling = False
|
||||
|
||||
def enable_slicing(self) -> None:
|
||||
r"""
|
||||
Enable sliced AE decoding. When this option is enabled, the AE will split the input tensor in slices to compute
|
||||
decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
||||
"""
|
||||
self.use_slicing = True
|
||||
|
||||
def disable_slicing(self) -> None:
|
||||
r"""
|
||||
Disable sliced AE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
|
||||
decoding in one step.
|
||||
"""
|
||||
self.use_slicing = False
|
||||
|
||||
def _encode(self, x: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, num_channels, height, width = x.shape
|
||||
|
||||
if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
|
||||
return self.tiled_encode(x, return_dict=False)[0]
|
||||
|
||||
encoded = self.encoder(x)
|
||||
|
||||
return encoded
|
||||
|
||||
@apply_forward_hook
|
||||
def encode(self, x: torch.Tensor, return_dict: bool = True) -> Union[EncoderOutput, Tuple[torch.Tensor]]:
|
||||
r"""
|
||||
Encode a batch of images into latents.
|
||||
|
||||
Args:
|
||||
x (`torch.Tensor`): Input batch of images.
|
||||
return_dict (`bool`, defaults to `True`):
|
||||
Whether to return a [`~models.vae.EncoderOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
The latent representations of the encoded videos. If `return_dict` is True, a
|
||||
[`~models.vae.EncoderOutput`] is returned, otherwise a plain `tuple` is returned.
|
||||
"""
|
||||
if self.use_slicing and x.shape[0] > 1:
|
||||
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
|
||||
encoded = torch.cat(encoded_slices)
|
||||
else:
|
||||
encoded = self._encode(x)
|
||||
|
||||
if not return_dict:
|
||||
return (encoded,)
|
||||
return EncoderOutput(latent=encoded)
|
||||
|
||||
def _decode(self, z: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, num_channels, height, width = z.shape
|
||||
|
||||
if self.use_tiling and (width > self.tile_latent_min_width or height > self.tile_latent_min_height):
|
||||
return self.tiled_decode(z, return_dict=False)[0]
|
||||
|
||||
decoded = self.decoder(z)
|
||||
|
||||
return decoded
|
||||
|
||||
@apply_forward_hook
|
||||
def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, Tuple[torch.Tensor]]:
|
||||
r"""
|
||||
Decode a batch of images.
|
||||
|
||||
Args:
|
||||
z (`torch.Tensor`): Input batch of latent vectors.
|
||||
return_dict (`bool`, defaults to `True`):
|
||||
Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
[`~models.vae.DecoderOutput`] or `tuple`:
|
||||
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
|
||||
returned.
|
||||
"""
|
||||
if self.use_slicing and z.size(0) > 1:
|
||||
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
|
||||
decoded = torch.cat(decoded_slices)
|
||||
else:
|
||||
decoded = self._decode(z)
|
||||
|
||||
if not return_dict:
|
||||
return (decoded,)
|
||||
return DecoderOutput(sample=decoded)
|
||||
|
||||
def tiled_encode(self, x: torch.Tensor, return_dict: bool = True) -> torch.Tensor:
|
||||
raise NotImplementedError("`tiled_encode` has not been implemented for AutoencoderDC.")
|
||||
|
||||
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
||||
raise NotImplementedError("`tiled_decode` has not been implemented for AutoencoderDC.")
|
||||
|
||||
def forward(self, sample: torch.Tensor, return_dict: bool = True) -> torch.Tensor:
|
||||
encoded = self.encode(sample, return_dict=False)[0]
|
||||
decoded = self.decode(encoded, return_dict=False)[0]
|
||||
if not return_dict:
|
||||
return (decoded,)
|
||||
return DecoderOutput(sample=decoded)
|
||||
@@ -30,6 +30,19 @@ from ..unets.unet_2d_blocks import (
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class EncoderOutput(BaseOutput):
|
||||
r"""
|
||||
Output of encoding method.
|
||||
|
||||
Args:
|
||||
latent (`torch.Tensor` of shape `(batch_size, num_channels, latent_height, latent_width)`):
|
||||
The encoded latent.
|
||||
"""
|
||||
|
||||
latent: torch.Tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
class DecoderOutput(BaseOutput):
|
||||
r"""
|
||||
|
||||
@@ -512,20 +512,24 @@ else:
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dim, eps: float, elementwise_affine: bool = True):
|
||||
def __init__(self, dim, eps: float, elementwise_affine: bool = True, bias: bool = False):
|
||||
super().__init__()
|
||||
|
||||
self.eps = eps
|
||||
self.elementwise_affine = elementwise_affine
|
||||
|
||||
if isinstance(dim, numbers.Integral):
|
||||
dim = (dim,)
|
||||
|
||||
self.dim = torch.Size(dim)
|
||||
|
||||
self.weight = None
|
||||
self.bias = None
|
||||
|
||||
if elementwise_affine:
|
||||
self.weight = nn.Parameter(torch.ones(dim))
|
||||
else:
|
||||
self.weight = None
|
||||
if bias:
|
||||
self.bias = nn.Parameter(torch.zeros(dim))
|
||||
|
||||
def forward(self, hidden_states):
|
||||
input_dtype = hidden_states.dtype
|
||||
@@ -537,6 +541,8 @@ class RMSNorm(nn.Module):
|
||||
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
||||
hidden_states = hidden_states.to(self.weight.dtype)
|
||||
hidden_states = hidden_states * self.weight
|
||||
if self.bias is not None:
|
||||
hidden_states = hidden_states + self.bias
|
||||
else:
|
||||
hidden_states = hidden_states.to(input_dtype)
|
||||
|
||||
@@ -566,3 +572,21 @@ class LpNorm(nn.Module):
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
return F.normalize(hidden_states, p=self.p, dim=self.dim, eps=self.eps)
|
||||
|
||||
|
||||
def get_normalization(
|
||||
norm_type: str = "batch_norm",
|
||||
num_features: Optional[int] = None,
|
||||
eps: float = 1e-5,
|
||||
elementwise_affine: bool = True,
|
||||
bias: bool = True,
|
||||
) -> nn.Module:
|
||||
if norm_type == "rms_norm":
|
||||
norm = RMSNorm(num_features, eps=eps, elementwise_affine=elementwise_affine, bias=bias)
|
||||
elif norm_type == "layer_norm":
|
||||
norm = nn.LayerNorm(num_features, eps=eps, elementwise_affine=elementwise_affine, bias=bias)
|
||||
elif norm_type == "batch_norm":
|
||||
norm = nn.BatchNorm2d(num_features, eps=eps, affine=elementwise_affine)
|
||||
else:
|
||||
raise ValueError(f"{norm_type=} is not supported.")
|
||||
return norm
|
||||
|
||||
@@ -47,6 +47,21 @@ class AuraFlowTransformer2DModel(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class AutoencoderDC(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class AutoencoderKL(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
||||
87
tests/models/autoencoders/test_models_autoencoder_dc.py
Normal file
87
tests/models/autoencoders/test_models_autoencoder_dc.py
Normal file
@@ -0,0 +1,87 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from diffusers import AutoencoderDC
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
floats_tensor,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class AutoencoderDCTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
model_class = AutoencoderDC
|
||||
main_input_name = "sample"
|
||||
base_precision = 1e-2
|
||||
|
||||
def get_autoencoder_dc_config(self):
|
||||
return {
|
||||
"in_channels": 3,
|
||||
"latent_channels": 4,
|
||||
"attention_head_dim": 2,
|
||||
"encoder_block_types": (
|
||||
"ResBlock",
|
||||
"EfficientViTBlock",
|
||||
),
|
||||
"decoder_block_types": (
|
||||
"ResBlock",
|
||||
"EfficientViTBlock",
|
||||
),
|
||||
"encoder_block_out_channels": (8, 8),
|
||||
"decoder_block_out_channels": (8, 8),
|
||||
"encoder_qkv_multiscales": ((), (5,)),
|
||||
"decoder_qkv_multiscales": ((), (5,)),
|
||||
"encoder_layers_per_block": (1, 1),
|
||||
"decoder_layers_per_block": [1, 1],
|
||||
"downsample_block_type": "conv",
|
||||
"upsample_block_type": "interpolate",
|
||||
"decoder_norm_types": "rms_norm",
|
||||
"decoder_act_fns": "silu",
|
||||
"scaling_factor": 0.41407,
|
||||
}
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
batch_size = 4
|
||||
num_channels = 3
|
||||
sizes = (32, 32)
|
||||
|
||||
image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
|
||||
|
||||
return {"sample": image}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (3, 32, 32)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (3, 32, 32)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = self.get_autoencoder_dc_config()
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
@unittest.skip("AutoencoderDC does not support `norm_num_groups` because it does not use GroupNorm.")
|
||||
def test_forward_with_norm_groups(self):
|
||||
pass
|
||||
Reference in New Issue
Block a user