Compare commits

...

81 Commits

Author SHA1 Message Date
yiyixuxu
8c7c483bb0 combine attention processor 2024-12-04 23:55:17 +01:00
Aryan
6f29e2aba1 add standard autoencoder methods 2024-12-04 22:24:01 +01:00
Aryan
31f9fc6c96 add docs 2024-12-04 21:10:00 +01:00
Aryan
46eb504352 Merge branch 'main' into DC-AE 2024-12-05 01:26:49 +05:30
Aryan
d6c748c7e6 from original file loader 2024-12-04 20:56:24 +01:00
Junsong Chen
632ad3baac Merge branch 'main' into DC-AE 2024-12-04 23:06:02 +08:00
Aryan
da834d5177 add tests 2024-12-04 13:18:51 +01:00
Aryan
68f817ac34 Merge branch 'main' into DC-AE 2024-12-04 12:58:47 +01:00
Aryan
39a947ce7d refactor 2024-12-04 12:37:34 +01:00
Aryan
30c3238339 move mla to attention processor file; split qkv conv to linears 2024-12-04 11:30:50 +01:00
Aryan
4a224ce906 Merge branch 'main' into DC-AE 2024-12-04 12:32:04 +05:30
Aryan
eb64d52c42 make style 2024-12-04 08:01:52 +01:00
Aryan
0bda5c51bc incorporate changes for all checkpoints 2024-12-04 07:26:30 +01:00
Junyu Chen
64de66a477 add scaling factors 2024-12-02 21:31:56 -08:00
Junyu Chen
074817cfdc replace downsample_block_type from Conv to conv for consistency 2024-12-02 19:15:22 -08:00
Junyu Chen
9ef7b59725 replace get_block_from_block_type to get_block 2024-12-02 17:53:34 -08:00
Junyu Chen
ca3ac4dede replace vae with ae 2024-12-02 17:52:28 -08:00
Aryan
65edfa5e4c resolve conflicts & merge 2024-12-01 12:42:05 +01:00
Aryan
3d5faaff2d make fix-copies 2024-11-30 22:27:41 +01:00
Aryan
babc9f5990 Merge branch 'main' into aryan-dcae 2024-12-01 02:55:47 +05:30
Aryan
54e933b39f refactor 2024-11-30 22:15:00 +01:00
Aryan
0bdb7eff9a possibly change to nn.Linear 2024-11-30 21:54:22 +01:00
Aryan
c4d0867c89 update 2024-11-30 05:53:30 +01:00
Aryan
77571a8fb9 refactor 2024-11-29 16:11:31 +01:00
Aryan
637924124b refactor 2024-11-29 16:04:15 +01:00
Aryan
44034a6f24 refactor 2024-11-29 15:50:26 +01:00
Aryan
f5876c5bc4 fix 2024-11-28 22:13:21 +01:00
Aryan
a2ec5f81ac update 2024-11-28 22:01:46 +01:00
Aryan
bf6c211e65 update 2024-11-28 21:38:09 +01:00
Aryan
7b9d7e5513 make style 2024-11-28 20:24:36 +01:00
Aryan
1f8a3b3b46 update 2024-11-28 20:24:24 +01:00
Aryan
c1c02a20b7 quick push before dgx disappears again 2024-11-28 20:01:24 +01:00
lawrence-cj
2d59056c6c make quality && make style; 2024-11-29 02:23:52 +08:00
Aryan
5ed50e9825 update 2024-11-28 16:32:18 +01:00
lawrence-cj
20da201225 add DC-AE into init files; 2024-11-28 22:36:53 +08:00
lawrence-cj
be9826c5be align with other vae decode output; 2024-11-28 22:35:01 +08:00
Junyu Chen
d3d9c84068 move LiteMLA to attention.py 2024-11-28 01:25:32 -08:00
Junyu Chen
e007057de3 Merge branch 'DC-AE' of github.com:lawrence-cj/diffusers into DC-AE 2024-11-28 01:01:11 -08:00
Junyu Chen
4d3c026622 change file name to autoencoder_dc 2024-11-28 01:01:01 -08:00
Junsong Chen
44957833bc Merge branch 'main' into DC-AE 2024-11-27 01:07:55 +08:00
Junyu Chen
2f6bbad8f0 remove build_stage_main 2024-11-26 04:36:57 -08:00
Junyu Chen
4f5cbb4380 remove op_list & build_block 2024-11-26 04:24:24 -08:00
Junsong Chen
22ea5fd7f7 Update src/diffusers/models/autoencoders/dc_ae.py
Co-authored-by: YiYi Xu <yixu310@gmail.com>
2024-11-21 13:58:24 +08:00
Junsong Chen
c82f828b06 Update src/diffusers/models/autoencoders/dc_ae.py
Co-authored-by: YiYi Xu <yixu310@gmail.com>
2024-11-21 13:58:16 +08:00
Junsong Chen
b4f75f22e4 Update src/diffusers/models/autoencoders/dc_ae.py
Co-authored-by: YiYi Xu <yixu310@gmail.com>
2024-11-21 13:58:07 +08:00
chenjy2003
2e04a9967f remove device and dtype in RMSNorm2d __init__
Co-authored-by: YiYi Xu <yixu310@gmail.com>
2024-11-20 19:17:43 +08:00
chenjy2003
b42bb54854 remove reset_parameters for RMSNorm2d
Co-authored-by: YiYi Xu <yixu310@gmail.com>
2024-11-20 19:17:15 +08:00
Junyu Chen
cab56b1657 remove inheritance of RMSNorm2d from LayerNorm 2024-11-16 06:42:59 -08:00
Junyu Chen
30d6308d1e Merge branch 'DC-AE' of github.com:lawrence-cj/diffusers into DC-AE 2024-11-16 05:48:35 -08:00
Junyu Chen
7ce9ff2fe0 remove build encoder/decoder project in/out 2024-11-16 05:48:28 -08:00
Sayak Paul
59b6e25703 Merge branch 'main' into DC-AE 2024-11-16 13:54:11 +01:00
Junyu Chen
96e844b786 update other blocks to support the removal of build_norm 2024-11-16 01:52:14 -08:00
Junyu Chen
25ae389066 Merge branch 'DC-AE' of github.com:lawrence-cj/diffusers into DC-AE 2024-11-16 01:41:04 -08:00
Junyu Chen
883bcf4f81 remove opsequential 2024-11-16 01:40:48 -08:00
Junsong Chen
1752afdda9 Update src/diffusers/models/autoencoders/dc_ae.py
Co-authored-by: YiYi Xu <yixu310@gmail.com>
2024-11-16 17:32:06 +08:00
Junyu Chen
80dce027cb remove contents within super().__init__ 2024-11-13 21:29:45 -08:00
Junyu Chen
ea604a4dc2 update ResBlock 2024-11-13 21:23:14 -08:00
Junyu Chen
59de0a3479 remove autocast and some assert 2024-11-13 20:18:37 -08:00
Junyu Chen
c6eb23361c remove get_same_padding and val2tuple 2024-11-13 20:03:45 -08:00
Junsong Chen
0e818df823 Merge branch 'main' into DC-AE 2024-11-13 18:02:34 +08:00
Junsong Chen
3481e23c2d Merge branch 'main' into DC-AE 2024-11-09 15:45:10 +08:00
Junyu Chen
19986a524c support DecoderOutput 2024-11-08 20:05:12 -08:00
Junyu Chen
dd7718a29a clean code 2024-11-08 19:59:09 -08:00
Junyu Chen
bf40fe85ae merge 2024-11-08 19:48:50 -08:00
Junyu Chen
1448681057 update downsample and upsample 2024-11-08 19:47:17 -08:00
junsongc
3c3cc51772 Merge remote-tracking branch 'refs/remotes/origin/main' into DC-AE
# Conflicts:
#	src/diffusers/models/normalization.py
2024-11-06 16:57:12 +08:00
Junsong Chen
6d96b954e6 Merge branch 'huggingface:main' into DC-AE 2024-11-04 23:19:13 +08:00
Junyu Chen
b7f68f9a2d Merge remote-tracking branch 'upstream/main' into DC-AE 2024-10-31 00:18:02 -07:00
Junyu Chen
8f9b4e43bc dc-ae reduce to one file 2024-10-31 00:17:58 -07:00
Junyu Chen
72cce2b468 inherit from ConfigMixin 2024-10-24 21:33:09 -07:00
Junyu Chen
5e63a1ae6c iinherit from ModelMixin 2024-10-24 21:13:05 -07:00
Junyu Chen
fb6d92a4b9 remove dc_ae attention in attention_processor.py 2024-10-24 21:08:50 -07:00
Junyu Chen
da7caa5850 replace custom activation with diffuers activation 2024-10-24 21:07:22 -07:00
Junyu Chen
c323e76b1b Merge remote-tracking branch 'upstream/main' into DC-AE 2024-10-24 20:51:10 -07:00
Junyu Chen
6fb7fdb8ab merge 2024-10-24 20:50:11 -07:00
Junyu Chen
55b2615dd4 no longer rely on omegaconf and dataclass 2024-10-24 20:48:26 -07:00
junsongc
3a44fa4cf3 1. rename file and remove un-used codes; 2024-10-23 16:08:52 +08:00
Junyu Chen
825c975cb9 replace triton with custom implementation 2024-10-23 00:48:42 -07:00
Junyu Chen
90e893936d DC-AE init 2024-10-22 23:49:59 -07:00
Junyu Chen
d2e187a051 Merge remote-tracking branch 'upstream/main' into DC-AE 2024-10-22 23:18:33 -07:00
junsongc
6e616a9cc8 first add a script for DC-AE; 2024-10-18 17:40:51 +08:00
15 changed files with 1586 additions and 6 deletions

