Compare commits

...

65 Commits

Author SHA1 Message Date
sayakpaul
403d3f20f7 small nits. 2026-03-05 14:05:53 +05:30
Sayak Paul
441224ac00 Merge branch 'main' into rae 2026-03-05 12:27:28 +05:30
dg845
af0bed007a Merge branch 'main' into rae 2026-03-04 17:04:49 -08:00
Ando
ed9bcfd7a9 Merge branch 'huggingface:main' into rae 2026-03-04 19:21:12 +08:00
Kashif Rasul
05d3edca66 use randn_tensor 2026-03-04 10:16:07 +00:00
Kashif Rasul
f4ec0f1443 remove unittest 2026-03-04 10:12:40 +00:00
Kashif Rasul
fa016b196c rename 2026-03-04 09:55:54 +00:00
Kashif Rasul
33d98a85da fix api 2026-03-04 09:55:25 +00:00
Kashif Rasul
14d918ee88 Merge branch 'main' into rae 2026-03-04 10:18:06 +01:00
Kashif Rasul
bc59324a2f Update src/diffusers/models/autoencoders/autoencoder_rae.py
Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>
2026-03-04 10:12:50 +01:00
Kashif Rasul
b9a5266cec _noising takes a generator 2026-03-04 09:12:19 +00:00
Kashif Rasul
876e930780 remove optional 2026-03-04 09:09:09 +00:00
Kashif Rasul
df1af7d907 Update src/diffusers/models/autoencoders/autoencoder_rae.py
Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>
2026-03-04 10:04:23 +01:00
Kashif Rasul
af75d8b9e2 inline 2026-03-04 09:03:37 +00:00
Kashif Rasul
e805be989e use buffer 2026-03-04 09:00:09 +00:00
Kashif Rasul
3958fda3bf Update src/diffusers/models/autoencoders/autoencoder_rae.py
Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>
2026-03-04 09:53:33 +01:00
Kashif Rasul
196f8a36c7 error out as soon as possible and add comments 2026-03-04 08:52:08 +00:00
Sayak Paul
9c0f96b303 Merge branch 'main' into rae 2026-03-03 17:06:14 +05:30
Kashif Rasul
bc71889852 update training script 2026-03-03 09:10:58 +00:00
Kashif Rasul
3a6689518f add dispatch forward and update conversion script 2026-03-03 09:03:28 +00:00
Kashif Rasul
5817416a19 fix test 2026-03-02 08:11:31 +00:00
Kashif Rasul
e834e498b2 _strip_final_layernorm_affine for training script 2026-02-28 19:40:19 +00:00
Kashif Rasul
f15873af72 strip final layernorm when converting 2026-02-28 19:35:21 +00:00
Sayak Paul
bff48d317e Merge branch 'main' into rae 2026-02-28 22:01:01 +05:30
Kashif Rasul
cd86873ea6 make quality 2026-02-28 16:28:04 +00:00
Kashif Rasul
34787e5b9b use ModelTesterMixin and AutoencoderTesterMixin 2026-02-28 16:22:47 +00:00
Kashif Rasul
9ada5768e5 remove config 2026-02-28 16:05:19 +00:00
Kashif Rasul
8861a8082a fix slow test 2026-02-28 15:57:10 +00:00
Kashif Rasul
03e757ca73 Encoder is frozen 2026-02-28 15:35:28 +00:00
Kashif Rasul
c717498fa3 use image url 2026-02-28 15:08:56 +00:00
Kashif Rasul
1b4a43f59d Update src/diffusers/models/autoencoders/autoencoder_rae.py
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2026-02-27 11:43:20 +01:00
Kashif Rasul
6a78767864 Update examples/research_projects/autoencoder_rae/README.md
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2026-02-27 11:42:45 +01:00
Kashif Rasul
663b580418 latebt normalization buffers are now always registered with no-op defaults 2026-02-26 10:45:30 +00:00
Kashif Rasul
d965cabe79 fix conversion script review 2026-02-26 10:44:27 +00:00
Kashif Rasul
5c85781519 fix train script to use pretrained 2026-02-26 10:38:47 +00:00
Kashif Rasul
c71cb44299 Merge branch 'rae' of https://github.com/Ando233/diffusers into rae 2026-02-26 10:30:32 +00:00
Kashif Rasul
dca59233f6 address reviews 2026-02-26 10:30:26 +00:00
Kashif Rasul
b3ffd6344a cleanups 2026-02-26 10:26:30 +00:00
Kashif Rasul
7debd07541 Merge branch 'main' into rae 2026-02-26 11:08:08 +01:00
Kashif Rasul
b297868201 fixes from pretrained weights 2026-02-25 13:38:22 +00:00
Kashif Rasul
28a02eb226 undo last change 2026-02-23 10:05:24 +00:00
Kashif Rasul
61885f37e3 added encoder_image_size config 2026-02-23 09:59:26 +00:00
Kashif Rasul
c68b812cb0 fix entrypoint for instantiating the AutoencoderRAE 2026-02-23 09:40:18 +00:00
Kashif Rasul
d8b2983b9e Merge branch 'main' into rae 2026-02-17 10:10:40 +01:00
Kashif Rasul
d06b501850 fix training script 2026-02-16 13:00:00 +00:00
Kashif Rasul
a4fc9f64b2 simplify mixins 2026-02-16 12:52:20 +00:00
Kashif Rasul
fc5295951a cleanup 2026-02-16 12:40:36 +00:00
Kashif Rasul
96520c4ff1 move loss to training script 2026-02-16 12:35:18 +00:00
Kashif Rasul
d3cbd5a60b fix argument 2026-02-16 00:03:54 +00:00
Kashif Rasul
906d79a432 input and ground truth sizes have to be the same 2026-02-16 00:02:27 +00:00
Kashif Rasul
9522e68a5b example traiing script 2026-02-15 23:56:19 +00:00
Kashif Rasul
6a9bde6964 remove unneeded class 2026-02-15 23:55:06 +00:00
Kashif Rasul
e6d449933d use attention 2026-02-15 23:50:52 +00:00
Kashif Rasul
7cbbf271f3 use imports 2026-02-15 23:33:30 +00:00
Kashif Rasul
202b14f6a4 add rae to diffusers script 2026-02-15 23:19:53 +00:00
Kashif Rasul
0d59b22732 cleanup 2026-02-15 23:19:13 +00:00
Kashif Rasul
d7cb12470b use mean and std convention 2026-02-15 22:57:02 +00:00
Kashif Rasul
f06ea7a901 fix latent_mean / latent_var init types to accept config-friendly inputs 2026-02-15 22:51:36 +00:00
Kashif Rasul
25bc9e334c initial doc 2026-02-15 22:44:46 +00:00
Kashif Rasul
24acab0bcc make fix-copies 2026-02-15 22:44:16 +00:00
Kashif Rasul
0850c8cdc9 fix formatting 2026-02-15 22:39:59 +00:00
Kashif Rasul
3ecf89d044 Merge branch 'main' into rae 2026-02-15 23:05:44 +01:00
Ando
a3926d77d7 Merge branch 'main' into rae 2026-01-28 20:31:20 +08:00
wangyuqi
f82cecc298 feat: finish first version of autoencoder_rae 2026-01-28 20:19:31 +08:00
wangyuqi
382aad0a6c feat: implement three RAE encoders(dinov2, siglip2, mae) 2026-01-25 02:54:35 +08:00
11 changed files with 1977 additions and 0 deletions

View File

@@ -460,6 +460,8 @@
title: AutoencoderKLQwenImage
- local: api/models/autoencoder_kl_wan
title: AutoencoderKLWan
- local: api/models/autoencoder_rae
title: AutoencoderRAE
- local: api/models/consistency_decoder_vae
title: ConsistencyDecoderVAE
- local: api/models/autoencoder_oobleck

View File

@@ -0,0 +1,89 @@
<!-- Copyright 2026 The NYU Vision-X and HuggingFace Teams. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# AutoencoderRAE
The Representation Autoencoder (RAE) model introduced in [Diffusion Transformers with Representation Autoencoders](https://huggingface.co/papers/2510.11690) by Boyang Zheng, Nanye Ma, Shengbang Tong, Saining Xie from NYU VISIONx.
RAE combines a frozen pretrained vision encoder (DINOv2, SigLIP2, or MAE) with a trainable ViT-MAE-style decoder. In the two-stage RAE training recipe, the autoencoder is trained in stage 1 (reconstruction), and then a diffusion model is trained on the resulting latent space in stage 2 (generation).
The following RAE models are released and supported in Diffusers:
| Model | Encoder | Latent shape (224px input) |
|:------|:--------|:---------------------------|
| [`nyu-visionx/RAE-dinov2-wReg-base-ViTXL-n08`](https://huggingface.co/nyu-visionx/RAE-dinov2-wReg-base-ViTXL-n08) | DINOv2-base | 768 x 16 x 16 |
| [`nyu-visionx/RAE-dinov2-wReg-base-ViTXL-n08-i512`](https://huggingface.co/nyu-visionx/RAE-dinov2-wReg-base-ViTXL-n08-i512) | DINOv2-base (512px) | 768 x 32 x 32 |
| [`nyu-visionx/RAE-dinov2-wReg-small-ViTXL-n08`](https://huggingface.co/nyu-visionx/RAE-dinov2-wReg-small-ViTXL-n08) | DINOv2-small | 384 x 16 x 16 |
| [`nyu-visionx/RAE-dinov2-wReg-large-ViTXL-n08`](https://huggingface.co/nyu-visionx/RAE-dinov2-wReg-large-ViTXL-n08) | DINOv2-large | 1024 x 16 x 16 |
| [`nyu-visionx/RAE-siglip2-base-p16-i256-ViTXL-n08`](https://huggingface.co/nyu-visionx/RAE-siglip2-base-p16-i256-ViTXL-n08) | SigLIP2-base | 768 x 16 x 16 |
| [`nyu-visionx/RAE-mae-base-p16-ViTXL-n08`](https://huggingface.co/nyu-visionx/RAE-mae-base-p16-ViTXL-n08) | MAE-base | 768 x 16 x 16 |
## Loading a pretrained model
```python
from diffusers import AutoencoderRAE
model = AutoencoderRAE.from_pretrained(
"nyu-visionx/RAE-dinov2-wReg-base-ViTXL-n08"
).to("cuda").eval()
```
## Encoding and decoding a real image
```python
import torch
from diffusers import AutoencoderRAE
from diffusers.utils import load_image
from torchvision.transforms.functional import to_tensor, to_pil_image
model = AutoencoderRAE.from_pretrained(
"nyu-visionx/RAE-dinov2-wReg-base-ViTXL-n08"
).to("cuda").eval()
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cat.png")
image = image.convert("RGB").resize((224, 224))
x = to_tensor(image).unsqueeze(0).to("cuda") # (1, 3, 224, 224), values in [0, 1]
with torch.no_grad():
latents = model.encode(x).latent # (1, 768, 16, 16)
recon = model.decode(latents).sample # (1, 3, 256, 256)
recon_image = to_pil_image(recon[0].clamp(0, 1).cpu())
recon_image.save("recon.png")
```
## Latent normalization
Some pretrained checkpoints include per-channel `latents_mean` and `latents_std` statistics for normalizing the latent space. When present, `encode` and `decode` automatically apply the normalization and denormalization, respectively.
```python
model = AutoencoderRAE.from_pretrained(
"nyu-visionx/RAE-dinov2-wReg-base-ViTXL-n08"
).to("cuda").eval()
# Latent normalization is handled automatically inside encode/decode
# when the checkpoint config includes latents_mean/latents_std.
with torch.no_grad():
latents = model.encode(x).latent # normalized latents
recon = model.decode(latents).sample
```
## AutoencoderRAE
[[autodoc]] AutoencoderRAE
- encode
- decode
- all
## DecoderOutput
[[autodoc]] models.autoencoders.vae.DecoderOutput

View File

@@ -0,0 +1,66 @@
# Training AutoencoderRAE
This example trains the decoder of `AutoencoderRAE` (stage-1 style), while keeping the representation encoder frozen.
It follows the same high-level training recipe as the official RAE stage-1 setup:
- frozen encoder
- train decoder
- pixel reconstruction loss
- optional encoder feature consistency loss
## Quickstart
### Resume or finetune from pretrained weights
```bash
accelerate launch examples/research_projects/autoencoder_rae/train_autoencoder_rae.py \
--pretrained_model_name_or_path nyu-visionx/RAE-dinov2-wReg-base-ViTXL-n08 \
--train_data_dir /path/to/imagenet_like_folder \
--output_dir /tmp/autoencoder-rae \
--resolution 256 \
--train_batch_size 8 \
--learning_rate 1e-4 \
--num_train_epochs 10 \
--report_to wandb \
--reconstruction_loss_type l1 \
--use_encoder_loss \
--encoder_loss_weight 0.1
```
### Train from scratch with a pretrained encoder
The following command launches RAE training with "facebook/dinov2-with-registers-base" as the base.
```bash
accelerate launch examples/research_projects/autoencoder_rae/train_autoencoder_rae.py \
--train_data_dir /path/to/imagenet_like_folder \
--output_dir /tmp/autoencoder-rae \
--resolution 256 \
--encoder_type dinov2 \
--encoder_name_or_path facebook/dinov2-with-registers-base \
--encoder_input_size 224 \
--patch_size 16 \
--image_size 256 \
--decoder_hidden_size 1152 \
--decoder_num_hidden_layers 28 \
--decoder_num_attention_heads 16 \
--decoder_intermediate_size 4096 \
--train_batch_size 8 \
--learning_rate 1e-4 \
--num_train_epochs 10 \
--report_to wandb \
--reconstruction_loss_type l1 \
--use_encoder_loss \
--encoder_loss_weight 0.1
```
Note: stage-1 reconstruction loss assumes matching target/output spatial size, so `--resolution` must equal `--image_size`.
Dataset format is expected to be `ImageFolder`-compatible:
```text
train_data_dir/
class_a/
img_0001.jpg
class_b/
img_0002.jpg
```

View File

@@ -0,0 +1,405 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
import math
import os
from pathlib import Path
import torch
import torch.nn.functional as F
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import ImageFolder
from tqdm.auto import tqdm
from diffusers import AutoencoderRAE
from diffusers.optimization import get_scheduler
logger = get_logger(__name__)
def parse_args():
parser = argparse.ArgumentParser(description="Train a stage-1 Representation Autoencoder (RAE) decoder.")
parser.add_argument(
"--train_data_dir",
type=str,
required=True,
help="Path to an ImageFolder-style dataset root.",
)
parser.add_argument(
"--output_dir", type=str, default="autoencoder-rae", help="Directory to save checkpoints/model."
)
parser.add_argument("--logging_dir", type=str, default="logs", help="Accelerate logging directory.")
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--resolution", type=int, default=256)
parser.add_argument("--center_crop", action="store_true")
parser.add_argument("--random_flip", action="store_true")
parser.add_argument("--train_batch_size", type=int, default=8)
parser.add_argument("--dataloader_num_workers", type=int, default=4)
parser.add_argument("--num_train_epochs", type=int, default=10)
parser.add_argument("--max_train_steps", type=int, default=None)
parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
parser.add_argument("--max_grad_norm", type=float, default=1.0)
parser.add_argument("--learning_rate", type=float, default=1e-4)
parser.add_argument("--adam_beta1", type=float, default=0.9)
parser.add_argument("--adam_beta2", type=float, default=0.999)
parser.add_argument("--adam_weight_decay", type=float, default=1e-2)
parser.add_argument("--adam_epsilon", type=float, default=1e-8)
parser.add_argument("--lr_scheduler", type=str, default="cosine")
parser.add_argument("--lr_warmup_steps", type=int, default=500)
parser.add_argument("--checkpointing_steps", type=int, default=1000)
parser.add_argument("--validation_steps", type=int, default=500)
parser.add_argument(
"--pretrained_model_name_or_path",
type=str,
default=None,
help="Path to a pretrained AutoencoderRAE model (or HF Hub id) to resume training from.",
)
parser.add_argument(
"--encoder_name_or_path",
type=str,
default=None,
help=(
"HF Hub id or local path of the pretrained encoder (e.g. 'facebook/dinov2-with-registers-base'). "
"When --pretrained_model_name_or_path is not set, the encoder weights are loaded from this path "
"into a freshly constructed AutoencoderRAE. Ignored when --pretrained_model_name_or_path is set."
),
)
parser.add_argument("--encoder_type", type=str, choices=["dinov2", "siglip2", "mae"], default="dinov2")
parser.add_argument("--encoder_hidden_size", type=int, default=768)
parser.add_argument("--encoder_patch_size", type=int, default=14)
parser.add_argument("--encoder_num_hidden_layers", type=int, default=12)
parser.add_argument("--encoder_input_size", type=int, default=224)
parser.add_argument("--patch_size", type=int, default=16)
parser.add_argument("--image_size", type=int, default=256)
parser.add_argument("--num_channels", type=int, default=3)
parser.add_argument("--decoder_hidden_size", type=int, default=1152)
parser.add_argument("--decoder_num_hidden_layers", type=int, default=28)
parser.add_argument("--decoder_num_attention_heads", type=int, default=16)
parser.add_argument("--decoder_intermediate_size", type=int, default=4096)
parser.add_argument("--noise_tau", type=float, default=0.0)
parser.add_argument("--scaling_factor", type=float, default=1.0)
parser.add_argument("--reshape_to_2d", action=argparse.BooleanOptionalAction, default=True)
parser.add_argument(
"--reconstruction_loss_type",
type=str,
choices=["l1", "mse"],
default="l1",
help="Pixel reconstruction loss.",
)
parser.add_argument(
"--encoder_loss_weight",
type=float,
default=0.0,
help="Weight for encoder feature consistency loss in the training loop.",
)
parser.add_argument(
"--use_encoder_loss",
action="store_true",
help="Enable encoder feature consistency loss term in the training loop.",
)
parser.add_argument("--report_to", type=str, default="tensorboard")
return parser.parse_args()
def build_transforms(args):
image_transforms = [
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BICUBIC),
transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
]
if args.random_flip:
image_transforms.append(transforms.RandomHorizontalFlip())
image_transforms.append(transforms.ToTensor())
return transforms.Compose(image_transforms)
def compute_losses(
model, pixel_values, reconstruction_loss_type: str, use_encoder_loss: bool, encoder_loss_weight: float
):
decoded = model(pixel_values).sample
if decoded.shape[-2:] != pixel_values.shape[-2:]:
raise ValueError(
"Training requires matching reconstruction and target sizes, got "
f"decoded={tuple(decoded.shape[-2:])}, target={tuple(pixel_values.shape[-2:])}."
)
if reconstruction_loss_type == "l1":
reconstruction_loss = F.l1_loss(decoded.float(), pixel_values.float())
else:
reconstruction_loss = F.mse_loss(decoded.float(), pixel_values.float())
encoder_loss = torch.zeros_like(reconstruction_loss)
if use_encoder_loss and encoder_loss_weight > 0:
base_model = model.module if hasattr(model, "module") else model
target_encoder_input = base_model._resize_and_normalize(pixel_values)
reconstructed_encoder_input = base_model._resize_and_normalize(decoded)
encoder_forward_kwargs = {"model": base_model.encoder}
if base_model.config.encoder_type == "mae":
encoder_forward_kwargs["patch_size"] = base_model.config.encoder_patch_size
with torch.no_grad():
target_tokens = base_model._encoder_forward_fn(images=target_encoder_input, **encoder_forward_kwargs)
reconstructed_tokens = base_model._encoder_forward_fn(
images=reconstructed_encoder_input, **encoder_forward_kwargs
)
encoder_loss = F.mse_loss(reconstructed_tokens.float(), target_tokens.float())
loss = reconstruction_loss + float(encoder_loss_weight) * encoder_loss
return decoded, loss, reconstruction_loss, encoder_loss
def _strip_final_layernorm_affine(state_dict, prefix=""):
"""Remove final layernorm weight/bias so the model keeps its default init (identity)."""
keys_to_strip = {f"{prefix}weight", f"{prefix}bias"}
return {k: v for k, v in state_dict.items() if k not in keys_to_strip}
def _load_pretrained_encoder_weights(model, encoder_type, encoder_name_or_path):
"""Load pretrained HF transformers encoder weights into the model's encoder."""
if encoder_type == "dinov2":
from transformers import Dinov2WithRegistersModel
hf_encoder = Dinov2WithRegistersModel.from_pretrained(encoder_name_or_path)
state_dict = hf_encoder.state_dict()
state_dict = _strip_final_layernorm_affine(state_dict, prefix="layernorm.")
elif encoder_type == "siglip2":
from transformers import SiglipModel
hf_encoder = SiglipModel.from_pretrained(encoder_name_or_path).vision_model
state_dict = {f"vision_model.{k}": v for k, v in hf_encoder.state_dict().items()}
state_dict = _strip_final_layernorm_affine(state_dict, prefix="vision_model.post_layernorm.")
elif encoder_type == "mae":
from transformers import ViTMAEForPreTraining
hf_encoder = ViTMAEForPreTraining.from_pretrained(encoder_name_or_path).vit
state_dict = hf_encoder.state_dict()
state_dict = _strip_final_layernorm_affine(state_dict, prefix="layernorm.")
else:
raise ValueError(f"Unknown encoder_type: {encoder_type}")
model.encoder.load_state_dict(state_dict, strict=False)
def main():
args = parse_args()
if args.resolution != args.image_size:
raise ValueError(
f"`--resolution` ({args.resolution}) must match `--image_size` ({args.image_size}) "
"for stage-1 reconstruction loss."
)
logging_dir = Path(args.output_dir, args.logging_dir)
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
project_config=accelerator_project_config,
log_with=args.report_to,
)
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger.info(accelerator.state, main_process_only=False)
if args.seed is not None:
set_seed(args.seed)
if accelerator.is_main_process:
os.makedirs(args.output_dir, exist_ok=True)
accelerator.wait_for_everyone()
dataset = ImageFolder(args.train_data_dir, transform=build_transforms(args))
def collate_fn(examples):
pixel_values = torch.stack([example[0] for example in examples]).float()
return {"pixel_values": pixel_values}
train_dataloader = DataLoader(
dataset,
shuffle=True,
collate_fn=collate_fn,
batch_size=args.train_batch_size,
num_workers=args.dataloader_num_workers,
pin_memory=True,
drop_last=True,
)
if args.pretrained_model_name_or_path is not None:
model = AutoencoderRAE.from_pretrained(args.pretrained_model_name_or_path)
logger.info(f"Loaded pretrained AutoencoderRAE from {args.pretrained_model_name_or_path}")
else:
model = AutoencoderRAE(
encoder_type=args.encoder_type,
encoder_hidden_size=args.encoder_hidden_size,
encoder_patch_size=args.encoder_patch_size,
encoder_num_hidden_layers=args.encoder_num_hidden_layers,
decoder_hidden_size=args.decoder_hidden_size,
decoder_num_hidden_layers=args.decoder_num_hidden_layers,
decoder_num_attention_heads=args.decoder_num_attention_heads,
decoder_intermediate_size=args.decoder_intermediate_size,
patch_size=args.patch_size,
encoder_input_size=args.encoder_input_size,
image_size=args.image_size,
num_channels=args.num_channels,
noise_tau=args.noise_tau,
reshape_to_2d=args.reshape_to_2d,
use_encoder_loss=args.use_encoder_loss,
scaling_factor=args.scaling_factor,
)
if args.encoder_name_or_path is not None:
_load_pretrained_encoder_weights(model, args.encoder_type, args.encoder_name_or_path)
logger.info(f"Loaded pretrained encoder weights from {args.encoder_name_or_path}")
model.encoder.requires_grad_(False)
model.decoder.requires_grad_(True)
model.train()
optimizer = torch.optim.AdamW(
(p for p in model.parameters() if p.requires_grad),
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,
)
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
overrode_max_train_steps = True
lr_scheduler = get_scheduler(
args.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
num_training_steps=args.max_train_steps * accelerator.num_processes,
)
model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
model, optimizer, train_dataloader, lr_scheduler
)
if overrode_max_train_steps:
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
if accelerator.is_main_process:
accelerator.init_trackers("train_autoencoder_rae", config=vars(args))
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
logger.info("***** Running training *****")
logger.info(f" Num examples = {len(dataset)}")
logger.info(f" Num Epochs = {args.num_train_epochs}")
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
logger.info(f" Total optimization steps = {args.max_train_steps}")
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
progress_bar.set_description("Steps")
global_step = 0
for epoch in range(args.num_train_epochs):
for step, batch in enumerate(train_dataloader):
with accelerator.accumulate(model):
pixel_values = batch["pixel_values"]
_, loss, reconstruction_loss, encoder_loss = compute_losses(
model,
pixel_values,
reconstruction_loss_type=args.reconstruction_loss_type,
use_encoder_loss=args.use_encoder_loss,
encoder_loss_weight=args.encoder_loss_weight,
)
accelerator.backward(loss)
if accelerator.sync_gradients:
accelerator.clip_grad_norm_(model.parameters(), args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
logs = {
"loss": loss.detach().item(),
"reconstruction_loss": reconstruction_loss.detach().item(),
"encoder_loss": encoder_loss.detach().item(),
"lr": lr_scheduler.get_last_lr()[0],
}
progress_bar.set_postfix(**logs)
accelerator.log(logs, step=global_step)
if global_step % args.validation_steps == 0:
with torch.no_grad():
_, val_loss, val_reconstruction_loss, val_encoder_loss = compute_losses(
model,
pixel_values,
reconstruction_loss_type=args.reconstruction_loss_type,
use_encoder_loss=args.use_encoder_loss,
encoder_loss_weight=args.encoder_loss_weight,
)
accelerator.log(
{
"val/loss": val_loss.detach().item(),
"val/reconstruction_loss": val_reconstruction_loss.detach().item(),
"val/encoder_loss": val_encoder_loss.detach().item(),
},
step=global_step,
)
if global_step % args.checkpointing_steps == 0:
if accelerator.is_main_process:
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.save_pretrained(save_path)
logger.info(f"Saved checkpoint to {save_path}")
if global_step >= args.max_train_steps:
break
if global_step >= args.max_train_steps:
break
accelerator.wait_for_everyone()
if accelerator.is_main_process:
unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.save_pretrained(args.output_dir)
accelerator.end_training()
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,403 @@
import argparse
from pathlib import Path
from typing import Any
import torch
from huggingface_hub import HfApi, hf_hub_download
from diffusers import AutoencoderRAE
DECODER_CONFIGS = {
"ViTB": {
"decoder_hidden_size": 768,
"decoder_intermediate_size": 3072,
"decoder_num_attention_heads": 12,
"decoder_num_hidden_layers": 12,
},
"ViTL": {
"decoder_hidden_size": 1024,
"decoder_intermediate_size": 4096,
"decoder_num_attention_heads": 16,
"decoder_num_hidden_layers": 24,
},
"ViTXL": {
"decoder_hidden_size": 1152,
"decoder_intermediate_size": 4096,
"decoder_num_attention_heads": 16,
"decoder_num_hidden_layers": 28,
},
}
ENCODER_DEFAULT_NAME_OR_PATH = {
"dinov2": "facebook/dinov2-with-registers-base",
"siglip2": "google/siglip2-base-patch16-256",
"mae": "facebook/vit-mae-base",
}
ENCODER_HIDDEN_SIZE = {
"dinov2": 768,
"siglip2": 768,
"mae": 768,
}
ENCODER_PATCH_SIZE = {
"dinov2": 14,
"siglip2": 16,
"mae": 16,
}
DEFAULT_DECODER_SUBDIR = {
"dinov2": "decoders/dinov2/wReg_base",
"mae": "decoders/mae/base_p16",
"siglip2": "decoders/siglip2/base_p16_i256",
}
DEFAULT_STATS_SUBDIR = {
"dinov2": "stats/dinov2/wReg_base",
"mae": "stats/mae/base_p16",
"siglip2": "stats/siglip2/base_p16_i256",
}
DECODER_FILE_CANDIDATES = ("dinov2_decoder.pt", "model.pt")
STATS_FILE_CANDIDATES = ("stat.pt",)
def dataset_case_candidates(name: str) -> tuple[str, ...]:
return (name, name.lower(), name.upper(), name.title(), "imagenet1k", "ImageNet1k")
class RepoAccessor:
def __init__(self, repo_or_path: str, cache_dir: str | None = None):
self.repo_or_path = repo_or_path
self.cache_dir = cache_dir
self.local_root: Path | None = None
self.repo_id: str | None = None
self.repo_files: set[str] | None = None
root = Path(repo_or_path)
if root.exists() and root.is_dir():
self.local_root = root
else:
self.repo_id = repo_or_path
self.repo_files = set(HfApi().list_repo_files(repo_or_path))
def exists(self, relative_path: str) -> bool:
relative_path = relative_path.replace("\\", "/")
if self.local_root is not None:
return (self.local_root / relative_path).is_file()
return relative_path in self.repo_files
def fetch(self, relative_path: str) -> Path:
relative_path = relative_path.replace("\\", "/")
if self.local_root is not None:
return self.local_root / relative_path
downloaded = hf_hub_download(repo_id=self.repo_id, filename=relative_path, cache_dir=self.cache_dir)
return Path(downloaded)
def unwrap_state_dict(maybe_wrapped: dict[str, Any]) -> dict[str, Any]:
state_dict = maybe_wrapped
for k in ("model", "module", "state_dict"):
if isinstance(state_dict, dict) and k in state_dict and isinstance(state_dict[k], dict):
state_dict = state_dict[k]
out = dict(state_dict)
if len(out) > 0 and all(key.startswith("module.") for key in out):
out = {key[len("module.") :]: value for key, value in out.items()}
if len(out) > 0 and all(key.startswith("decoder.") for key in out):
out = {key[len("decoder.") :]: value for key, value in out.items()}
return out
def remap_decoder_attention_keys_for_diffusers(state_dict: dict[str, Any]) -> dict[str, Any]:
"""
Map official RAE decoder attention key layout to diffusers Attention layout used by AutoencoderRAE decoder.
Example mappings:
- `...attention.attention.query.*` -> `...attention.to_q.*`
- `...attention.attention.key.*` -> `...attention.to_k.*`
- `...attention.attention.value.*` -> `...attention.to_v.*`
- `...attention.output.dense.*` -> `...attention.to_out.0.*`
"""
remapped: dict[str, Any] = {}
for key, value in state_dict.items():
new_key = key
new_key = new_key.replace(".attention.attention.query.", ".attention.to_q.")
new_key = new_key.replace(".attention.attention.key.", ".attention.to_k.")
new_key = new_key.replace(".attention.attention.value.", ".attention.to_v.")
new_key = new_key.replace(".attention.output.dense.", ".attention.to_out.0.")
remapped[new_key] = value
return remapped
def resolve_decoder_file(
accessor: RepoAccessor, encoder_type: str, variant: str, decoder_checkpoint: str | None
) -> str:
if decoder_checkpoint is not None:
if accessor.exists(decoder_checkpoint):
return decoder_checkpoint
raise FileNotFoundError(f"Decoder checkpoint not found: {decoder_checkpoint}")
base = f"{DEFAULT_DECODER_SUBDIR[encoder_type]}/{variant}"
for name in DECODER_FILE_CANDIDATES:
candidate = f"{base}/{name}"
if accessor.exists(candidate):
return candidate
raise FileNotFoundError(
f"Could not find decoder checkpoint under `{base}`. Tried: {list(DECODER_FILE_CANDIDATES)}"
)
def resolve_stats_file(
accessor: RepoAccessor,
encoder_type: str,
dataset_name: str,
stats_checkpoint: str | None,
) -> str | None:
if stats_checkpoint is not None:
if accessor.exists(stats_checkpoint):
return stats_checkpoint
raise FileNotFoundError(f"Stats checkpoint not found: {stats_checkpoint}")
base = DEFAULT_STATS_SUBDIR[encoder_type]
for dataset in dataset_case_candidates(dataset_name):
for name in STATS_FILE_CANDIDATES:
candidate = f"{base}/{dataset}/{name}"
if accessor.exists(candidate):
return candidate
return None
def extract_latent_stats(stats_obj: Any) -> tuple[Any | None, Any | None]:
if not isinstance(stats_obj, dict):
return None, None
if "latents_mean" in stats_obj or "latents_std" in stats_obj:
return stats_obj.get("latents_mean", None), stats_obj.get("latents_std", None)
mean = stats_obj.get("mean", None)
var = stats_obj.get("var", None)
if mean is None and var is None:
return None, None
latents_std = None
if var is not None:
if isinstance(var, torch.Tensor):
latents_std = torch.sqrt(var + 1e-5)
else:
latents_std = torch.sqrt(torch.tensor(var) + 1e-5)
return mean, latents_std
def _strip_final_layernorm_affine(state_dict: dict[str, Any], prefix: str = "") -> dict[str, Any]:
"""Remove final layernorm weight/bias from encoder state dict.
RAE uses non-affine layernorm (weight=1, bias=0 is the default identity).
Stripping these keys means the model keeps its default init values, which
is functionally equivalent to setting elementwise_affine=False.
"""
keys_to_strip = {f"{prefix}weight", f"{prefix}bias"}
return {k: v for k, v in state_dict.items() if k not in keys_to_strip}
def _load_hf_encoder_state_dict(encoder_type: str, encoder_name_or_path: str) -> dict[str, Any]:
"""Download the HF encoder and extract the state dict for the inner model."""
if encoder_type == "dinov2":
from transformers import Dinov2WithRegistersModel
hf_model = Dinov2WithRegistersModel.from_pretrained(encoder_name_or_path)
sd = hf_model.state_dict()
return _strip_final_layernorm_affine(sd, prefix="layernorm.")
elif encoder_type == "siglip2":
from transformers import SiglipModel
# SiglipModel.vision_model is a SiglipVisionTransformer.
# Our Siglip2Encoder wraps it inside SiglipVisionModel which nests it
# under .vision_model, so we add the prefix to match the diffusers key layout.
hf_model = SiglipModel.from_pretrained(encoder_name_or_path).vision_model
sd = {f"vision_model.{k}": v for k, v in hf_model.state_dict().items()}
return _strip_final_layernorm_affine(sd, prefix="vision_model.post_layernorm.")
elif encoder_type == "mae":
from transformers import ViTMAEForPreTraining
hf_model = ViTMAEForPreTraining.from_pretrained(encoder_name_or_path).vit
sd = hf_model.state_dict()
return _strip_final_layernorm_affine(sd, prefix="layernorm.")
else:
raise ValueError(f"Unknown encoder_type: {encoder_type}")
def convert(args: argparse.Namespace) -> None:
accessor = RepoAccessor(args.repo_or_path, cache_dir=args.cache_dir)
encoder_name_or_path = args.encoder_name_or_path or ENCODER_DEFAULT_NAME_OR_PATH[args.encoder_type]
decoder_relpath = resolve_decoder_file(accessor, args.encoder_type, args.variant, args.decoder_checkpoint)
stats_relpath = resolve_stats_file(accessor, args.encoder_type, args.dataset_name, args.stats_checkpoint)
print(f"Using decoder checkpoint: {decoder_relpath}")
if stats_relpath is not None:
print(f"Using stats checkpoint: {stats_relpath}")
else:
print("No stats checkpoint found; conversion will proceed without latent stats.")
if args.dry_run:
return
decoder_path = accessor.fetch(decoder_relpath)
decoder_obj = torch.load(decoder_path, map_location="cpu")
decoder_state_dict = unwrap_state_dict(decoder_obj)
decoder_state_dict = remap_decoder_attention_keys_for_diffusers(decoder_state_dict)
latents_mean, latents_std = None, None
if stats_relpath is not None:
stats_path = accessor.fetch(stats_relpath)
stats_obj = torch.load(stats_path, map_location="cpu")
latents_mean, latents_std = extract_latent_stats(stats_obj)
decoder_cfg = DECODER_CONFIGS[args.decoder_config_name]
# Read encoder normalization stats from the HF image processor (only place that downloads encoder info)
from transformers import AutoConfig, AutoImageProcessor
proc = AutoImageProcessor.from_pretrained(encoder_name_or_path)
encoder_norm_mean = list(proc.image_mean)
encoder_norm_std = list(proc.image_std)
# Read encoder hidden size and patch size from HF config
encoder_hidden_size = ENCODER_HIDDEN_SIZE[args.encoder_type]
encoder_patch_size = ENCODER_PATCH_SIZE[args.encoder_type]
try:
hf_config = AutoConfig.from_pretrained(encoder_name_or_path)
# For models like SigLIP that nest vision config
if hasattr(hf_config, "vision_config"):
hf_config = hf_config.vision_config
encoder_hidden_size = hf_config.hidden_size
encoder_patch_size = hf_config.patch_size
except Exception:
pass
# Load the actual encoder weights from HF to include in the saved model
encoder_state_dict = _load_hf_encoder_state_dict(args.encoder_type, encoder_name_or_path)
# Build model on meta device to avoid double init overhead
with torch.device("meta"):
model = AutoencoderRAE(
encoder_type=args.encoder_type,
encoder_hidden_size=encoder_hidden_size,
encoder_patch_size=encoder_patch_size,
encoder_input_size=args.encoder_input_size,
patch_size=args.patch_size,
image_size=args.image_size,
num_channels=args.num_channels,
encoder_norm_mean=encoder_norm_mean,
encoder_norm_std=encoder_norm_std,
decoder_hidden_size=decoder_cfg["decoder_hidden_size"],
decoder_num_hidden_layers=decoder_cfg["decoder_num_hidden_layers"],
decoder_num_attention_heads=decoder_cfg["decoder_num_attention_heads"],
decoder_intermediate_size=decoder_cfg["decoder_intermediate_size"],
latents_mean=latents_mean,
latents_std=latents_std,
scaling_factor=args.scaling_factor,
)
# Assemble full state dict and load with assign=True
full_state_dict = {}
# Encoder weights (prefixed with "encoder.")
for k, v in encoder_state_dict.items():
full_state_dict[f"encoder.{k}"] = v
# Decoder weights (prefixed with "decoder.")
for k, v in decoder_state_dict.items():
full_state_dict[f"decoder.{k}"] = v
# Buffers from config
full_state_dict["encoder_mean"] = torch.tensor(encoder_norm_mean, dtype=torch.float32).view(1, 3, 1, 1)
full_state_dict["encoder_std"] = torch.tensor(encoder_norm_std, dtype=torch.float32).view(1, 3, 1, 1)
if latents_mean is not None:
latents_mean_t = latents_mean if isinstance(latents_mean, torch.Tensor) else torch.tensor(latents_mean)
full_state_dict["_latents_mean"] = latents_mean_t
else:
full_state_dict["_latents_mean"] = torch.zeros(1)
if latents_std is not None:
latents_std_t = latents_std if isinstance(latents_std, torch.Tensor) else torch.tensor(latents_std)
full_state_dict["_latents_std"] = latents_std_t
else:
full_state_dict["_latents_std"] = torch.ones(1)
model.load_state_dict(full_state_dict, strict=False, assign=True)
# Verify no critical keys are missing
model_keys = {name for name, _ in model.named_parameters()}
model_keys |= {name for name, _ in model.named_buffers()}
loaded_keys = set(full_state_dict.keys())
missing = model_keys - loaded_keys
# trainable_cls_token and decoder_pos_embed are initialized, not loaded from original checkpoint
allowed_missing = {"decoder.trainable_cls_token", "decoder.decoder_pos_embed"}
if missing - allowed_missing:
print(f"Warning: missing keys after conversion: {sorted(missing - allowed_missing)}")
output_path = Path(args.output_path)
output_path.mkdir(parents=True, exist_ok=True)
model.save_pretrained(output_path)
if args.verify_load:
print("Verifying converted checkpoint with AutoencoderRAE.from_pretrained(low_cpu_mem_usage=False)...")
loaded_model = AutoencoderRAE.from_pretrained(output_path, low_cpu_mem_usage=False)
if not isinstance(loaded_model, AutoencoderRAE):
raise RuntimeError("Verification failed: loaded object is not AutoencoderRAE.")
print("Verification passed.")
print(f"Saved converted AutoencoderRAE to: {output_path}")
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Convert RAE decoder checkpoints to diffusers AutoencoderRAE format")
parser.add_argument(
"--repo_or_path", type=str, required=True, help="Hub repo id (e.g. nyu-visionx/RAE-collections) or local path"
)
parser.add_argument("--output_path", type=str, required=True, help="Directory to save converted model")
parser.add_argument("--encoder_type", type=str, choices=["dinov2", "mae", "siglip2"], required=True)
parser.add_argument(
"--encoder_name_or_path", type=str, default=None, help="Optional encoder HF model id or local path override"
)
parser.add_argument("--variant", type=str, default="ViTXL_n08", help="Decoder variant folder name")
parser.add_argument("--dataset_name", type=str, default="imagenet1k", help="Stats dataset folder name")
parser.add_argument(
"--decoder_checkpoint", type=str, default=None, help="Relative path to decoder checkpoint inside repo/path"
)
parser.add_argument(
"--stats_checkpoint", type=str, default=None, help="Relative path to stats checkpoint inside repo/path"
)
parser.add_argument("--decoder_config_name", type=str, choices=list(DECODER_CONFIGS.keys()), default="ViTXL")
parser.add_argument("--encoder_input_size", type=int, default=224)
parser.add_argument("--patch_size", type=int, default=16)
parser.add_argument("--image_size", type=int, default=None)
parser.add_argument("--num_channels", type=int, default=3)
parser.add_argument("--scaling_factor", type=float, default=1.0)
parser.add_argument("--cache_dir", type=str, default=None)
parser.add_argument("--dry_run", action="store_true", help="Only resolve and print selected files")
parser.add_argument(
"--verify_load",
action="store_true",
help="After conversion, load back with AutoencoderRAE.from_pretrained(low_cpu_mem_usage=False).",
)
return parser.parse_args()
def main() -> None:
args = parse_args()
convert(args)
if __name__ == "__main__":
main()

View File

@@ -202,6 +202,7 @@ else:
"AutoencoderKLTemporalDecoder",
"AutoencoderKLWan",
"AutoencoderOobleck",
"AutoencoderRAE",
"AutoencoderTiny",
"AutoModel",
"BriaFiboTransformer2DModel",
@@ -974,6 +975,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AutoencoderKLTemporalDecoder,
AutoencoderKLWan,
AutoencoderOobleck,
AutoencoderRAE,
AutoencoderTiny,
AutoModel,
BriaFiboTransformer2DModel,

View File

@@ -49,6 +49,7 @@ if is_torch_available():
_import_structure["autoencoders.autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"]
_import_structure["autoencoders.autoencoder_kl_wan"] = ["AutoencoderKLWan"]
_import_structure["autoencoders.autoencoder_oobleck"] = ["AutoencoderOobleck"]
_import_structure["autoencoders.autoencoder_rae"] = ["AutoencoderRAE"]
_import_structure["autoencoders.autoencoder_tiny"] = ["AutoencoderTiny"]
_import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
_import_structure["autoencoders.vq_model"] = ["VQModel"]
@@ -168,6 +169,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AutoencoderKLTemporalDecoder,
AutoencoderKLWan,
AutoencoderOobleck,
AutoencoderRAE,
AutoencoderTiny,
ConsistencyDecoderVAE,
VQModel,

View File

@@ -18,6 +18,7 @@ from .autoencoder_kl_qwenimage import AutoencoderKLQwenImage
from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder
from .autoencoder_kl_wan import AutoencoderKLWan
from .autoencoder_oobleck import AutoencoderOobleck
from .autoencoder_rae import AutoencoderRAE
from .autoencoder_tiny import AutoencoderTiny
from .consistency_decoder_vae import ConsistencyDecoderVAE
from .vq_model import VQModel

View File

@@ -0,0 +1,692 @@
# Copyright 2026 The NYU Vision-X and HuggingFace Teams. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from math import sqrt
from typing import Any
import torch
import torch.nn as nn
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import BaseOutput, logging
from ...utils.accelerate_utils import apply_forward_hook
from ...utils.import_utils import is_transformers_available
from ...utils.torch_utils import randn_tensor
if is_transformers_available():
from transformers import (
Dinov2WithRegistersConfig,
Dinov2WithRegistersModel,
SiglipVisionConfig,
SiglipVisionModel,
ViTMAEConfig,
ViTMAEModel,
)
from ..activations import get_activation
from ..attention import AttentionMixin
from ..attention_processor import Attention
from ..embeddings import get_2d_sincos_pos_embed
from ..modeling_utils import ModelMixin
from .vae import AutoencoderMixin, DecoderOutput, EncoderOutput
logger = logging.get_logger(__name__)
# ---------------------------------------------------------------------------
# Per-encoder forward functions
# ---------------------------------------------------------------------------
# Each function takes the raw transformers model + images and returns patch
# tokens of shape (B, N, C), stripping CLS / register tokens as needed.
def _dinov2_encoder_forward(model: nn.Module, images: torch.Tensor) -> torch.Tensor:
outputs = model(images, output_hidden_states=True)
unused_token_num = 5 # 1 CLS + 4 register tokens
return outputs.last_hidden_state[:, unused_token_num:]
def _siglip2_encoder_forward(model: nn.Module, images: torch.Tensor) -> torch.Tensor:
outputs = model(images, output_hidden_states=True, interpolate_pos_encoding=True)
return outputs.last_hidden_state
def _mae_encoder_forward(model: nn.Module, images: torch.Tensor, patch_size: int) -> torch.Tensor:
h, w = images.shape[2], images.shape[3]
patch_num = int(h * w // patch_size**2)
if patch_num * patch_size**2 != h * w:
raise ValueError("Image size should be divisible by patch size.")
noise = torch.arange(patch_num).unsqueeze(0).expand(images.shape[0], -1).to(images.device).to(images.dtype)
outputs = model(images, noise, interpolate_pos_encoding=True)
return outputs.last_hidden_state[:, 1:] # remove cls token
# ---------------------------------------------------------------------------
# Encoder construction helpers
# ---------------------------------------------------------------------------
def _build_encoder(
encoder_type: str, hidden_size: int, patch_size: int, num_hidden_layers: int, head_dim: int = 64
) -> nn.Module:
"""Build a frozen encoder from config (no pretrained download)."""
num_attention_heads = hidden_size // head_dim # all supported encoders use head_dim=64
if encoder_type == "dinov2":
config = Dinov2WithRegistersConfig(
hidden_size=hidden_size,
patch_size=patch_size,
image_size=518,
num_attention_heads=num_attention_heads,
num_hidden_layers=num_hidden_layers,
)
model = Dinov2WithRegistersModel(config)
# RAE strips the final layernorm affine params (identity LN). Remove them from
# the architecture so `from_pretrained` doesn't leave them on the meta device.
model.layernorm.weight = None
model.layernorm.bias = None
elif encoder_type == "siglip2":
config = SiglipVisionConfig(
hidden_size=hidden_size,
patch_size=patch_size,
image_size=256,
num_attention_heads=num_attention_heads,
num_hidden_layers=num_hidden_layers,
)
model = SiglipVisionModel(config)
# See dinov2 comment above.
model.vision_model.post_layernorm.weight = None
model.vision_model.post_layernorm.bias = None
elif encoder_type == "mae":
config = ViTMAEConfig(
hidden_size=hidden_size,
patch_size=patch_size,
image_size=224,
num_attention_heads=num_attention_heads,
num_hidden_layers=num_hidden_layers,
mask_ratio=0.0,
)
model = ViTMAEModel(config)
# See dinov2 comment above.
model.layernorm.weight = None
model.layernorm.bias = None
else:
raise ValueError(f"Unknown encoder_type='{encoder_type}'. Available: dinov2, siglip2, mae")
model.requires_grad_(False)
return model
_ENCODER_FORWARD_FNS = {
"dinov2": _dinov2_encoder_forward,
"siglip2": _siglip2_encoder_forward,
"mae": _mae_encoder_forward,
}
@dataclass
class RAEDecoderOutput(BaseOutput):
"""
Output of `RAEDecoder`.
Args:
logits (`torch.Tensor`):
Patch reconstruction logits of shape `(batch_size, num_patches, patch_size**2 * num_channels)`.
"""
logits: torch.Tensor
class ViTMAEIntermediate(nn.Module):
def __init__(self, hidden_size: int, intermediate_size: int, hidden_act: str = "gelu"):
super().__init__()
self.dense = nn.Linear(hidden_size, intermediate_size)
self.intermediate_act_fn = get_activation(hidden_act)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
class ViTMAEOutput(nn.Module):
def __init__(self, hidden_size: int, intermediate_size: int, hidden_dropout_prob: float = 0.0):
super().__init__()
self.dense = nn.Linear(intermediate_size, hidden_size)
self.dropout = nn.Dropout(hidden_dropout_prob)
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = hidden_states + input_tensor
return hidden_states
class ViTMAELayer(nn.Module):
"""
This matches the naming/parameter structure used in RAE-main (ViTMAE decoder block).
"""
def __init__(
self,
*,
hidden_size: int,
num_attention_heads: int,
intermediate_size: int,
qkv_bias: bool = True,
layer_norm_eps: float = 1e-12,
hidden_dropout_prob: float = 0.0,
attention_probs_dropout_prob: float = 0.0,
hidden_act: str = "gelu",
):
super().__init__()
if hidden_size % num_attention_heads != 0:
raise ValueError(
f"hidden_size={hidden_size} must be divisible by num_attention_heads={num_attention_heads}"
)
self.attention = Attention(
query_dim=hidden_size,
heads=num_attention_heads,
dim_head=hidden_size // num_attention_heads,
dropout=attention_probs_dropout_prob,
bias=qkv_bias,
)
self.intermediate = ViTMAEIntermediate(
hidden_size=hidden_size, intermediate_size=intermediate_size, hidden_act=hidden_act
)
self.output = ViTMAEOutput(
hidden_size=hidden_size, intermediate_size=intermediate_size, hidden_dropout_prob=hidden_dropout_prob
)
self.layernorm_before = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
self.layernorm_after = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
attention_output = self.attention(self.layernorm_before(hidden_states))
hidden_states = attention_output + hidden_states
layer_output = self.layernorm_after(hidden_states)
layer_output = self.intermediate(layer_output)
layer_output = self.output(layer_output, hidden_states)
return layer_output
class RAEDecoder(nn.Module):
"""Lightweight RAE decoder."""
def __init__(
self,
hidden_size: int = 768,
decoder_hidden_size: int = 512,
decoder_num_hidden_layers: int = 8,
decoder_num_attention_heads: int = 16,
decoder_intermediate_size: int = 2048,
num_patches: int = 256,
patch_size: int = 16,
num_channels: int = 3,
image_size: int = 256,
qkv_bias: bool = True,
layer_norm_eps: float = 1e-12,
hidden_dropout_prob: float = 0.0,
attention_probs_dropout_prob: float = 0.0,
hidden_act: str = "gelu",
):
super().__init__()
self.decoder_hidden_size = decoder_hidden_size
self.patch_size = patch_size
self.num_channels = num_channels
self.image_size = image_size
self.num_patches = num_patches
self.decoder_embed = nn.Linear(hidden_size, decoder_hidden_size, bias=True)
self.register_buffer("decoder_pos_embed", torch.zeros(1, num_patches + 1, decoder_hidden_size))
self.decoder_layers = nn.ModuleList(
[
ViTMAELayer(
hidden_size=decoder_hidden_size,
num_attention_heads=decoder_num_attention_heads,
intermediate_size=decoder_intermediate_size,
qkv_bias=qkv_bias,
layer_norm_eps=layer_norm_eps,
hidden_dropout_prob=hidden_dropout_prob,
attention_probs_dropout_prob=attention_probs_dropout_prob,
hidden_act=hidden_act,
)
for _ in range(decoder_num_hidden_layers)
]
)
self.decoder_norm = nn.LayerNorm(decoder_hidden_size, eps=layer_norm_eps)
self.decoder_pred = nn.Linear(decoder_hidden_size, patch_size**2 * num_channels, bias=True)
self.gradient_checkpointing = False
self._initialize_weights(num_patches)
self.trainable_cls_token = nn.Parameter(torch.zeros(1, 1, decoder_hidden_size))
def _initialize_weights(self, num_patches: int):
# Skip initialization when parameters are on meta device (e.g. during
# accelerate.init_empty_weights() used by low_cpu_mem_usage loading).
# The weights are initialized.
if self.decoder_pos_embed.device.type == "meta":
return
grid_size = int(num_patches**0.5)
pos_embed = get_2d_sincos_pos_embed(
self.decoder_pos_embed.shape[-1],
grid_size,
cls_token=True,
extra_tokens=1,
output_type="pt",
device=self.decoder_pos_embed.device,
)
self.decoder_pos_embed.data.copy_(pos_embed.unsqueeze(0).to(dtype=self.decoder_pos_embed.dtype))
def interpolate_pos_encoding(self, embeddings: torch.Tensor) -> torch.Tensor:
embeddings_positions = embeddings.shape[1] - 1
num_positions = self.decoder_pos_embed.shape[1] - 1
class_pos_embed = self.decoder_pos_embed[:, 0, :]
patch_pos_embed = self.decoder_pos_embed[:, 1:, :]
dim = self.decoder_pos_embed.shape[-1]
patch_pos_embed = patch_pos_embed.reshape(1, 1, -1, dim).permute(0, 3, 1, 2)
patch_pos_embed = F.interpolate(
patch_pos_embed,
scale_factor=(1, embeddings_positions / num_positions),
mode="bicubic",
align_corners=False,
)
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
def interpolate_latent(self, x: torch.Tensor) -> torch.Tensor:
b, l, c = x.shape
if l == self.num_patches:
return x
h = w = int(l**0.5)
x = x.reshape(b, h, w, c).permute(0, 3, 1, 2)
target_size = (int(self.num_patches**0.5), int(self.num_patches**0.5))
x = F.interpolate(x, size=target_size, mode="bilinear", align_corners=False)
x = x.permute(0, 2, 3, 1).contiguous().view(b, self.num_patches, c)
return x
def unpatchify(self, patchified_pixel_values: torch.Tensor, original_image_size: tuple[int, int] | None = None):
patch_size, num_channels = self.patch_size, self.num_channels
original_image_size = (
original_image_size if original_image_size is not None else (self.image_size, self.image_size)
)
original_height, original_width = original_image_size
num_patches_h = original_height // patch_size
num_patches_w = original_width // patch_size
if num_patches_h * num_patches_w != patchified_pixel_values.shape[1]:
raise ValueError(
f"The number of patches in the patchified pixel values {patchified_pixel_values.shape[1]}, does not match the number of patches on original image {num_patches_h}*{num_patches_w}"
)
batch_size = patchified_pixel_values.shape[0]
patchified_pixel_values = patchified_pixel_values.reshape(
batch_size,
num_patches_h,
num_patches_w,
patch_size,
patch_size,
num_channels,
)
patchified_pixel_values = torch.einsum("nhwpqc->nchpwq", patchified_pixel_values)
pixel_values = patchified_pixel_values.reshape(
batch_size,
num_channels,
num_patches_h * patch_size,
num_patches_w * patch_size,
)
return pixel_values
def forward(
self,
hidden_states: torch.Tensor,
*,
interpolate_pos_encoding: bool = False,
drop_cls_token: bool = False,
return_dict: bool = True,
) -> RAEDecoderOutput | tuple[torch.Tensor]:
x = self.decoder_embed(hidden_states)
if drop_cls_token:
x_ = x[:, 1:, :]
x_ = self.interpolate_latent(x_)
else:
x_ = self.interpolate_latent(x)
cls_token = self.trainable_cls_token.expand(x_.shape[0], -1, -1)
x = torch.cat([cls_token, x_], dim=1)
if interpolate_pos_encoding:
if not drop_cls_token:
raise ValueError("interpolate_pos_encoding only supports drop_cls_token=True")
decoder_pos_embed = self.interpolate_pos_encoding(x)
else:
decoder_pos_embed = self.decoder_pos_embed
hidden_states = x + decoder_pos_embed.to(device=x.device, dtype=x.dtype)
for layer_module in self.decoder_layers:
hidden_states = layer_module(hidden_states)
hidden_states = self.decoder_norm(hidden_states)
logits = self.decoder_pred(hidden_states)
logits = logits[:, 1:, :]
if not return_dict:
return (logits,)
return RAEDecoderOutput(logits=logits)
class AutoencoderRAE(ModelMixin, AttentionMixin, AutoencoderMixin, ConfigMixin):
r"""
Representation Autoencoder (RAE) model for encoding images to latents and decoding latents to images.
This model uses a frozen pretrained encoder (DINOv2, SigLIP2, or MAE) with a trainable ViT decoder to reconstruct
images from learned representations.
This model inherits from [`ModelMixin`]. Check the superclass documentation for its generic methods implemented for
all models (such as downloading or saving).
Args:
encoder_type (`str`, *optional*, defaults to `"dinov2"`):
Type of frozen encoder to use. One of `"dinov2"`, `"siglip2"`, or `"mae"`.
encoder_hidden_size (`int`, *optional*, defaults to `768`):
Hidden size of the encoder model.
encoder_patch_size (`int`, *optional*, defaults to `14`):
Patch size of the encoder model.
encoder_num_hidden_layers (`int`, *optional*, defaults to `12`):
Number of hidden layers in the encoder model.
patch_size (`int`, *optional*, defaults to `16`):
Decoder patch size (used for unpatchify and decoder head).
encoder_input_size (`int`, *optional*, defaults to `224`):
Input size expected by the encoder.
image_size (`int`, *optional*):
Decoder output image size. If `None`, it is derived from encoder token count and `patch_size` like
RAE-main: `image_size = patch_size * sqrt(num_patches)`, where `num_patches = (encoder_input_size //
encoder_patch_size) ** 2`.
num_channels (`int`, *optional*, defaults to `3`):
Number of input/output channels.
encoder_norm_mean (`list`, *optional*, defaults to `[0.485, 0.456, 0.406]`):
Channel-wise mean for encoder input normalization (ImageNet defaults).
encoder_norm_std (`list`, *optional*, defaults to `[0.229, 0.224, 0.225]`):
Channel-wise std for encoder input normalization (ImageNet defaults).
latents_mean (`list` or `tuple`, *optional*):
Optional mean for latent normalization. Tensor inputs are accepted and converted to config-serializable
lists.
latents_std (`list` or `tuple`, *optional*):
Optional standard deviation for latent normalization. Tensor inputs are accepted and converted to
config-serializable lists.
noise_tau (`float`, *optional*, defaults to `0.0`):
Noise level for training (adds noise to latents during training).
reshape_to_2d (`bool`, *optional*, defaults to `True`):
Whether to reshape latents to 2D (B, C, H, W) format.
use_encoder_loss (`bool`, *optional*, defaults to `False`):
Whether to use encoder hidden states in the loss (for advanced training).
"""
# NOTE: gradient checkpointing is not wired up for this model yet.
_supports_gradient_checkpointing = False
_no_split_modules = ["ViTMAELayer"]
@register_to_config
def __init__(
self,
encoder_type: str = "dinov2",
encoder_hidden_size: int = 768,
encoder_patch_size: int = 14,
encoder_num_hidden_layers: int = 12,
decoder_hidden_size: int = 512,
decoder_num_hidden_layers: int = 8,
decoder_num_attention_heads: int = 16,
decoder_intermediate_size: int = 2048,
patch_size: int = 16,
encoder_input_size: int = 224,
image_size: int | None = None,
num_channels: int = 3,
encoder_norm_mean: list | None = None,
encoder_norm_std: list | None = None,
latents_mean: list | tuple | torch.Tensor | None = None,
latents_std: list | tuple | torch.Tensor | None = None,
noise_tau: float = 0.0,
reshape_to_2d: bool = True,
use_encoder_loss: bool = False,
scaling_factor: float = 1.0,
):
super().__init__()
if encoder_type not in _ENCODER_FORWARD_FNS:
raise ValueError(
f"Unknown encoder_type='{encoder_type}'. Available: {sorted(_ENCODER_FORWARD_FNS.keys())}"
)
if encoder_input_size % encoder_patch_size != 0:
raise ValueError(
f"encoder_input_size={encoder_input_size} must be divisible by encoder_patch_size={encoder_patch_size}."
)
decoder_patch_size = patch_size
if decoder_patch_size <= 0:
raise ValueError("patch_size must be a positive integer (this is decoder_patch_size).")
num_patches = (encoder_input_size // encoder_patch_size) ** 2
grid = int(sqrt(num_patches))
if grid * grid != num_patches:
raise ValueError(f"Computed num_patches={num_patches} must be a perfect square.")
derived_image_size = decoder_patch_size * grid
if image_size is None:
image_size = derived_image_size
else:
image_size = int(image_size)
if image_size != derived_image_size:
raise ValueError(
f"image_size={image_size} must equal decoder_patch_size*sqrt(num_patches)={derived_image_size} "
f"for patch_size={decoder_patch_size} and computed num_patches={num_patches}."
)
def _to_config_compatible(value: Any) -> Any:
if isinstance(value, torch.Tensor):
return value.detach().cpu().tolist()
if isinstance(value, tuple):
return [_to_config_compatible(v) for v in value]
if isinstance(value, list):
return [_to_config_compatible(v) for v in value]
return value
def _as_optional_tensor(value: torch.Tensor | list | tuple | None) -> torch.Tensor | None:
if value is None:
return None
if isinstance(value, torch.Tensor):
return value.detach().clone()
return torch.tensor(value, dtype=torch.float32)
latents_std_tensor = _as_optional_tensor(latents_std)
# Ensure config values are JSON-serializable (list/None), even if caller passes torch.Tensors.
self.register_to_config(
latents_mean=_to_config_compatible(latents_mean),
latents_std=_to_config_compatible(latents_std),
)
# Frozen representation encoder (built from config, no downloads)
self.encoder: nn.Module = _build_encoder(
encoder_type=encoder_type,
hidden_size=encoder_hidden_size,
patch_size=encoder_patch_size,
num_hidden_layers=encoder_num_hidden_layers,
)
self._encoder_forward_fn = _ENCODER_FORWARD_FNS[encoder_type]
num_patches = (encoder_input_size // encoder_patch_size) ** 2
# Encoder input normalization stats (ImageNet defaults)
if encoder_norm_mean is None:
encoder_norm_mean = [0.485, 0.456, 0.406]
if encoder_norm_std is None:
encoder_norm_std = [0.229, 0.224, 0.225]
encoder_mean_tensor = torch.tensor(encoder_norm_mean, dtype=torch.float32).view(1, 3, 1, 1)
encoder_std_tensor = torch.tensor(encoder_norm_std, dtype=torch.float32).view(1, 3, 1, 1)
self.register_buffer("encoder_mean", encoder_mean_tensor, persistent=True)
self.register_buffer("encoder_std", encoder_std_tensor, persistent=True)
# Latent normalization buffers (defaults are no-ops; actual values come from checkpoint)
latents_mean_tensor = _as_optional_tensor(latents_mean)
if latents_mean_tensor is None:
latents_mean_tensor = torch.zeros(1)
self.register_buffer("_latents_mean", latents_mean_tensor, persistent=True)
if latents_std_tensor is None:
latents_std_tensor = torch.ones(1)
self.register_buffer("_latents_std", latents_std_tensor, persistent=True)
# ViT-MAE style decoder
self.decoder = RAEDecoder(
hidden_size=int(encoder_hidden_size),
decoder_hidden_size=int(decoder_hidden_size),
decoder_num_hidden_layers=int(decoder_num_hidden_layers),
decoder_num_attention_heads=int(decoder_num_attention_heads),
decoder_intermediate_size=int(decoder_intermediate_size),
num_patches=int(num_patches),
patch_size=int(decoder_patch_size),
num_channels=int(num_channels),
image_size=int(image_size),
)
self.num_patches = int(num_patches)
self.decoder_patch_size = int(decoder_patch_size)
self.decoder_image_size = int(image_size)
# Slicing support (batch dimension) similar to other diffusers autoencoders
self.use_slicing = False
def _noising(self, x: torch.Tensor, generator: torch.Generator | None = None) -> torch.Tensor:
# Per-sample random sigma in [0, noise_tau]
noise_sigma = self.config.noise_tau * torch.rand(
(x.size(0),) + (1,) * (x.ndim - 1), device=x.device, dtype=x.dtype, generator=generator
)
return x + noise_sigma * randn_tensor(x.shape, generator=generator, device=x.device, dtype=x.dtype)
def _resize_and_normalize(self, x: torch.Tensor) -> torch.Tensor:
_, _, h, w = x.shape
if h != self.config.encoder_input_size or w != self.config.encoder_input_size:
x = F.interpolate(
x,
size=(self.config.encoder_input_size, self.config.encoder_input_size),
mode="bicubic",
align_corners=False,
)
mean = self.encoder_mean.to(device=x.device, dtype=x.dtype)
std = self.encoder_std.to(device=x.device, dtype=x.dtype)
return (x - mean) / std
def _denormalize_image(self, x: torch.Tensor) -> torch.Tensor:
mean = self.encoder_mean.to(device=x.device, dtype=x.dtype)
std = self.encoder_std.to(device=x.device, dtype=x.dtype)
return x * std + mean
def _normalize_latents(self, z: torch.Tensor) -> torch.Tensor:
latents_mean = self._latents_mean.to(device=z.device, dtype=z.dtype)
latents_std = self._latents_std.to(device=z.device, dtype=z.dtype)
return (z - latents_mean) / (latents_std + 1e-5)
def _denormalize_latents(self, z: torch.Tensor) -> torch.Tensor:
latents_mean = self._latents_mean.to(device=z.device, dtype=z.dtype)
latents_std = self._latents_std.to(device=z.device, dtype=z.dtype)
return z * (latents_std + 1e-5) + latents_mean
def _encode(self, x: torch.Tensor, generator: torch.Generator | None = None) -> torch.Tensor:
x = self._resize_and_normalize(x)
if self.config.encoder_type == "mae":
tokens = self._encoder_forward_fn(self.encoder, x, self.config.encoder_patch_size)
else:
tokens = self._encoder_forward_fn(self.encoder, x) # (B, N, C)
if self.training and self.config.noise_tau > 0:
tokens = self._noising(tokens, generator=generator)
if self.config.reshape_to_2d:
b, n, c = tokens.shape
side = int(sqrt(n))
if side * side != n:
raise ValueError(f"Token length n={n} is not a perfect square; cannot reshape to 2D.")
z = tokens.transpose(1, 2).contiguous().view(b, c, side, side) # (B, C, h, w)
else:
z = tokens
z = self._normalize_latents(z)
# Follow diffusers convention: optionally scale latents for diffusion
if self.config.scaling_factor != 1.0:
z = z * self.config.scaling_factor
return z
@apply_forward_hook
def encode(
self, x: torch.Tensor, return_dict: bool = True, generator: torch.Generator | None = None
) -> EncoderOutput | tuple[torch.Tensor]:
if self.use_slicing and x.shape[0] > 1:
latents = torch.cat([self._encode(x_slice, generator=generator) for x_slice in x.split(1)], dim=0)
else:
latents = self._encode(x, generator=generator)
if not return_dict:
return (latents,)
return EncoderOutput(latent=latents)
def _decode(self, z: torch.Tensor) -> torch.Tensor:
# Undo scaling factor if applied at encode time
if self.config.scaling_factor != 1.0:
z = z / self.config.scaling_factor
z = self._denormalize_latents(z)
if self.config.reshape_to_2d:
b, c, h, w = z.shape
tokens = z.view(b, c, h * w).transpose(1, 2).contiguous() # (B, N, C)
else:
tokens = z
logits = self.decoder(tokens, return_dict=True).logits
x_rec = self.decoder.unpatchify(logits)
x_rec = self._denormalize_image(x_rec)
return x_rec
@apply_forward_hook
def decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput | tuple[torch.Tensor]:
if self.use_slicing and z.shape[0] > 1:
decoded = torch.cat([self._decode(z_slice) for z_slice in z.split(1)], dim=0)
else:
decoded = self._decode(z)
if not return_dict:
return (decoded,)
return DecoderOutput(sample=decoded)
def forward(
self, sample: torch.Tensor, return_dict: bool = True, generator: torch.Generator | None = None
) -> DecoderOutput | tuple[torch.Tensor]:
latents = self.encode(sample, return_dict=False, generator=generator)[0]
decoded = self.decode(latents, return_dict=False)[0]
if not return_dict:
return (decoded,)
return DecoderOutput(sample=decoded)

View File

@@ -656,6 +656,21 @@ class AutoencoderOobleck(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class AutoencoderRAE(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class AutoencoderTiny(metaclass=DummyObject):
_backends = ["torch"]

View File

@@ -0,0 +1,300 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import pytest
import torch
import torch.nn.functional as F
from torchvision.transforms.functional import to_tensor
import diffusers.models.autoencoders.autoencoder_rae as _rae_module
from diffusers.models.autoencoders.autoencoder_rae import (
_ENCODER_FORWARD_FNS,
AutoencoderRAE,
_build_encoder,
)
from diffusers.utils import load_image
from ...testing_utils import (
backend_empty_cache,
enable_full_determinism,
slow,
torch_all_close,
torch_device,
)
from ..testing_utils import BaseModelTesterConfig, ModelTesterMixin
from .testing_utils import AutoencoderTesterMixin
enable_full_determinism()
# ---------------------------------------------------------------------------
# Tiny test encoder for fast unit tests (no transformers dependency)
# ---------------------------------------------------------------------------
class _TinyTestEncoderModule(torch.nn.Module):
"""Minimal encoder that mimics the patch-token interface without any HF model."""
def __init__(self, hidden_size: int = 16, patch_size: int = 8, **kwargs):
super().__init__()
self.patch_size = patch_size
self.hidden_size = hidden_size
def forward(self, images: torch.Tensor) -> torch.Tensor:
pooled = F.avg_pool2d(images.mean(dim=1, keepdim=True), kernel_size=self.patch_size, stride=self.patch_size)
tokens = pooled.flatten(2).transpose(1, 2).contiguous()
return tokens.repeat(1, 1, self.hidden_size)
def _tiny_test_encoder_forward(model, images):
return model(images)
def _build_tiny_test_encoder(encoder_type, hidden_size, patch_size, num_hidden_layers):
return _TinyTestEncoderModule(hidden_size=hidden_size, patch_size=patch_size)
# Monkey-patch the dispatch tables so "tiny_test" is recognised by AutoencoderRAE
_ENCODER_FORWARD_FNS["tiny_test"] = _tiny_test_encoder_forward
_original_build_encoder = _build_encoder
def _patched_build_encoder(encoder_type, hidden_size, patch_size, num_hidden_layers):
if encoder_type == "tiny_test":
return _build_tiny_test_encoder(encoder_type, hidden_size, patch_size, num_hidden_layers)
return _original_build_encoder(encoder_type, hidden_size, patch_size, num_hidden_layers)
_rae_module._build_encoder = _patched_build_encoder
# ---------------------------------------------------------------------------
# Test config
# ---------------------------------------------------------------------------
class AutoencoderRAETesterConfig(BaseModelTesterConfig):
@property
def model_class(self):
return AutoencoderRAE
@property
def output_shape(self):
return (3, 16, 16)
def get_init_dict(self):
return {
"encoder_type": "tiny_test",
"encoder_hidden_size": 16,
"encoder_patch_size": 8,
"encoder_input_size": 32,
"patch_size": 4,
"image_size": 16,
"decoder_hidden_size": 32,
"decoder_num_hidden_layers": 1,
"decoder_num_attention_heads": 4,
"decoder_intermediate_size": 64,
"num_channels": 3,
"encoder_norm_mean": [0.5, 0.5, 0.5],
"encoder_norm_std": [0.5, 0.5, 0.5],
"noise_tau": 0.0,
"reshape_to_2d": True,
"scaling_factor": 1.0,
}
@property
def generator(self):
return torch.Generator("cpu").manual_seed(0)
def get_dummy_inputs(self):
return {"sample": torch.randn(2, 3, 32, 32, generator=self.generator, device="cpu").to(torch_device)}
# Bridge for AutoencoderTesterMixin which still uses the old interface
def prepare_init_args_and_inputs_for_common(self):
return self.get_init_dict(), self.get_dummy_inputs()
def _make_model(self, **overrides) -> AutoencoderRAE:
config = self.get_init_dict()
config.update(overrides)
return AutoencoderRAE(**config).to(torch_device)
class TestAutoEncoderRAE(AutoencoderRAETesterConfig, ModelTesterMixin):
"""Core model tests for AutoencoderRAE."""
@pytest.mark.skip(reason="AutoencoderRAE does not support torch dynamo yet")
def test_from_save_pretrained_dynamo(self): ...
def test_fast_encode_decode_and_forward_shapes(self):
model = self._make_model().eval()
x = torch.rand(2, 3, 32, 32, device=torch_device)
with torch.no_grad():
z = model.encode(x).latent
decoded = model.decode(z).sample
recon = model(x).sample
assert z.shape == (2, 16, 4, 4)
assert decoded.shape == (2, 3, 16, 16)
assert recon.shape == (2, 3, 16, 16)
assert torch.isfinite(recon).all().item()
def test_fast_scaling_factor_encode_and_decode_consistency(self):
torch.manual_seed(0)
model_base = self._make_model(scaling_factor=1.0).eval()
torch.manual_seed(0)
model_scaled = self._make_model(scaling_factor=2.0).eval()
x = torch.rand(2, 3, 32, 32, device=torch_device)
with torch.no_grad():
z_base = model_base.encode(x).latent
z_scaled = model_scaled.encode(x).latent
recon_base = model_base.decode(z_base).sample
recon_scaled = model_scaled.decode(z_scaled).sample
assert torch.allclose(z_scaled, z_base * 2.0, atol=1e-5, rtol=1e-4)
assert torch.allclose(recon_scaled, recon_base, atol=1e-5, rtol=1e-4)
def test_fast_latents_normalization_matches_formula(self):
latents_mean = torch.full((1, 16, 1, 1), 0.25, dtype=torch.float32)
latents_std = torch.full((1, 16, 1, 1), 2.0, dtype=torch.float32)
model_raw = self._make_model().eval()
model_norm = self._make_model(latents_mean=latents_mean, latents_std=latents_std).eval()
x = torch.rand(1, 3, 32, 32, device=torch_device)
with torch.no_grad():
z_raw = model_raw.encode(x).latent
z_norm = model_norm.encode(x).latent
expected = (z_raw - latents_mean.to(z_raw.device, z_raw.dtype)) / (
latents_std.to(z_raw.device, z_raw.dtype) + 1e-5
)
assert torch.allclose(z_norm, expected, atol=1e-5, rtol=1e-4)
def test_fast_slicing_matches_non_slicing(self):
model = self._make_model().eval()
x = torch.rand(3, 3, 32, 32, device=torch_device)
with torch.no_grad():
model.use_slicing = False
z_no_slice = model.encode(x).latent
out_no_slice = model.decode(z_no_slice).sample
model.use_slicing = True
z_slice = model.encode(x).latent
out_slice = model.decode(z_slice).sample
assert torch.allclose(z_slice, z_no_slice, atol=1e-6, rtol=1e-5)
assert torch.allclose(out_slice, out_no_slice, atol=1e-6, rtol=1e-5)
def test_fast_noise_tau_applies_only_in_train(self):
model = self._make_model(noise_tau=0.5).to(torch_device)
x = torch.rand(2, 3, 32, 32, device=torch_device)
model.train()
torch.manual_seed(0)
z_train_1 = model.encode(x).latent
torch.manual_seed(1)
z_train_2 = model.encode(x).latent
model.eval()
torch.manual_seed(0)
z_eval_1 = model.encode(x).latent
torch.manual_seed(1)
z_eval_2 = model.encode(x).latent
assert z_train_1.shape == z_eval_1.shape
assert not torch.allclose(z_train_1, z_train_2)
assert torch.allclose(z_eval_1, z_eval_2, atol=1e-6, rtol=1e-5)
class TestAutoEncoderRAESlicingTiling(AutoencoderRAETesterConfig, AutoencoderTesterMixin):
"""Slicing and tiling tests for AutoencoderRAE."""
@slow
@pytest.mark.skip(reason="Not enough model usage to justify slow tests yet.")
class AutoencoderRAEEncoderIntegrationTests:
def teardown_method(self):
gc.collect()
backend_empty_cache(torch_device)
def test_dinov2_encoder_forward_shape(self):
encoder = _build_encoder("dinov2", hidden_size=768, patch_size=14, num_hidden_layers=12).to(torch_device)
x = torch.rand(1, 3, 224, 224, device=torch_device)
y = _ENCODER_FORWARD_FNS["dinov2"](encoder, x)
assert y.ndim == 3
assert y.shape[0] == 1
assert y.shape[1] == 256 # (224/14)^2 - 5 (CLS + 4 register) = 251? Actually dinov2 has 256 patches
assert y.shape[2] == 768
def test_siglip2_encoder_forward_shape(self):
encoder = _build_encoder("siglip2", hidden_size=768, patch_size=16, num_hidden_layers=12).to(torch_device)
x = torch.rand(1, 3, 224, 224, device=torch_device)
y = _ENCODER_FORWARD_FNS["siglip2"](encoder, x)
assert y.ndim == 3
assert y.shape[0] == 1
assert y.shape[1] == 196 # (224/16)^2
assert y.shape[2] == 768
def test_mae_encoder_forward_shape(self):
encoder = _build_encoder("mae", hidden_size=768, patch_size=16, num_hidden_layers=12).to(torch_device)
x = torch.rand(1, 3, 224, 224, device=torch_device)
y = _ENCODER_FORWARD_FNS["mae"](encoder, x, patch_size=16)
assert y.ndim == 3
assert y.shape[0] == 1
assert y.shape[1] == 196 # (224/16)^2
assert y.shape[2] == 768
@slow
@pytest.mark.skip(reason="Not enough model usage to justify slow tests yet.")
class AutoencoderRAEIntegrationTests:
def teardown_method(self):
gc.collect()
backend_empty_cache(torch_device)
def test_autoencoder_rae_from_pretrained_dinov2(self):
model = AutoencoderRAE.from_pretrained("nyu-visionx/RAE-dinov2-wReg-base-ViTXL-n08").to(torch_device)
model.eval()
image = load_image(
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cat.png"
)
image = image.convert("RGB").resize((224, 224))
x = to_tensor(image).unsqueeze(0).to(torch_device)
with torch.no_grad():
latents = model.encode(x).latent
assert latents.shape == (1, 768, 16, 16)
recon = model.decode(latents).sample
assert recon.shape == (1, 3, 256, 256)
assert torch.isfinite(recon).all().item()
# fmt: off
expected_latent_slice = torch.tensor([0.7617, 0.8824, -0.4891])
expected_recon_slice = torch.tensor([0.1263, 0.1355, 0.1435])
# fmt: on
assert torch_all_close(latents[0, :3, 0, 0].float().cpu(), expected_latent_slice, atol=1e-3)
assert torch_all_close(recon[0, 0, 0, :3].float().cpu(), expected_recon_slice, atol=1e-3)