View File

@@ -314,6 +314,8 @@
title: AutoencoderKLMochi
- local: api/models/asymmetricautoencoderkl
title: AsymmetricAutoencoderKL
- local: api/models/autoencoder_dc
title: AutoencoderDC
- local: api/models/consistency_decoder_vae
title: ConsistencyDecoderVAE
- local: api/models/autoencoder_oobleck

View File

@@ -0,0 +1,56 @@
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->
# AutoencoderDC
*The 2D Autoencoder model used in [SANA](https://arxiv.org/papers/2410.10629) and introduced in [DCAE](https://huggingface.co/papers/2410.10733) by authors Junyu Chen, Han Cai, Junsong Chen, Enze Xie, Shang Yang, Haotian Tang, Muyang Li, Yao Lu, Song Han from MIT HAN Lab.*
The following DCAE models are released and supported in Diffusers:
- [`mit-han-lab/dc-ae-f32c32-sana-1.0`](https://huggingface.co/mit-han-lab/dc-ae-f32c32-sana-1.0)
- [`mit-han-lab/dc-ae-f32c32-in-1.0`](https://huggingface.co/mit-han-lab/dc-ae-f32c32-in-1.0)
- [`mit-han-lab/dc-ae-f32c32-mix-1.0`](https://huggingface.co/mit-han-lab/dc-ae-f32c32-mix-1.0)
- [`mit-han-lab/dc-ae-f64c128-in-1.0`](https://huggingface.co/mit-han-lab/dc-ae-f64c128-in-1.0)
- [`mit-han-lab/dc-ae-f64c128-mix-1.0`](https://huggingface.co/mit-han-lab/dc-ae-f64c128-mix-1.0)
- [`mit-han-lab/dc-ae-f128c512-in-1.0`](https://huggingface.co/mit-han-lab/dc-ae-f128c512-in-1.0)
- [`mit-han-lab/dc-ae-f128c512-mix-1.0`](https://huggingface.co/mit-han-lab/dc-ae-f128c512-mix-1.0)
The models can be loaded with the following code snippet.
```python
from diffusers import AutoencoderDC
ae = AutoencoderDC.from_pretrained("mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers", torch_dtype=torch.float32).to("cuda")
```
## Single file loading
The `AutoencoderDC` implementation supports loading checkpoints shipped in the original format by MIT HAN Lab. The following example demonstrates how to load the `f128c512` checkpoint:
```python
from diffusers import AutoencoderDC
model_name = "dc-ae-f128c512-mix-1.0"
ae = AutoencoderDC.from_single_file(
f"https://huggingface.co/mit-han-lab/{model_name}/model.safetensors",
original_config=f"https://huggingface.co/mit-han-lab/{model_name}/resolve/main/config.json"
)
```
## AutoencoderDC
[[autodoc]] AutoencoderDC
- decode
- all
## DecoderOutput
[[autodoc]] models.autoencoders.vae.DecoderOutput

View File

@@ -0,0 +1,323 @@
import argparse
from typing import Any, Dict
import torch
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from diffusers import AutoencoderDC
def remap_qkv_(key: str, state_dict: Dict[str, Any]):
qkv = state_dict.pop(key)
q, k, v = torch.chunk(qkv, 3, dim=0)
parent_module, _, _ = key.rpartition(".qkv.conv.weight")
state_dict[f"{parent_module}.to_q.weight"] = q.squeeze()
state_dict[f"{parent_module}.to_k.weight"] = k.squeeze()
state_dict[f"{parent_module}.to_v.weight"] = v.squeeze()
def remap_proj_conv_(key: str, state_dict: Dict[str, Any]):
parent_module, _, _ = key.rpartition(".proj.conv.weight")
state_dict[f"{parent_module}.to_out.weight"] = state_dict.pop(key).squeeze()
AE_KEYS_RENAME_DICT = {
# common
"main.": "",
"op_list.": "",
"context_module": "attn",
"local_module": "conv_out",
# NOTE: The below two lines work because scales in the available configs only have a tuple length of 1
# If there were more scales, there would be more layers, so a loop would be better to handle this
"aggreg.0.0": "to_qkv_multiscale.0.proj_in",
"aggreg.0.1": "to_qkv_multiscale.0.proj_out",
"depth_conv.conv": "conv_depth",
"inverted_conv.conv": "conv_inverted",
"point_conv.conv": "conv_point",
"point_conv.norm": "norm",
"conv.conv.": "conv.",
"conv1.conv": "conv1",
"conv2.conv": "conv2",
"conv2.norm": "norm",
"proj.norm": "norm_out",
# encoder
"encoder.project_in.conv": "encoder.conv_in",
"encoder.project_out.0.conv": "encoder.conv_out",
"encoder.stages": "encoder.down_blocks",
# decoder
"decoder.project_in.conv": "decoder.conv_in",
"decoder.project_out.0": "decoder.norm_out",
"decoder.project_out.2.conv": "decoder.conv_out",
"decoder.stages": "decoder.up_blocks",
}
AE_F32C32_KEYS = {
# encoder
"encoder.project_in.conv": "encoder.conv_in.conv",
# decoder
"decoder.project_out.2.conv": "decoder.conv_out.conv",
}
AE_F64C128_KEYS = {
# encoder
"encoder.project_in.conv": "encoder.conv_in.conv",
# decoder
"decoder.project_out.2.conv": "decoder.conv_out.conv",
}
AE_F128C512_KEYS = {
# encoder
"encoder.project_in.conv": "encoder.conv_in.conv",
# decoder
"decoder.project_out.2.conv": "decoder.conv_out.conv",
}
AE_SPECIAL_KEYS_REMAP = {
"qkv.conv.weight": remap_qkv_,
"proj.conv.weight": remap_proj_conv_,
}
def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]:
state_dict = saved_dict
if "model" in saved_dict.keys():
state_dict = state_dict["model"]
if "module" in saved_dict.keys():
state_dict = state_dict["module"]
if "state_dict" in saved_dict.keys():
state_dict = state_dict["state_dict"]
return state_dict
def update_state_dict_(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]:
state_dict[new_key] = state_dict.pop(old_key)
def convert_ae(config_name: str, dtype: torch.dtype):
config = get_ae_config(config_name)
hub_id = f"mit-han-lab/{config_name}"
ckpt_path = hf_hub_download(hub_id, "model.safetensors")
original_state_dict = get_state_dict(load_file(ckpt_path))
ae = AutoencoderDC(**config).to(dtype=dtype)
for key in list(original_state_dict.keys()):
new_key = key[:]
for replace_key, rename_key in AE_KEYS_RENAME_DICT.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict_(original_state_dict, key, new_key)
for key in list(original_state_dict.keys()):
for special_key, handler_fn_inplace in AE_SPECIAL_KEYS_REMAP.items():
if special_key not in key:
continue
handler_fn_inplace(key, original_state_dict)
ae.load_state_dict(original_state_dict, strict=True)
return ae
def get_ae_config(name: str):
if name in ["dc-ae-f32c32-sana-1.0"]:
config = {
"latent_channels": 32,
"encoder_block_types": (
"ResBlock",
"ResBlock",
"ResBlock",
"EfficientViTBlock",
"EfficientViTBlock",
"EfficientViTBlock",
),
"decoder_block_types": (
"ResBlock",
"ResBlock",
"ResBlock",
"EfficientViTBlock",
"EfficientViTBlock",
"EfficientViTBlock",
),
"encoder_block_out_channels": (128, 256, 512, 512, 1024, 1024),
"decoder_block_out_channels": (128, 256, 512, 512, 1024, 1024),
"encoder_qkv_multiscales": ((), (), (), (5,), (5,), (5,)),
"decoder_qkv_multiscales": ((), (), (), (5,), (5,), (5,)),
"encoder_layers_per_block": (2, 2, 2, 3, 3, 3),
"decoder_layers_per_block": [3, 3, 3, 3, 3, 3],
"downsample_block_type": "conv",
"upsample_block_type": "interpolate",
"decoder_norm_types": "rms_norm",
"decoder_act_fns": "silu",
"scaling_factor": 0.41407,
}
elif name in ["dc-ae-f32c32-in-1.0", "dc-ae-f32c32-mix-1.0"]:
AE_KEYS_RENAME_DICT.update(AE_F32C32_KEYS)
config = {
"latent_channels": 32,
"encoder_block_types": [
"ResBlock",
"ResBlock",
"ResBlock",
"EfficientViTBlock",
"EfficientViTBlock",
"EfficientViTBlock",
],
"decoder_block_types": [
"ResBlock",
"ResBlock",
"ResBlock",
"EfficientViTBlock",
"EfficientViTBlock",
"EfficientViTBlock",
],
"encoder_block_out_channels": [128, 256, 512, 512, 1024, 1024],
"decoder_block_out_channels": [128, 256, 512, 512, 1024, 1024],
"encoder_layers_per_block": [0, 4, 8, 2, 2, 2],
"decoder_layers_per_block": [0, 5, 10, 2, 2, 2],
"encoder_qkv_multiscales": ((), (), (), (), (), ()),
"decoder_qkv_multiscales": ((), (), (), (), (), ()),
"decoder_norm_types": ["batch_norm", "batch_norm", "batch_norm", "rms_norm", "rms_norm", "rms_norm"],
"decoder_act_fns": ["relu", "relu", "relu", "silu", "silu", "silu"],
}
if name == "dc-ae-f32c32-in-1.0":
config["scaling_factor"] = 0.3189
elif name == "dc-ae-f32c32-mix-1.0":
config["scaling_factor"] = 0.4552
elif name in ["dc-ae-f64c128-in-1.0", "dc-ae-f64c128-mix-1.0"]:
AE_KEYS_RENAME_DICT.update(AE_F64C128_KEYS)
config = {
"latent_channels": 128,
"encoder_block_types": [
"ResBlock",
"ResBlock",
"ResBlock",
"EfficientViTBlock",
"EfficientViTBlock",
"EfficientViTBlock",
"EfficientViTBlock",
],
"decoder_block_types": [
"ResBlock",
"ResBlock",
"ResBlock",
"EfficientViTBlock",
"EfficientViTBlock",
"EfficientViTBlock",
"EfficientViTBlock",
],
"encoder_block_out_channels": [128, 256, 512, 512, 1024, 1024, 2048],
"decoder_block_out_channels": [128, 256, 512, 512, 1024, 1024, 2048],
"encoder_layers_per_block": [0, 4, 8, 2, 2, 2, 2],
"decoder_layers_per_block": [0, 5, 10, 2, 2, 2, 2],
"encoder_qkv_multiscales": ((), (), (), (), (), (), ()),
"decoder_qkv_multiscales": ((), (), (), (), (), (), ()),
"decoder_norm_types": [
"batch_norm",
"batch_norm",
"batch_norm",
"rms_norm",
"rms_norm",
"rms_norm",
"rms_norm",
],
"decoder_act_fns": ["relu", "relu", "relu", "silu", "silu", "silu", "silu"],
}
if name == "dc-ae-f64c128-in-1.0":
config["scaling_factor"] = 0.2889
elif name == "dc-ae-f64c128-mix-1.0":
config["scaling_factor"] = 0.4538
elif name in ["dc-ae-f128c512-in-1.0", "dc-ae-f128c512-mix-1.0"]:
AE_KEYS_RENAME_DICT.update(AE_F128C512_KEYS)
config = {
"latent_channels": 512,
"encoder_block_types": [
"ResBlock",
"ResBlock",
"ResBlock",
"EfficientViTBlock",
"EfficientViTBlock",
"EfficientViTBlock",
"EfficientViTBlock",
"EfficientViTBlock",
],
"decoder_block_types": [
"ResBlock",
"ResBlock",
"ResBlock",
"EfficientViTBlock",
"EfficientViTBlock",
"EfficientViTBlock",
"EfficientViTBlock",
"EfficientViTBlock",
],
"encoder_block_out_channels": [128, 256, 512, 512, 1024, 1024, 2048, 2048],
"decoder_block_out_channels": [128, 256, 512, 512, 1024, 1024, 2048, 2048],
"encoder_layers_per_block": [0, 4, 8, 2, 2, 2, 2, 2],
"decoder_layers_per_block": [0, 5, 10, 2, 2, 2, 2, 2],
"encoder_qkv_multiscales": ((), (), (), (), (), (), (), ()),
"decoder_qkv_multiscales": ((), (), (), (), (), (), (), ()),
"decoder_norm_types": [
"batch_norm",
"batch_norm",
"batch_norm",
"rms_norm",
"rms_norm",
"rms_norm",
"rms_norm",
"rms_norm",
],
"decoder_act_fns": ["relu", "relu", "relu", "silu", "silu", "silu", "silu", "silu"],
}
if name == "dc-ae-f128c512-in-1.0":
config["scaling_factor"] = 0.4883
elif name == "dc-ae-f128c512-mix-1.0":
config["scaling_factor"] = 0.3620
else:
raise ValueError("Invalid config name provided.")
return config
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--config_name",
type=str,
default="dc-ae-f32c32-sana-1.0",
choices=[
"dc-ae-f32c32-sana-1.0",
"dc-ae-f32c32-in-1.0",
"dc-ae-f32c32-mix-1.0",
"dc-ae-f64c128-in-1.0",
"dc-ae-f64c128-mix-1.0",
"dc-ae-f128c512-in-1.0",
"dc-ae-f128c512-mix-1.0",
],
help="The DCAE checkpoint to convert",
)
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
parser.add_argument("--dtype", default="fp32", help="Torch dtype to save the model in.")
return parser.parse_args()
DTYPE_MAPPING = {
"fp32": torch.float32,
"fp16": torch.float16,
"bf16": torch.bfloat16,
}
VARIANT_MAPPING = {
"fp32": None,
"fp16": "fp16",
"bf16": "bf16",
}
if __name__ == "__main__":
args = get_args()
dtype = DTYPE_MAPPING[args.dtype]
variant = VARIANT_MAPPING[args.dtype]
ae = convert_ae(args.config_name, dtype)
ae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB", variant=variant)

View File

@@ -80,6 +80,7 @@ else:
"AllegroTransformer3DModel",
"AsymmetricAutoencoderKL",
"AuraFlowTransformer2DModel",
"AutoencoderDC",
"AutoencoderKL",
"AutoencoderKLAllegro",
"AutoencoderKLCogVideoX",
@@ -572,6 +573,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AllegroTransformer3DModel,
AsymmetricAutoencoderKL,
AuraFlowTransformer2DModel,
AutoencoderDC,
AutoencoderKL,
AutoencoderKLAllegro,
AutoencoderKLCogVideoX,

View File

@@ -23,12 +23,14 @@ from ..utils import deprecate, is_accelerate_available, logging
from .single_file_utils import (
SingleFileComponentError,
convert_animatediff_checkpoint_to_diffusers,
convert_autoencoder_dc_checkpoint_to_diffusers,
convert_controlnet_checkpoint,
convert_flux_transformer_checkpoint_to_diffusers,
convert_ldm_unet_checkpoint,
convert_ldm_vae_checkpoint,
convert_sd3_transformer_checkpoint_to_diffusers,
convert_stable_cascade_unet_single_file_to_diffusers,
create_autoencoder_dc_config_from_original,
create_controlnet_diffusers_config_from_ldm,
create_unet_diffusers_config_from_ldm,
create_vae_diffusers_config_from_ldm,
@@ -82,6 +84,10 @@ SINGLE_FILE_LOADABLE_CLASSES = {
"checkpoint_mapping_fn": convert_flux_transformer_checkpoint_to_diffusers,
"default_subfolder": "transformer",
},
"AutoencoderDC": {
"checkpoint_mapping_fn": convert_autoencoder_dc_checkpoint_to_diffusers,
"config_mapping_fn": create_autoencoder_dc_config_from_original,
},
}
@@ -228,7 +234,7 @@ class FromOriginalModelMixin:
if config_mapping_fn is None:
raise ValueError(
(
f"`original_config` has been provided for {mapping_class_name} but no mapping function"
f"`original_config` has been provided for {mapping_class_name} but no mapping function "
"was found to convert the original config to a Diffusers config in"
"`diffusers.loaders.single_file_utils`"
)

View File

@@ -12,7 +12,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Conversion script for the Stable Diffusion checkpoints."""
"""
Conversion scripts for the various modeling checkpoints. These scripts convert original model implementations to
Diffusers adapted versions. This usually only involves renaming/remapping the state dict keys and changing some
modeling components partially (for example, splitting a single QKV linear to individual Q, K, V layers).
"""
import copy
import os
@@ -92,6 +97,7 @@ CHECKPOINT_KEY_NAMES = {
"double_blocks.0.img_attn.norm.key_norm.scale",
"model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale",
],
"autoencoder_dc": "decoder.stages.0.op_list.0.main.conv.conv.weight",
}
DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
@@ -2198,3 +2204,250 @@ def convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
)
return converted_state_dict
def create_autoencoder_dc_config_from_original(original_config, checkpoint, **kwargs):
model_name = original_config.get("model_name", "dc-ae-f32c32-sana-1.0")
if model_name in ["dc-ae-f32c32-sana-1.0"]:
config = {
"latent_channels": 32,
"encoder_block_types": (
"ResBlock",
"ResBlock",
"ResBlock",
"EfficientViTBlock",
"EfficientViTBlock",
"EfficientViTBlock",
),
"decoder_block_types": (
"ResBlock",
"ResBlock",
"ResBlock",
"EfficientViTBlock",
"EfficientViTBlock",
"EfficientViTBlock",
),
"encoder_block_out_channels": (128, 256, 512, 512, 1024, 1024),
"decoder_block_out_channels": (128, 256, 512, 512, 1024, 1024),
"encoder_qkv_multiscales": ((), (), (), (5,), (5,), (5,)),
"decoder_qkv_multiscales": ((), (), (), (5,), (5,), (5,)),
"encoder_layers_per_block": (2, 2, 2, 3, 3, 3),
"decoder_layers_per_block": [3, 3, 3, 3, 3, 3],
"downsample_block_type": "conv",
"upsample_block_type": "interpolate",
"decoder_norm_types": "rms_norm",
"decoder_act_fns": "silu",
"scaling_factor": 0.41407,
}
elif model_name in ["dc-ae-f32c32-in-1.0", "dc-ae-f32c32-mix-1.0"]:
config = {
"latent_channels": 32,
"encoder_block_types": [
"ResBlock",
"ResBlock",
"ResBlock",
"EfficientViTBlock",
"EfficientViTBlock",
"EfficientViTBlock",
],
"decoder_block_types": [
"ResBlock",
"ResBlock",
"ResBlock",
"EfficientViTBlock",
"EfficientViTBlock",
"EfficientViTBlock",
],
"encoder_block_out_channels": [128, 256, 512, 512, 1024, 1024],
"decoder_block_out_channels": [128, 256, 512, 512, 1024, 1024],
"encoder_layers_per_block": [0, 4, 8, 2, 2, 2],
"decoder_layers_per_block": [0, 5, 10, 2, 2, 2],
"encoder_qkv_multiscales": ((), (), (), (), (), ()),
"decoder_qkv_multiscales": ((), (), (), (), (), ()),
"decoder_norm_types": ["batch_norm", "batch_norm", "batch_norm", "rms_norm", "rms_norm", "rms_norm"],
"decoder_act_fns": ["relu", "relu", "relu", "silu", "silu", "silu"],
}
if model_name == "dc-ae-f32c32-in-1.0":
config["scaling_factor"] = 0.3189
elif model_name == "dc-ae-f32c32-mix-1.0":
config["scaling_factor"] = 0.4552
elif model_name in ["dc-ae-f64c128-in-1.0", "dc-ae-f64c128-mix-1.0"]:
config = {
"latent_channels": 128,
"encoder_block_types": [
"ResBlock",
"ResBlock",
"ResBlock",
"EfficientViTBlock",
"EfficientViTBlock",
"EfficientViTBlock",
"EfficientViTBlock",
],
"decoder_block_types": [
"ResBlock",
"ResBlock",
"ResBlock",
"EfficientViTBlock",
"EfficientViTBlock",
"EfficientViTBlock",
"EfficientViTBlock",
],
"encoder_block_out_channels": [128, 256, 512, 512, 1024, 1024, 2048],
"decoder_block_out_channels": [128, 256, 512, 512, 1024, 1024, 2048],
"encoder_layers_per_block": [0, 4, 8, 2, 2, 2, 2],
"decoder_layers_per_block": [0, 5, 10, 2, 2, 2, 2],
"encoder_qkv_multiscales": ((), (), (), (), (), (), ()),
"decoder_qkv_multiscales": ((), (), (), (), (), (), ()),
"decoder_norm_types": [
"batch_norm",
"batch_norm",
"batch_norm",
"rms_norm",
"rms_norm",
"rms_norm",
"rms_norm",
],
"decoder_act_fns": ["relu", "relu", "relu", "silu", "silu", "silu", "silu"],
}
if model_name == "dc-ae-f64c128-in-1.0":
config["scaling_factor"] = 0.2889
elif model_name == "dc-ae-f64c128-mix-1.0":
config["scaling_factor"] = 0.4538
elif model_name in ["dc-ae-f128c512-in-1.0", "dc-ae-f128c512-mix-1.0"]:
config = {
"latent_channels": 512,
"encoder_block_types": [
"ResBlock",
"ResBlock",
"ResBlock",
"EfficientViTBlock",
"EfficientViTBlock",
"EfficientViTBlock",
"EfficientViTBlock",
"EfficientViTBlock",
],
"decoder_block_types": [
"ResBlock",
"ResBlock",
"ResBlock",
"EfficientViTBlock",
"EfficientViTBlock",
"EfficientViTBlock",
"EfficientViTBlock",
"EfficientViTBlock",
],
"encoder_block_out_channels": [128, 256, 512, 512, 1024, 1024, 2048, 2048],
"decoder_block_out_channels": [128, 256, 512, 512, 1024, 1024, 2048, 2048],
"encoder_layers_per_block": [0, 4, 8, 2, 2, 2, 2, 2],
"decoder_layers_per_block": [0, 5, 10, 2, 2, 2, 2, 2],
"encoder_qkv_multiscales": ((), (), (), (), (), (), (), ()),
"decoder_qkv_multiscales": ((), (), (), (), (), (), (), ()),
"decoder_norm_types": [
"batch_norm",
"batch_norm",
"batch_norm",
"rms_norm",
"rms_norm",
"rms_norm",
"rms_norm",
"rms_norm",
],
"decoder_act_fns": ["relu", "relu", "relu", "silu", "silu", "silu", "silu", "silu"],
}
if model_name == "dc-ae-f128c512-in-1.0":
config["scaling_factor"] = 0.4883
elif model_name == "dc-ae-f128c512-mix-1.0":
config["scaling_factor"] = 0.3620
config.update({"model_name": model_name})
return config
def convert_autoencoder_dc_checkpoint_to_diffusers(config, checkpoint, **kwargs):
converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys())}
model_name = config.pop("model_name")
def remap_qkv_(key: str, state_dict):
qkv = state_dict.pop(key)
q, k, v = torch.chunk(qkv, 3, dim=0)
parent_module, _, _ = key.rpartition(".qkv.conv.weight")
state_dict[f"{parent_module}.to_q.weight"] = q.squeeze()
state_dict[f"{parent_module}.to_k.weight"] = k.squeeze()
state_dict[f"{parent_module}.to_v.weight"] = v.squeeze()
def remap_proj_conv_(key: str, state_dict):
parent_module, _, _ = key.rpartition(".proj.conv.weight")
state_dict[f"{parent_module}.to_out.weight"] = state_dict.pop(key).squeeze()
AE_KEYS_RENAME_DICT = {
# common
"main.": "",
"op_list.": "",
"context_module": "attn",
"local_module": "conv_out",
# NOTE: The below two lines work because scales in the available configs only have a tuple length of 1
# If there were more scales, there would be more layers, so a loop would be better to handle this
"aggreg.0.0": "to_qkv_multiscale.0.proj_in",
"aggreg.0.1": "to_qkv_multiscale.0.proj_out",
"depth_conv.conv": "conv_depth",
"inverted_conv.conv": "conv_inverted",
"point_conv.conv": "conv_point",
"point_conv.norm": "norm",
"conv.conv.": "conv.",
"conv1.conv": "conv1",
"conv2.conv": "conv2",
"conv2.norm": "norm",
"proj.norm": "norm_out",
# encoder
"encoder.project_in.conv": "encoder.conv_in",
"encoder.project_out.0.conv": "encoder.conv_out",
"encoder.stages": "encoder.down_blocks",
# decoder
"decoder.project_in.conv": "decoder.conv_in",
"decoder.project_out.0": "decoder.norm_out",
"decoder.project_out.2.conv": "decoder.conv_out",
"decoder.stages": "decoder.up_blocks",
}
AE_F32C32_KEYS = {
"encoder.project_in.conv": "encoder.conv_in.conv",
"decoder.project_out.2.conv": "decoder.conv_out.conv",
}
AE_F64C128_KEYS = {
"encoder.project_in.conv": "encoder.conv_in.conv",
"decoder.project_out.2.conv": "decoder.conv_out.conv",
}
AE_F128C512_KEYS = {
"encoder.project_in.conv": "encoder.conv_in.conv",
"decoder.project_out.2.conv": "decoder.conv_out.conv",
}
AE_SPECIAL_KEYS_REMAP = {
"qkv.conv.weight": remap_qkv_,
"proj.conv.weight": remap_proj_conv_,
}
if "f32c32" in model_name and "sana" not in model_name:
AE_KEYS_RENAME_DICT.update(AE_F32C32_KEYS)
elif "f64c128" in model_name:
AE_KEYS_RENAME_DICT.update(AE_F64C128_KEYS)
elif "f128c512" in model_name:
AE_KEYS_RENAME_DICT.update(AE_F128C512_KEYS)
for key in list(converted_state_dict.keys()):
new_key = key[:]
for replace_key, rename_key in AE_KEYS_RENAME_DICT.items():
new_key = new_key.replace(replace_key, rename_key)
converted_state_dict[new_key] = converted_state_dict.pop(key)
for key in list(converted_state_dict.keys()):
for special_key, handler_fn_inplace in AE_SPECIAL_KEYS_REMAP.items():
if special_key not in key:
continue
handler_fn_inplace(key, converted_state_dict)
return converted_state_dict

View File

@@ -27,6 +27,7 @@ _import_structure = {}
if is_torch_available():
_import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"]
_import_structure["autoencoders.autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"]
_import_structure["autoencoders.autoencoder_dc"] = ["AutoencoderDC"]
_import_structure["autoencoders.autoencoder_kl"] = ["AutoencoderKL"]
_import_structure["autoencoders.autoencoder_kl_allegro"] = ["AutoencoderKLAllegro"]
_import_structure["autoencoders.autoencoder_kl_cogvideox"] = ["AutoencoderKLCogVideoX"]
@@ -88,6 +89,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .adapter import MultiAdapter, T2IAdapter
from .autoencoders import (
AsymmetricAutoencoderKL,
AutoencoderDC,
AutoencoderKL,
AutoencoderKLAllegro,
AutoencoderKLCogVideoX,

View File

@@ -22,7 +22,13 @@ from ..utils.torch_utils import maybe_allow_in_graph
from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, SwiGLU
from .attention_processor import Attention, JointAttnProcessor2_0
from .embeddings import SinusoidalPositionalEmbedding
from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm, SD35AdaLayerNormZeroX
from .normalization import (
AdaLayerNorm,
AdaLayerNormContinuous,
AdaLayerNormZero,
RMSNorm,
SD35AdaLayerNormZeroX,
)
logger = logging.get_logger(__name__)

View File

@@ -752,6 +752,98 @@ class Attention(nn.Module):
self.fused_projections = fuse
class SanaMultiscaleAttentionProjection(nn.Module):
def __init__(
self,
in_channels: int,
num_attention_heads: int,
kernel_size: int,
) -> None:
super().__init__()
channels = 3 * in_channels
self.proj_in = nn.Conv2d(
channels,
channels,
kernel_size,
padding=kernel_size // 2,
groups=channels,
bias=False,
)
self.proj_out = nn.Conv2d(channels, channels, 1, 1, 0, groups=3 * num_attention_heads, bias=False)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.proj_in(hidden_states)
hidden_states = self.proj_out(hidden_states)
return hidden_states
class SanaMultiscaleLinearAttention(nn.Module):
r"""Lightweight multi-scale linear attention"""
def __init__(
self,
in_channels: int,
out_channels: int,
num_attention_heads: Optional[int] = None,
mult: float = 1.0,
attention_head_dim: int = 8,
norm_type: str = "batch_norm",
kernel_sizes: Tuple[int, ...] = (5,),
eps: float = 1e-15,
residual_connection: bool = False,
):
super().__init__()
# To prevent circular import
from .normalization import get_normalization
self.eps = eps
self.attention_head_dim = attention_head_dim
self.norm_type = norm_type
self.residual_connection = residual_connection
num_attention_heads = (
int(in_channels // attention_head_dim * mult) if num_attention_heads is None else num_attention_heads
)
inner_dim = num_attention_heads * attention_head_dim
self.to_q = nn.Linear(in_channels, inner_dim, bias=False)
self.to_k = nn.Linear(in_channels, inner_dim, bias=False)
self.to_v = nn.Linear(in_channels, inner_dim, bias=False)
self.to_qkv_multiscale = nn.ModuleList()
for kernel_size in kernel_sizes:
self.to_qkv_multiscale.append(
SanaMultiscaleAttentionProjection(inner_dim, num_attention_heads, kernel_size)
)
self.nonlinearity = nn.ReLU()
self.to_out = nn.Linear(inner_dim * (1 + len(kernel_sizes)), out_channels, bias=False)
self.norm_out = get_normalization(norm_type, num_features=out_channels)
self.processor = SanaMultiscaleAttnProcessor2_0()
def apply_linear_attention(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
value = F.pad(value, (0, 0, 0, 1), mode="constant", value=1) # Adds padding
scores = torch.matmul(value, key.transpose(-1, -2))
hidden_states = torch.matmul(scores, query)
hidden_states = hidden_states.to(dtype=torch.float32)
hidden_states = hidden_states[:, :, :-1] / (hidden_states[:, :, -1:] + self.eps)
return hidden_states
def apply_quadratic_attention(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
scores = torch.matmul(key.transpose(-1, -2), query)
scores = scores.to(dtype=torch.float32)
scores = scores / (torch.sum(scores, dim=2, keepdim=True) + self.eps)
hidden_states = torch.matmul(value, scores)
return hidden_states
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return self.processor(self, hidden_states)
class AttnProcessor:
r"""
Default processor for performing attention-related computations.
@@ -5005,6 +5097,66 @@ class PAGCFGIdentitySelfAttnProcessor2_0:
return hidden_states
class SanaMultiscaleAttnProcessor2_0:
r"""
Processor for implementing multiscale quadratic attention.
"""
def __call__(self, attn: SanaMultiscaleLinearAttention, hidden_states: torch.Tensor) -> torch.Tensor:
height, width = hidden_states.shape[-2:]
if height * width > attn.attention_head_dim:
use_linear_attention = True
else:
use_linear_attention = False
residual = hidden_states
batch_size, _, height, width = list(hidden_states.size())
original_dtype = hidden_states.dtype
hidden_states = hidden_states.movedim(1, -1)
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
hidden_states = torch.cat([query, key, value], dim=3)
hidden_states = hidden_states.movedim(-1, 1)
multi_scale_qkv = [hidden_states]
for block in attn.to_qkv_multiscale:
multi_scale_qkv.append(block(hidden_states))
hidden_states = torch.cat(multi_scale_qkv, dim=1)
if use_linear_attention:
# for linear attention upcast hidden_states to float32
hidden_states = hidden_states.to(dtype=torch.float32)
hidden_states = hidden_states.reshape(batch_size, -1, 3 * attn.attention_head_dim, height * width)
query, key, value = hidden_states.chunk(3, dim=2)
query = attn.nonlinearity(query)
key = attn.nonlinearity(key)
if use_linear_attention:
hidden_states = attn.apply_linear_attention(query, key, value)
hidden_states = hidden_states.to(dtype=original_dtype)
else:
hidden_states = attn.apply_quadratic_attention(query, key, value)
hidden_states = torch.reshape(hidden_states, (batch_size, -1, height, width))
hidden_states = attn.to_out(hidden_states.movedim(1, -1)).movedim(-1, 1)
if attn.norm_type == "rms_norm":
hidden_states = attn.norm_out(hidden_states.movedim(1, -1)).movedim(-1, 1)
else:
hidden_states = attn.norm_out(hidden_states)
if attn.residual_connection:
hidden_states = hidden_states + residual
return hidden_states
class LoRAAttnProcessor:
def __init__(self):
pass

View File

@@ -1,4 +1,5 @@
from .autoencoder_asym_kl import AsymmetricAutoencoderKL
from .autoencoder_dc import AutoencoderDC
from .autoencoder_kl import AutoencoderKL
from .autoencoder_kl_allegro import AutoencoderKLAllegro
from .autoencoder_kl_cogvideox import AutoencoderKLCogVideoX

View File

@@ -0,0 +1,638 @@
# Copyright 2024 MIT, Tsinghua University, NVIDIA CORPORATION and The HuggingFace Team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin
from ...utils.accelerate_utils import apply_forward_hook
from ..activations import get_activation
from ..attention_processor import SanaMultiscaleLinearAttention
from ..modeling_utils import ModelMixin
from ..normalization import RMSNorm, get_normalization
from .vae import DecoderOutput, EncoderOutput
class GLUMBConv(nn.Module):
def __init__(self, in_channels: int, out_channels: int) -> None:
super().__init__()
hidden_channels = 4 * in_channels
self.nonlinearity = nn.SiLU()
self.conv_inverted = nn.Conv2d(in_channels, hidden_channels * 2, 1, 1, 0)
self.conv_depth = nn.Conv2d(hidden_channels * 2, hidden_channels * 2, 3, 1, 1, groups=hidden_channels * 2)
self.conv_point = nn.Conv2d(hidden_channels, out_channels, 1, 1, 0, bias=False)
self.norm = RMSNorm(out_channels, eps=1e-5, elementwise_affine=True, bias=True)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
residual = hidden_states
hidden_states = self.conv_inverted(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.conv_depth(hidden_states)
hidden_states, gate = torch.chunk(hidden_states, 2, dim=1)
hidden_states = hidden_states * self.nonlinearity(gate)
hidden_states = self.conv_point(hidden_states)
# move channel to the last dimension so we apply RMSnorm across channel dimension
hidden_states = self.norm(hidden_states.movedim(1, -1)).movedim(-1, 1)
return hidden_states + residual
class ResBlock(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
norm_type: str = "batch_norm",
act_fn: str = "relu6",
) -> None:
super().__init__()
self.norm_type = norm_type
self.nonlinearity = get_activation(act_fn) if act_fn is not None else nn.Identity()
self.conv1 = nn.Conv2d(in_channels, in_channels, 3, 1, 1)
self.conv2 = nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False)
self.norm = get_normalization(norm_type, out_channels)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
residual = hidden_states
hidden_states = self.conv1(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.conv2(hidden_states)
if self.norm_type == "rms_norm":
# move channel to the last dimension so we apply RMSnorm across channel dimension
hidden_states = self.norm(hidden_states.movedim(1, -1)).movedim(-1, 1)
else:
hidden_states = self.norm(hidden_states)
return hidden_states + residual
class EfficientViTBlock(nn.Module):
def __init__(
self,
in_channels: int,
mult: float = 1.0,
attention_head_dim: int = 32,
qkv_multiscales: Tuple[int, ...] = (5,),
norm_type: str = "batch_norm",
) -> None:
super().__init__()
self.attn = SanaMultiscaleLinearAttention(
in_channels=in_channels,
out_channels=in_channels,
mult=mult,
attention_head_dim=attention_head_dim,
norm_type=norm_type,
kernel_sizes=qkv_multiscales,
residual_connection=True,
)
self.conv_out = GLUMBConv(
in_channels=in_channels,
out_channels=in_channels,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.attn(x)
x = self.conv_out(x)
return x
def get_block(
block_type: str,
in_channels: int,
out_channels: int,
attention_head_dim: int,
norm_type: str,
act_fn: str,
qkv_mutliscales: Tuple[int] = (),
):
if block_type == "ResBlock":
block = ResBlock(in_channels, out_channels, norm_type, act_fn)
elif block_type == "EfficientViTBlock":
block = EfficientViTBlock(
in_channels, attention_head_dim=attention_head_dim, norm_type=norm_type, qkv_multiscales=qkv_mutliscales
)
else:
raise ValueError(f"Block with {block_type=} is not supported.")
return block
class DCDownBlock2d(nn.Module):
def __init__(self, in_channels: int, out_channels: int, downsample: bool = False, shortcut: bool = True) -> None:
super().__init__()
self.downsample = downsample
self.factor = 2
self.stride = 1 if downsample else 2
self.group_size = in_channels * self.factor**2 // out_channels
self.shortcut = shortcut
out_ratio = self.factor**2
if downsample:
assert out_channels % out_ratio == 0
out_channels = out_channels // out_ratio
self.conv = nn.Conv2d(
in_channels,
out_channels,
kernel_size=3,
stride=self.stride,
padding=1,
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
x = self.conv(hidden_states)
if self.downsample:
x = F.pixel_unshuffle(x, self.factor)
if self.shortcut:
y = F.pixel_unshuffle(hidden_states, self.factor)
y = y.unflatten(1, (-1, self.group_size))
y = y.mean(dim=2)
hidden_states = x + y
else:
hidden_states = x
return hidden_states
class DCUpBlock2d(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
interpolate: bool = False,
shortcut: bool = True,
interpolation_mode: str = "nearest",
) -> None:
super().__init__()
self.interpolate = interpolate
self.interpolation_mode = interpolation_mode
self.shortcut = shortcut
self.factor = 2
self.repeats = out_channels * self.factor**2 // in_channels
out_ratio = self.factor**2
if not interpolate:
out_channels = out_channels * out_ratio
self.conv = nn.Conv2d(in_channels, out_channels, 3, 1, 1)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
if self.interpolate:
x = F.interpolate(hidden_states, scale_factor=self.factor, mode=self.interpolation_mode)
x = self.conv(x)
else:
x = self.conv(hidden_states)
x = F.pixel_shuffle(x, self.factor)
if self.shortcut:
y = hidden_states.repeat_interleave(self.repeats, dim=1)
y = F.pixel_shuffle(y, self.factor)
hidden_states = x + y
else:
hidden_states = x
return hidden_states
class Encoder(nn.Module):
def __init__(
self,
in_channels: int,
latent_channels: int,
attention_head_dim: int = 32,
block_type: Union[str, Tuple[str]] = "ResBlock",
block_out_channels: Tuple[int] = (128, 256, 512, 512, 1024, 1024),
layers_per_block: Tuple[int] = (2, 2, 2, 2, 2, 2),
qkv_multiscales: Tuple[Tuple[int, ...], ...] = ((), (), (), (5,), (5,), (5,)),
downsample_block_type: str = "pixel_unshuffle",
):
super().__init__()
num_blocks = len(block_out_channels)
if isinstance(block_type, str):
block_type = (block_type,) * num_blocks
if layers_per_block[0] > 0:
self.conv_in = nn.Conv2d(
in_channels,
block_out_channels[0] if layers_per_block[0] > 0 else block_out_channels[1],
kernel_size=3,
stride=1,
padding=1,
)
else:
self.conv_in = DCDownBlock2d(
in_channels=in_channels,
out_channels=block_out_channels[0] if layers_per_block[0] > 0 else block_out_channels[1],
downsample=downsample_block_type == "pixel_unshuffle",
shortcut=False,
)
down_blocks = []
for i, (out_channel, num_layers) in enumerate(zip(block_out_channels, layers_per_block)):
down_block_list = []
for _ in range(num_layers):
block = get_block(
block_type[i],
out_channel,
out_channel,
attention_head_dim=attention_head_dim,
norm_type="rms_norm",
act_fn="silu",
qkv_mutliscales=qkv_multiscales[i],
)
down_block_list.append(block)
if i < num_blocks - 1 and num_layers > 0:
downsample_block = DCDownBlock2d(
in_channels=out_channel,
out_channels=block_out_channels[i + 1],
downsample=downsample_block_type == "pixel_unshuffle",
shortcut=True,
)
down_block_list.append(downsample_block)
down_blocks.append(nn.Sequential(*down_block_list))
self.down_blocks = nn.ModuleList(down_blocks)
self.conv_out = nn.Conv2d(block_out_channels[-1], latent_channels, 3, 1, 1)
self.norm_factor = 1
norm_in_channels = block_out_channels[-1]
norm_out_channels = latent_channels
self.norm_group_size = norm_in_channels * self.norm_factor**2 // norm_out_channels
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.conv_in(hidden_states)
for down_block in self.down_blocks:
hidden_states = down_block(hidden_states)
x = F.pixel_unshuffle(hidden_states, self.norm_factor)
x = x.unflatten(1, (-1, self.norm_group_size))
x = x.mean(dim=2)
hidden_states = self.conv_out(hidden_states) + x
return hidden_states
class Decoder(nn.Module):
def __init__(
self,
in_channels: int,
latent_channels: int,
attention_head_dim: int = 32,
block_type: Union[str, Tuple[str]] = "ResBlock",
block_out_channels: Tuple[int] = (128, 256, 512, 512, 1024, 1024),
layers_per_block: Tuple[int] = (2, 2, 2, 2, 2, 2),
qkv_multiscales: Tuple[Tuple[int, ...], ...] = ((), (), (), (5,), (5,), (5,)),
norm_type: Union[str, Tuple[str]] = "rms_norm",
act_fn: Union[str, Tuple[str]] = "silu",
upsample_block_type: str = "pixel_shuffle",
upsample_shortcut: str = "duplicating",
):
super().__init__()
num_blocks = len(block_out_channels)
if isinstance(block_type, str):
block_type = (block_type,) * num_blocks
if isinstance(norm_type, str):
norm_type = (norm_type,) * num_blocks
if isinstance(act_fn, str):
act_fn = (act_fn,) * num_blocks
self.conv_in = nn.Conv2d(latent_channels, block_out_channels[-1], 3, 1, 1)
self.norm_factor = 1
self.norm_repeats = block_out_channels[-1] * self.norm_factor**2 // latent_channels
up_blocks = []
for i, (out_channel, num_layers) in reversed(list(enumerate(zip(block_out_channels, layers_per_block)))):
up_block_list = []
if i < num_blocks - 1 and num_layers > 0:
upsample_block = DCUpBlock2d(
block_out_channels[i + 1],
out_channel,
interpolate=upsample_block_type == "interpolate",
shortcut=upsample_shortcut,
)
up_block_list.append(upsample_block)
for _ in range(num_layers):
block = get_block(
block_type[i],
out_channel,
out_channel,
attention_head_dim=attention_head_dim,
norm_type=norm_type[i],
act_fn=act_fn[i],
qkv_mutliscales=qkv_multiscales[i],
)
up_block_list.append(block)
up_blocks.insert(0, nn.Sequential(*up_block_list))
self.up_blocks = nn.ModuleList(up_blocks)
channels = block_out_channels[0] if layers_per_block[0] > 0 else block_out_channels[1]
self.norm_out = RMSNorm(channels, 1e-5, elementwise_affine=True, bias=True)
self.conv_act = nn.ReLU()
self.conv_out = None
if layers_per_block[0] > 0:
self.conv_out = nn.Conv2d(channels, in_channels, 3, 1, 1)
else:
self.conv_out = DCUpBlock2d(
channels, in_channels, interpolate=upsample_block_type == "interpolate", shortcut=False
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
x = hidden_states.repeat_interleave(self.norm_repeats, dim=1)
x = F.pixel_shuffle(x, self.norm_factor)
hidden_states = self.conv_in(hidden_states) + x
for up_block in reversed(self.up_blocks):
hidden_states = up_block(hidden_states)
hidden_states = self.norm_out(hidden_states.movedim(1, -1)).movedim(-1, 1)
hidden_states = self.conv_act(hidden_states)
hidden_states = self.conv_out(hidden_states)
return hidden_states
class AutoencoderDC(ModelMixin, ConfigMixin, FromOriginalModelMixin):
r"""
An Autoencoder model introduced in [DCAE](https://arxiv.org/abs/2410.10733) and used in
[SANA](https://arxiv.org/abs/2410.10629).
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
for all models (such as downloading or saving).
Args:
in_channels (`int`, defaults to `3`):
The number of input channels in samples.
latent_channels (`int`, defaults to `32`):
The number of channels in the latent space representation.
encoder_block_types (`Union[str, Tuple[str]]`, defaults to `"ResBlock"`):
The type(s) of block to use in the encoder.
decoder_block_types (`Union[str, Tuple[str]]`, defaults to `"ResBlock"`):
The type(s) of block to use in the decoder.
block_out_channels (`Tuple[int, ...]`, defaults to `(128, 256, 512, 512, 1024, 1024)`):
The number of output channels for each block in the encoder/decoder.
encoder_layers_per_block (`Tuple[int]`, defaults to `(2, 2, 2, 3, 3, 3)`):
The number of layers per block in the encoder.
decoder_layers_per_block (`Tuple[int]`, defaults to `(3, 3, 3, 3, 3, 3)`):
The number of layers per block in the decoder.
encoder_qkv_multiscales (`Tuple[Tuple[int, ...], ...]`, defaults to `((), (), (), (5,), (5,), (5,))`):
Multi-scale configurations for the encoder's QKV (query-key-value) transformations.
decoder_qkv_multiscales (`Tuple[Tuple[int, ...], ...]`, defaults to `((), (), (), (5,), (5,), (5,))`):
Multi-scale configurations for the decoder's QKV (query-key-value) transformations.
upsample_block_type (`str`, defaults to `"pixel_shuffle"`):
The type of block to use for upsampling in the decoder.
downsample_block_type (`str`, defaults to `"pixel_unshuffle"`):
The type of block to use for downsampling in the encoder.
decoder_norm_types (`Union[str, Tuple[str]]`, defaults to `"rms_norm"`):
The normalization type(s) to use in the decoder.
decoder_act_fns (`Union[str, Tuple[str]]`, defaults to `"silu"`):
The activation function(s) to use in the decoder.
scaling_factor (`float`, defaults to `1.0`):
A scaling factor applied during model operations.
"""
_supports_gradient_checkpointing = False
@register_to_config
def __init__(
self,
in_channels: int = 3,
latent_channels: int = 32,
attention_head_dim: int = 32,
encoder_block_types: Union[str, Tuple[str]] = "ResBlock",
decoder_block_types: Union[str, Tuple[str]] = "ResBlock",
encoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 512, 1024, 1024),
decoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 512, 1024, 1024),
encoder_layers_per_block: Tuple[int] = (2, 2, 2, 3, 3, 3),
decoder_layers_per_block: Tuple[int] = (3, 3, 3, 3, 3, 3),
encoder_qkv_multiscales: Tuple[Tuple[int, ...], ...] = ((), (), (), (5,), (5,), (5,)),
decoder_qkv_multiscales: Tuple[Tuple[int, ...], ...] = ((), (), (), (5,), (5,), (5,)),
upsample_block_type: str = "pixel_shuffle",
downsample_block_type: str = "pixel_unshuffle",
decoder_norm_types: Union[str, Tuple[str]] = "rms_norm",
decoder_act_fns: Union[str, Tuple[str]] = "silu",
scaling_factor: float = 1.0,
) -> None:
super().__init__()
self.encoder = Encoder(
in_channels=in_channels,
latent_channels=latent_channels,
attention_head_dim=attention_head_dim,
block_type=encoder_block_types,
block_out_channels=encoder_block_out_channels,
layers_per_block=encoder_layers_per_block,
qkv_multiscales=encoder_qkv_multiscales,
downsample_block_type=downsample_block_type,
)
self.decoder = Decoder(
in_channels=in_channels,
latent_channels=latent_channels,
attention_head_dim=attention_head_dim,
block_type=decoder_block_types,
block_out_channels=decoder_block_out_channels,
layers_per_block=decoder_layers_per_block,
qkv_multiscales=decoder_qkv_multiscales,
norm_type=decoder_norm_types,
act_fn=decoder_act_fns,
upsample_block_type=upsample_block_type,
)
self.spatial_compression_ratio = 2 ** (len(encoder_block_out_channels) - 1)
self.temporal_compression_ratio = 1
# When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension
# to perform decoding of a single video latent at a time.
self.use_slicing = False
# When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent
# frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the
# intermediate tiles together, the memory requirement can be lowered.
self.use_tiling = False
# The minimal tile height and width for spatial tiling to be used
self.tile_sample_min_height = 512
self.tile_sample_min_width = 512
# The minimal distance between two spatial tiles
self.tile_sample_stride_height = 448
self.tile_sample_stride_width = 448
def enable_tiling(
self,
tile_sample_min_height: Optional[int] = None,
tile_sample_min_width: Optional[int] = None,
tile_sample_stride_height: Optional[float] = None,
tile_sample_stride_width: Optional[float] = None,
) -> None:
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
Args:
tile_sample_min_height (`int`, *optional*):
The minimum height required for a sample to be separated into tiles across the height dimension.
tile_sample_min_width (`int`, *optional*):
The minimum width required for a sample to be separated into tiles across the width dimension.
tile_sample_stride_height (`int`, *optional*):
The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
no tiling artifacts produced across the height dimension.
tile_sample_stride_width (`int`, *optional*):
The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling
artifacts produced across the width dimension.
"""
self.use_tiling = True
self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
def disable_tiling(self) -> None:
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_tiling = False
def enable_slicing(self) -> None:
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
def disable_slicing(self) -> None:
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
def _encode(self, x: torch.Tensor) -> torch.Tensor:
batch_size, num_channels, height, width = x.shape
if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
return self.tiled_encode(x, return_dict=False)[0]
encoded = self.encoder(x)
return encoded
@apply_forward_hook
def encode(self, x: torch.Tensor, return_dict: bool = True) -> torch.Tensor:
r"""
Encode a batch of images into latents.
Args:
x (`torch.Tensor`): Input batch of images.
return_dict (`bool`, defaults to `True`):
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
Returns:
The latent representations of the encoded videos. If `return_dict` is True, a
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
"""
if self.use_slicing and x.shape[0] > 1:
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
encoded = torch.cat(encoded_slices)
else:
encoded = self._encode(x)
if not return_dict:
return (encoded,)
return EncoderOutput(latent=encoded)
def _decode(self, z: torch.Tensor) -> torch.Tensor:
batch_size, num_channels, height, width = z.shape
if self.use_tiling and (width > self.tile_latent_min_width or height > self.tile_latent_min_height):
return self.tiled_decode(z, return_dict=False)[0]
decoded = self.decoder(z)
return decoded
@apply_forward_hook
def decode(self, z: torch.Tensor, return_dict: bool = True) -> torch.Tensor:
r"""
Decode a batch of images.
Args:
z (`torch.Tensor`): Input batch of latent vectors.
return_dict (`bool`, defaults to `True`):
Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
Returns:
[`~models.vae.DecoderOutput`] or `tuple`:
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
returned.
"""
if self.use_slicing and z.size(0) > 1:
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
decoded = torch.cat(decoded_slices)
else:
decoded = self._decode(z)
if not return_dict:
return (decoded,)
return DecoderOutput(sample=decoded)
def tiled_encode(self, x: torch.Tensor, return_dict: bool = True) -> torch.Tensor:
raise NotImplementedError("`tiled_encode` has not been implemented for AutoencoderDC.")
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
raise NotImplementedError("`tiled_decode` has not been implemented for AutoencoderDC.")
def forward(self, sample: torch.Tensor, return_dict: bool = True) -> torch.Tensor:
z = self.encode(sample)
dec = self.decode(z)
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)

View File

@@ -30,6 +30,19 @@ from ..unets.unet_2d_blocks import (
)
@dataclass
class EncoderOutput(BaseOutput):
r"""
Output of encoding method.
Args:
latent (`torch.Tensor` of shape `(batch_size, num_channels, latent_height, latent_width)`):
The encoded latent.
"""
latent: torch.Tensor
@dataclass
class DecoderOutput(BaseOutput):
r"""

View File

@@ -512,20 +512,24 @@ else:
class RMSNorm(nn.Module):
def __init__(self, dim, eps: float, elementwise_affine: bool = True):
def __init__(self, dim, eps: float, elementwise_affine: bool = True, bias: bool = False):
super().__init__()
self.eps = eps
self.elementwise_affine = elementwise_affine
if isinstance(dim, numbers.Integral):
dim = (dim,)
self.dim = torch.Size(dim)
self.weight = None
self.bias = None
if elementwise_affine:
self.weight = nn.Parameter(torch.ones(dim))
else:
self.weight = None
if bias:
self.bias = nn.Parameter(torch.zeros(dim))
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
@@ -537,6 +541,8 @@ class RMSNorm(nn.Module):
if self.weight.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(self.weight.dtype)
hidden_states = hidden_states * self.weight
if self.bias is not None:
hidden_states = hidden_states + self.bias
else:
hidden_states = hidden_states.to(input_dtype)
@@ -566,3 +572,21 @@ class LpNorm(nn.Module):
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return F.normalize(hidden_states, p=self.p, dim=self.dim, eps=self.eps)
def get_normalization(
norm_type: str = "batch_norm",
num_features: Optional[int] = None,
eps: float = 1e-5,
elementwise_affine: bool = True,
bias: bool = True,
) -> nn.Module:
if norm_type == "rms_norm":
norm = RMSNorm(num_features, eps=eps, elementwise_affine=elementwise_affine, bias=bias)
elif norm_type == "layer_norm":
norm = nn.LayerNorm(num_features, eps=eps, elementwise_affine=elementwise_affine, bias=bias)
elif norm_type == "batch_norm":
norm = nn.BatchNorm2d(num_features, eps=eps, affine=elementwise_affine)
else:
raise ValueError(f"{norm_type=} is not supported.")
return norm

View File

@@ -47,6 +47,21 @@ class AuraFlowTransformer2DModel(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class AutoencoderDC(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class AutoencoderKL(metaclass=DummyObject):
_backends = ["torch"]

View File

@@ -0,0 +1,87 @@
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from diffusers import AutoencoderDC
from diffusers.utils.testing_utils import (
enable_full_determinism,
floats_tensor,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
enable_full_determinism()
class AutoencoderDCTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
model_class = AutoencoderDC
main_input_name = "sample"
base_precision = 1e-2
def get_autoencoder_dc_config(self):
return {
"in_channels": 3,
"latent_channels": 4,
"attention_head_dim": 2,
"encoder_block_types": (
"ResBlock",
"EfficientViTBlock",
),
"decoder_block_types": (
"ResBlock",
"EfficientViTBlock",
),
"encoder_block_out_channels": (8, 8),
"decoder_block_out_channels": (8, 8),
"encoder_qkv_multiscales": ((), (5,)),
"decoder_qkv_multiscales": ((), (5,)),
"encoder_layers_per_block": (1, 1),
"decoder_layers_per_block": [1, 1],
"downsample_block_type": "conv",
"upsample_block_type": "interpolate",
"decoder_norm_types": "rms_norm",
"decoder_act_fns": "silu",
"scaling_factor": 0.41407,
}
@property
def dummy_input(self):
batch_size = 4
num_channels = 3
sizes = (32, 32)
image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
return {"sample": image}
@property
def input_shape(self):
return (3, 32, 32)
@property
def output_shape(self):
return (3, 32, 32)
def prepare_init_args_and_inputs_for_common(self):
init_dict = self.get_autoencoder_dc_config()
inputs_dict = self.dummy_input
return init_dict, inputs_dict
@unittest.skip("AutoencoderDC does not support `norm_num_groups` because it does not use GroupNorm.")
def test_forward_with_norm_groups(self):
pass