mirror of
https://github.com/huggingface/diffusers.git
synced 2026-03-07 17:21:48 +08:00
Compare commits
65 Commits
main
...
Ando233-ra
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
403d3f20f7 | ||
|
|
441224ac00 | ||
|
|
af0bed007a | ||
|
|
ed9bcfd7a9 | ||
|
|
05d3edca66 | ||
|
|
f4ec0f1443 | ||
|
|
fa016b196c | ||
|
|
33d98a85da | ||
|
|
14d918ee88 | ||
|
|
bc59324a2f | ||
|
|
b9a5266cec | ||
|
|
876e930780 | ||
|
|
df1af7d907 | ||
|
|
af75d8b9e2 | ||
|
|
e805be989e | ||
|
|
3958fda3bf | ||
|
|
196f8a36c7 | ||
|
|
9c0f96b303 | ||
|
|
bc71889852 | ||
|
|
3a6689518f | ||
|
|
5817416a19 | ||
|
|
e834e498b2 | ||
|
|
f15873af72 | ||
|
|
bff48d317e | ||
|
|
cd86873ea6 | ||
|
|
34787e5b9b | ||
|
|
9ada5768e5 | ||
|
|
8861a8082a | ||
|
|
03e757ca73 | ||
|
|
c717498fa3 | ||
|
|
1b4a43f59d | ||
|
|
6a78767864 | ||
|
|
663b580418 | ||
|
|
d965cabe79 | ||
|
|
5c85781519 | ||
|
|
c71cb44299 | ||
|
|
dca59233f6 | ||
|
|
b3ffd6344a | ||
|
|
7debd07541 | ||
|
|
b297868201 | ||
|
|
28a02eb226 | ||
|
|
61885f37e3 | ||
|
|
c68b812cb0 | ||
|
|
d8b2983b9e | ||
|
|
d06b501850 | ||
|
|
a4fc9f64b2 | ||
|
|
fc5295951a | ||
|
|
96520c4ff1 | ||
|
|
d3cbd5a60b | ||
|
|
906d79a432 | ||
|
|
9522e68a5b | ||
|
|
6a9bde6964 | ||
|
|
e6d449933d | ||
|
|
7cbbf271f3 | ||
|
|
202b14f6a4 | ||
|
|
0d59b22732 | ||
|
|
d7cb12470b | ||
|
|
f06ea7a901 | ||
|
|
25bc9e334c | ||
|
|
24acab0bcc | ||
|
|
0850c8cdc9 | ||
|
|
3ecf89d044 | ||
|
|
a3926d77d7 | ||
|
|
f82cecc298 | ||
|
|
382aad0a6c |
@@ -460,6 +460,8 @@
|
||||
title: AutoencoderKLQwenImage
|
||||
- local: api/models/autoencoder_kl_wan
|
||||
title: AutoencoderKLWan
|
||||
- local: api/models/autoencoder_rae
|
||||
title: AutoencoderRAE
|
||||
- local: api/models/consistency_decoder_vae
|
||||
title: ConsistencyDecoderVAE
|
||||
- local: api/models/autoencoder_oobleck
|
||||
|
||||
89
docs/source/en/api/models/autoencoder_rae.md
Normal file
89
docs/source/en/api/models/autoencoder_rae.md
Normal file
@@ -0,0 +1,89 @@
|
||||
<!-- Copyright 2026 The NYU Vision-X and HuggingFace Teams. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# AutoencoderRAE
|
||||
|
||||
The Representation Autoencoder (RAE) model introduced in [Diffusion Transformers with Representation Autoencoders](https://huggingface.co/papers/2510.11690) by Boyang Zheng, Nanye Ma, Shengbang Tong, Saining Xie from NYU VISIONx.
|
||||
|
||||
RAE combines a frozen pretrained vision encoder (DINOv2, SigLIP2, or MAE) with a trainable ViT-MAE-style decoder. In the two-stage RAE training recipe, the autoencoder is trained in stage 1 (reconstruction), and then a diffusion model is trained on the resulting latent space in stage 2 (generation).
|
||||
|
||||
The following RAE models are released and supported in Diffusers:
|
||||
|
||||
| Model | Encoder | Latent shape (224px input) |
|
||||
|:------|:--------|:---------------------------|
|
||||
| [`nyu-visionx/RAE-dinov2-wReg-base-ViTXL-n08`](https://huggingface.co/nyu-visionx/RAE-dinov2-wReg-base-ViTXL-n08) | DINOv2-base | 768 x 16 x 16 |
|
||||
| [`nyu-visionx/RAE-dinov2-wReg-base-ViTXL-n08-i512`](https://huggingface.co/nyu-visionx/RAE-dinov2-wReg-base-ViTXL-n08-i512) | DINOv2-base (512px) | 768 x 32 x 32 |
|
||||
| [`nyu-visionx/RAE-dinov2-wReg-small-ViTXL-n08`](https://huggingface.co/nyu-visionx/RAE-dinov2-wReg-small-ViTXL-n08) | DINOv2-small | 384 x 16 x 16 |
|
||||
| [`nyu-visionx/RAE-dinov2-wReg-large-ViTXL-n08`](https://huggingface.co/nyu-visionx/RAE-dinov2-wReg-large-ViTXL-n08) | DINOv2-large | 1024 x 16 x 16 |
|
||||
| [`nyu-visionx/RAE-siglip2-base-p16-i256-ViTXL-n08`](https://huggingface.co/nyu-visionx/RAE-siglip2-base-p16-i256-ViTXL-n08) | SigLIP2-base | 768 x 16 x 16 |
|
||||
| [`nyu-visionx/RAE-mae-base-p16-ViTXL-n08`](https://huggingface.co/nyu-visionx/RAE-mae-base-p16-ViTXL-n08) | MAE-base | 768 x 16 x 16 |
|
||||
|
||||
## Loading a pretrained model
|
||||
|
||||
```python
|
||||
from diffusers import AutoencoderRAE
|
||||
|
||||
model = AutoencoderRAE.from_pretrained(
|
||||
"nyu-visionx/RAE-dinov2-wReg-base-ViTXL-n08"
|
||||
).to("cuda").eval()
|
||||
```
|
||||
|
||||
## Encoding and decoding a real image
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import AutoencoderRAE
|
||||
from diffusers.utils import load_image
|
||||
from torchvision.transforms.functional import to_tensor, to_pil_image
|
||||
|
||||
model = AutoencoderRAE.from_pretrained(
|
||||
"nyu-visionx/RAE-dinov2-wReg-base-ViTXL-n08"
|
||||
).to("cuda").eval()
|
||||
|
||||
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cat.png")
|
||||
image = image.convert("RGB").resize((224, 224))
|
||||
x = to_tensor(image).unsqueeze(0).to("cuda") # (1, 3, 224, 224), values in [0, 1]
|
||||
|
||||
with torch.no_grad():
|
||||
latents = model.encode(x).latent # (1, 768, 16, 16)
|
||||
recon = model.decode(latents).sample # (1, 3, 256, 256)
|
||||
|
||||
recon_image = to_pil_image(recon[0].clamp(0, 1).cpu())
|
||||
recon_image.save("recon.png")
|
||||
```
|
||||
|
||||
## Latent normalization
|
||||
|
||||
Some pretrained checkpoints include per-channel `latents_mean` and `latents_std` statistics for normalizing the latent space. When present, `encode` and `decode` automatically apply the normalization and denormalization, respectively.
|
||||
|
||||
```python
|
||||
model = AutoencoderRAE.from_pretrained(
|
||||
"nyu-visionx/RAE-dinov2-wReg-base-ViTXL-n08"
|
||||
).to("cuda").eval()
|
||||
|
||||
# Latent normalization is handled automatically inside encode/decode
|
||||
# when the checkpoint config includes latents_mean/latents_std.
|
||||
with torch.no_grad():
|
||||
latents = model.encode(x).latent # normalized latents
|
||||
recon = model.decode(latents).sample
|
||||
```
|
||||
|
||||
## AutoencoderRAE
|
||||
|
||||
[[autodoc]] AutoencoderRAE
|
||||
- encode
|
||||
- decode
|
||||
- all
|
||||
|
||||
## DecoderOutput
|
||||
|
||||
[[autodoc]] models.autoencoders.vae.DecoderOutput
|
||||
66
examples/research_projects/autoencoder_rae/README.md
Normal file
66
examples/research_projects/autoencoder_rae/README.md
Normal file
@@ -0,0 +1,66 @@
|
||||
# Training AutoencoderRAE
|
||||
|
||||
This example trains the decoder of `AutoencoderRAE` (stage-1 style), while keeping the representation encoder frozen.
|
||||
|
||||
It follows the same high-level training recipe as the official RAE stage-1 setup:
|
||||
- frozen encoder
|
||||
- train decoder
|
||||
- pixel reconstruction loss
|
||||
- optional encoder feature consistency loss
|
||||
|
||||
## Quickstart
|
||||
|
||||
### Resume or finetune from pretrained weights
|
||||
|
||||
```bash
|
||||
accelerate launch examples/research_projects/autoencoder_rae/train_autoencoder_rae.py \
|
||||
--pretrained_model_name_or_path nyu-visionx/RAE-dinov2-wReg-base-ViTXL-n08 \
|
||||
--train_data_dir /path/to/imagenet_like_folder \
|
||||
--output_dir /tmp/autoencoder-rae \
|
||||
--resolution 256 \
|
||||
--train_batch_size 8 \
|
||||
--learning_rate 1e-4 \
|
||||
--num_train_epochs 10 \
|
||||
--report_to wandb \
|
||||
--reconstruction_loss_type l1 \
|
||||
--use_encoder_loss \
|
||||
--encoder_loss_weight 0.1
|
||||
```
|
||||
|
||||
### Train from scratch with a pretrained encoder
|
||||
The following command launches RAE training with "facebook/dinov2-with-registers-base" as the base.
|
||||
|
||||
```bash
|
||||
accelerate launch examples/research_projects/autoencoder_rae/train_autoencoder_rae.py \
|
||||
--train_data_dir /path/to/imagenet_like_folder \
|
||||
--output_dir /tmp/autoencoder-rae \
|
||||
--resolution 256 \
|
||||
--encoder_type dinov2 \
|
||||
--encoder_name_or_path facebook/dinov2-with-registers-base \
|
||||
--encoder_input_size 224 \
|
||||
--patch_size 16 \
|
||||
--image_size 256 \
|
||||
--decoder_hidden_size 1152 \
|
||||
--decoder_num_hidden_layers 28 \
|
||||
--decoder_num_attention_heads 16 \
|
||||
--decoder_intermediate_size 4096 \
|
||||
--train_batch_size 8 \
|
||||
--learning_rate 1e-4 \
|
||||
--num_train_epochs 10 \
|
||||
--report_to wandb \
|
||||
--reconstruction_loss_type l1 \
|
||||
--use_encoder_loss \
|
||||
--encoder_loss_weight 0.1
|
||||
```
|
||||
|
||||
Note: stage-1 reconstruction loss assumes matching target/output spatial size, so `--resolution` must equal `--image_size`.
|
||||
|
||||
Dataset format is expected to be `ImageFolder`-compatible:
|
||||
|
||||
```text
|
||||
train_data_dir/
|
||||
class_a/
|
||||
img_0001.jpg
|
||||
class_b/
|
||||
img_0002.jpg
|
||||
```
|
||||
@@ -0,0 +1,405 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from accelerate import Accelerator
|
||||
from accelerate.logging import get_logger
|
||||
from accelerate.utils import ProjectConfiguration, set_seed
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision import transforms
|
||||
from torchvision.datasets import ImageFolder
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from diffusers import AutoencoderRAE
|
||||
from diffusers.optimization import get_scheduler
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Train a stage-1 Representation Autoencoder (RAE) decoder.")
|
||||
parser.add_argument(
|
||||
"--train_data_dir",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to an ImageFolder-style dataset root.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir", type=str, default="autoencoder-rae", help="Directory to save checkpoints/model."
|
||||
)
|
||||
parser.add_argument("--logging_dir", type=str, default="logs", help="Accelerate logging directory.")
|
||||
parser.add_argument("--seed", type=int, default=42)
|
||||
|
||||
parser.add_argument("--resolution", type=int, default=256)
|
||||
parser.add_argument("--center_crop", action="store_true")
|
||||
parser.add_argument("--random_flip", action="store_true")
|
||||
|
||||
parser.add_argument("--train_batch_size", type=int, default=8)
|
||||
parser.add_argument("--dataloader_num_workers", type=int, default=4)
|
||||
parser.add_argument("--num_train_epochs", type=int, default=10)
|
||||
parser.add_argument("--max_train_steps", type=int, default=None)
|
||||
parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
|
||||
parser.add_argument("--max_grad_norm", type=float, default=1.0)
|
||||
|
||||
parser.add_argument("--learning_rate", type=float, default=1e-4)
|
||||
parser.add_argument("--adam_beta1", type=float, default=0.9)
|
||||
parser.add_argument("--adam_beta2", type=float, default=0.999)
|
||||
parser.add_argument("--adam_weight_decay", type=float, default=1e-2)
|
||||
parser.add_argument("--adam_epsilon", type=float, default=1e-8)
|
||||
parser.add_argument("--lr_scheduler", type=str, default="cosine")
|
||||
parser.add_argument("--lr_warmup_steps", type=int, default=500)
|
||||
|
||||
parser.add_argument("--checkpointing_steps", type=int, default=1000)
|
||||
parser.add_argument("--validation_steps", type=int, default=500)
|
||||
|
||||
parser.add_argument(
|
||||
"--pretrained_model_name_or_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to a pretrained AutoencoderRAE model (or HF Hub id) to resume training from.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--encoder_name_or_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help=(
|
||||
"HF Hub id or local path of the pretrained encoder (e.g. 'facebook/dinov2-with-registers-base'). "
|
||||
"When --pretrained_model_name_or_path is not set, the encoder weights are loaded from this path "
|
||||
"into a freshly constructed AutoencoderRAE. Ignored when --pretrained_model_name_or_path is set."
|
||||
),
|
||||
)
|
||||
|
||||
parser.add_argument("--encoder_type", type=str, choices=["dinov2", "siglip2", "mae"], default="dinov2")
|
||||
parser.add_argument("--encoder_hidden_size", type=int, default=768)
|
||||
parser.add_argument("--encoder_patch_size", type=int, default=14)
|
||||
parser.add_argument("--encoder_num_hidden_layers", type=int, default=12)
|
||||
parser.add_argument("--encoder_input_size", type=int, default=224)
|
||||
parser.add_argument("--patch_size", type=int, default=16)
|
||||
parser.add_argument("--image_size", type=int, default=256)
|
||||
parser.add_argument("--num_channels", type=int, default=3)
|
||||
|
||||
parser.add_argument("--decoder_hidden_size", type=int, default=1152)
|
||||
parser.add_argument("--decoder_num_hidden_layers", type=int, default=28)
|
||||
parser.add_argument("--decoder_num_attention_heads", type=int, default=16)
|
||||
parser.add_argument("--decoder_intermediate_size", type=int, default=4096)
|
||||
|
||||
parser.add_argument("--noise_tau", type=float, default=0.0)
|
||||
parser.add_argument("--scaling_factor", type=float, default=1.0)
|
||||
parser.add_argument("--reshape_to_2d", action=argparse.BooleanOptionalAction, default=True)
|
||||
|
||||
parser.add_argument(
|
||||
"--reconstruction_loss_type",
|
||||
type=str,
|
||||
choices=["l1", "mse"],
|
||||
default="l1",
|
||||
help="Pixel reconstruction loss.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--encoder_loss_weight",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="Weight for encoder feature consistency loss in the training loop.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_encoder_loss",
|
||||
action="store_true",
|
||||
help="Enable encoder feature consistency loss term in the training loop.",
|
||||
)
|
||||
parser.add_argument("--report_to", type=str, default="tensorboard")
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def build_transforms(args):
|
||||
image_transforms = [
|
||||
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BICUBIC),
|
||||
transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
|
||||
]
|
||||
if args.random_flip:
|
||||
image_transforms.append(transforms.RandomHorizontalFlip())
|
||||
image_transforms.append(transforms.ToTensor())
|
||||
return transforms.Compose(image_transforms)
|
||||
|
||||
|
||||
def compute_losses(
|
||||
model, pixel_values, reconstruction_loss_type: str, use_encoder_loss: bool, encoder_loss_weight: float
|
||||
):
|
||||
decoded = model(pixel_values).sample
|
||||
|
||||
if decoded.shape[-2:] != pixel_values.shape[-2:]:
|
||||
raise ValueError(
|
||||
"Training requires matching reconstruction and target sizes, got "
|
||||
f"decoded={tuple(decoded.shape[-2:])}, target={tuple(pixel_values.shape[-2:])}."
|
||||
)
|
||||
|
||||
if reconstruction_loss_type == "l1":
|
||||
reconstruction_loss = F.l1_loss(decoded.float(), pixel_values.float())
|
||||
else:
|
||||
reconstruction_loss = F.mse_loss(decoded.float(), pixel_values.float())
|
||||
|
||||
encoder_loss = torch.zeros_like(reconstruction_loss)
|
||||
if use_encoder_loss and encoder_loss_weight > 0:
|
||||
base_model = model.module if hasattr(model, "module") else model
|
||||
target_encoder_input = base_model._resize_and_normalize(pixel_values)
|
||||
reconstructed_encoder_input = base_model._resize_and_normalize(decoded)
|
||||
|
||||
encoder_forward_kwargs = {"model": base_model.encoder}
|
||||
if base_model.config.encoder_type == "mae":
|
||||
encoder_forward_kwargs["patch_size"] = base_model.config.encoder_patch_size
|
||||
with torch.no_grad():
|
||||
target_tokens = base_model._encoder_forward_fn(images=target_encoder_input, **encoder_forward_kwargs)
|
||||
reconstructed_tokens = base_model._encoder_forward_fn(
|
||||
images=reconstructed_encoder_input, **encoder_forward_kwargs
|
||||
)
|
||||
encoder_loss = F.mse_loss(reconstructed_tokens.float(), target_tokens.float())
|
||||
|
||||
loss = reconstruction_loss + float(encoder_loss_weight) * encoder_loss
|
||||
return decoded, loss, reconstruction_loss, encoder_loss
|
||||
|
||||
|
||||
def _strip_final_layernorm_affine(state_dict, prefix=""):
|
||||
"""Remove final layernorm weight/bias so the model keeps its default init (identity)."""
|
||||
keys_to_strip = {f"{prefix}weight", f"{prefix}bias"}
|
||||
return {k: v for k, v in state_dict.items() if k not in keys_to_strip}
|
||||
|
||||
|
||||
def _load_pretrained_encoder_weights(model, encoder_type, encoder_name_or_path):
|
||||
"""Load pretrained HF transformers encoder weights into the model's encoder."""
|
||||
if encoder_type == "dinov2":
|
||||
from transformers import Dinov2WithRegistersModel
|
||||
|
||||
hf_encoder = Dinov2WithRegistersModel.from_pretrained(encoder_name_or_path)
|
||||
state_dict = hf_encoder.state_dict()
|
||||
state_dict = _strip_final_layernorm_affine(state_dict, prefix="layernorm.")
|
||||
elif encoder_type == "siglip2":
|
||||
from transformers import SiglipModel
|
||||
|
||||
hf_encoder = SiglipModel.from_pretrained(encoder_name_or_path).vision_model
|
||||
state_dict = {f"vision_model.{k}": v for k, v in hf_encoder.state_dict().items()}
|
||||
state_dict = _strip_final_layernorm_affine(state_dict, prefix="vision_model.post_layernorm.")
|
||||
elif encoder_type == "mae":
|
||||
from transformers import ViTMAEForPreTraining
|
||||
|
||||
hf_encoder = ViTMAEForPreTraining.from_pretrained(encoder_name_or_path).vit
|
||||
state_dict = hf_encoder.state_dict()
|
||||
state_dict = _strip_final_layernorm_affine(state_dict, prefix="layernorm.")
|
||||
else:
|
||||
raise ValueError(f"Unknown encoder_type: {encoder_type}")
|
||||
|
||||
model.encoder.load_state_dict(state_dict, strict=False)
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
if args.resolution != args.image_size:
|
||||
raise ValueError(
|
||||
f"`--resolution` ({args.resolution}) must match `--image_size` ({args.image_size}) "
|
||||
"for stage-1 reconstruction loss."
|
||||
)
|
||||
|
||||
logging_dir = Path(args.output_dir, args.logging_dir)
|
||||
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
|
||||
accelerator = Accelerator(
|
||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||
project_config=accelerator_project_config,
|
||||
log_with=args.report_to,
|
||||
)
|
||||
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
level=logging.INFO,
|
||||
)
|
||||
logger.info(accelerator.state, main_process_only=False)
|
||||
|
||||
if args.seed is not None:
|
||||
set_seed(args.seed)
|
||||
|
||||
if accelerator.is_main_process:
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
dataset = ImageFolder(args.train_data_dir, transform=build_transforms(args))
|
||||
|
||||
def collate_fn(examples):
|
||||
pixel_values = torch.stack([example[0] for example in examples]).float()
|
||||
return {"pixel_values": pixel_values}
|
||||
|
||||
train_dataloader = DataLoader(
|
||||
dataset,
|
||||
shuffle=True,
|
||||
collate_fn=collate_fn,
|
||||
batch_size=args.train_batch_size,
|
||||
num_workers=args.dataloader_num_workers,
|
||||
pin_memory=True,
|
||||
drop_last=True,
|
||||
)
|
||||
|
||||
if args.pretrained_model_name_or_path is not None:
|
||||
model = AutoencoderRAE.from_pretrained(args.pretrained_model_name_or_path)
|
||||
logger.info(f"Loaded pretrained AutoencoderRAE from {args.pretrained_model_name_or_path}")
|
||||
else:
|
||||
model = AutoencoderRAE(
|
||||
encoder_type=args.encoder_type,
|
||||
encoder_hidden_size=args.encoder_hidden_size,
|
||||
encoder_patch_size=args.encoder_patch_size,
|
||||
encoder_num_hidden_layers=args.encoder_num_hidden_layers,
|
||||
decoder_hidden_size=args.decoder_hidden_size,
|
||||
decoder_num_hidden_layers=args.decoder_num_hidden_layers,
|
||||
decoder_num_attention_heads=args.decoder_num_attention_heads,
|
||||
decoder_intermediate_size=args.decoder_intermediate_size,
|
||||
patch_size=args.patch_size,
|
||||
encoder_input_size=args.encoder_input_size,
|
||||
image_size=args.image_size,
|
||||
num_channels=args.num_channels,
|
||||
noise_tau=args.noise_tau,
|
||||
reshape_to_2d=args.reshape_to_2d,
|
||||
use_encoder_loss=args.use_encoder_loss,
|
||||
scaling_factor=args.scaling_factor,
|
||||
)
|
||||
if args.encoder_name_or_path is not None:
|
||||
_load_pretrained_encoder_weights(model, args.encoder_type, args.encoder_name_or_path)
|
||||
logger.info(f"Loaded pretrained encoder weights from {args.encoder_name_or_path}")
|
||||
model.encoder.requires_grad_(False)
|
||||
model.decoder.requires_grad_(True)
|
||||
model.train()
|
||||
|
||||
optimizer = torch.optim.AdamW(
|
||||
(p for p in model.parameters() if p.requires_grad),
|
||||
lr=args.learning_rate,
|
||||
betas=(args.adam_beta1, args.adam_beta2),
|
||||
weight_decay=args.adam_weight_decay,
|
||||
eps=args.adam_epsilon,
|
||||
)
|
||||
|
||||
overrode_max_train_steps = False
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
if args.max_train_steps is None:
|
||||
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
||||
overrode_max_train_steps = True
|
||||
|
||||
lr_scheduler = get_scheduler(
|
||||
args.lr_scheduler,
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
|
||||
num_training_steps=args.max_train_steps * accelerator.num_processes,
|
||||
)
|
||||
|
||||
model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
model, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
|
||||
if overrode_max_train_steps:
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
||||
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||
|
||||
if accelerator.is_main_process:
|
||||
accelerator.init_trackers("train_autoencoder_rae", config=vars(args))
|
||||
|
||||
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(f" Num examples = {len(dataset)}")
|
||||
logger.info(f" Num Epochs = {args.num_train_epochs}")
|
||||
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
|
||||
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
||||
logger.info(f" Total optimization steps = {args.max_train_steps}")
|
||||
|
||||
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
|
||||
progress_bar.set_description("Steps")
|
||||
global_step = 0
|
||||
|
||||
for epoch in range(args.num_train_epochs):
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
with accelerator.accumulate(model):
|
||||
pixel_values = batch["pixel_values"]
|
||||
|
||||
_, loss, reconstruction_loss, encoder_loss = compute_losses(
|
||||
model,
|
||||
pixel_values,
|
||||
reconstruction_loss_type=args.reconstruction_loss_type,
|
||||
use_encoder_loss=args.use_encoder_loss,
|
||||
encoder_loss_weight=args.encoder_loss_weight,
|
||||
)
|
||||
|
||||
accelerator.backward(loss)
|
||||
if accelerator.sync_gradients:
|
||||
accelerator.clip_grad_norm_(model.parameters(), args.max_grad_norm)
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
if accelerator.sync_gradients:
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
|
||||
logs = {
|
||||
"loss": loss.detach().item(),
|
||||
"reconstruction_loss": reconstruction_loss.detach().item(),
|
||||
"encoder_loss": encoder_loss.detach().item(),
|
||||
"lr": lr_scheduler.get_last_lr()[0],
|
||||
}
|
||||
progress_bar.set_postfix(**logs)
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
if global_step % args.validation_steps == 0:
|
||||
with torch.no_grad():
|
||||
_, val_loss, val_reconstruction_loss, val_encoder_loss = compute_losses(
|
||||
model,
|
||||
pixel_values,
|
||||
reconstruction_loss_type=args.reconstruction_loss_type,
|
||||
use_encoder_loss=args.use_encoder_loss,
|
||||
encoder_loss_weight=args.encoder_loss_weight,
|
||||
)
|
||||
accelerator.log(
|
||||
{
|
||||
"val/loss": val_loss.detach().item(),
|
||||
"val/reconstruction_loss": val_reconstruction_loss.detach().item(),
|
||||
"val/encoder_loss": val_encoder_loss.detach().item(),
|
||||
},
|
||||
step=global_step,
|
||||
)
|
||||
|
||||
if global_step % args.checkpointing_steps == 0:
|
||||
if accelerator.is_main_process:
|
||||
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
||||
unwrapped_model = accelerator.unwrap_model(model)
|
||||
unwrapped_model.save_pretrained(save_path)
|
||||
logger.info(f"Saved checkpoint to {save_path}")
|
||||
|
||||
if global_step >= args.max_train_steps:
|
||||
break
|
||||
|
||||
if global_step >= args.max_train_steps:
|
||||
break
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
unwrapped_model = accelerator.unwrap_model(model)
|
||||
unwrapped_model.save_pretrained(args.output_dir)
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
403
scripts/convert_rae_to_diffusers.py
Normal file
403
scripts/convert_rae_to_diffusers.py
Normal file
@@ -0,0 +1,403 @@
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from huggingface_hub import HfApi, hf_hub_download
|
||||
|
||||
from diffusers import AutoencoderRAE
|
||||
|
||||
|
||||
DECODER_CONFIGS = {
|
||||
"ViTB": {
|
||||
"decoder_hidden_size": 768,
|
||||
"decoder_intermediate_size": 3072,
|
||||
"decoder_num_attention_heads": 12,
|
||||
"decoder_num_hidden_layers": 12,
|
||||
},
|
||||
"ViTL": {
|
||||
"decoder_hidden_size": 1024,
|
||||
"decoder_intermediate_size": 4096,
|
||||
"decoder_num_attention_heads": 16,
|
||||
"decoder_num_hidden_layers": 24,
|
||||
},
|
||||
"ViTXL": {
|
||||
"decoder_hidden_size": 1152,
|
||||
"decoder_intermediate_size": 4096,
|
||||
"decoder_num_attention_heads": 16,
|
||||
"decoder_num_hidden_layers": 28,
|
||||
},
|
||||
}
|
||||
|
||||
ENCODER_DEFAULT_NAME_OR_PATH = {
|
||||
"dinov2": "facebook/dinov2-with-registers-base",
|
||||
"siglip2": "google/siglip2-base-patch16-256",
|
||||
"mae": "facebook/vit-mae-base",
|
||||
}
|
||||
|
||||
ENCODER_HIDDEN_SIZE = {
|
||||
"dinov2": 768,
|
||||
"siglip2": 768,
|
||||
"mae": 768,
|
||||
}
|
||||
|
||||
ENCODER_PATCH_SIZE = {
|
||||
"dinov2": 14,
|
||||
"siglip2": 16,
|
||||
"mae": 16,
|
||||
}
|
||||
|
||||
DEFAULT_DECODER_SUBDIR = {
|
||||
"dinov2": "decoders/dinov2/wReg_base",
|
||||
"mae": "decoders/mae/base_p16",
|
||||
"siglip2": "decoders/siglip2/base_p16_i256",
|
||||
}
|
||||
|
||||
DEFAULT_STATS_SUBDIR = {
|
||||
"dinov2": "stats/dinov2/wReg_base",
|
||||
"mae": "stats/mae/base_p16",
|
||||
"siglip2": "stats/siglip2/base_p16_i256",
|
||||
}
|
||||
|
||||
DECODER_FILE_CANDIDATES = ("dinov2_decoder.pt", "model.pt")
|
||||
STATS_FILE_CANDIDATES = ("stat.pt",)
|
||||
|
||||
|
||||
def dataset_case_candidates(name: str) -> tuple[str, ...]:
|
||||
return (name, name.lower(), name.upper(), name.title(), "imagenet1k", "ImageNet1k")
|
||||
|
||||
|
||||
class RepoAccessor:
|
||||
def __init__(self, repo_or_path: str, cache_dir: str | None = None):
|
||||
self.repo_or_path = repo_or_path
|
||||
self.cache_dir = cache_dir
|
||||
self.local_root: Path | None = None
|
||||
self.repo_id: str | None = None
|
||||
self.repo_files: set[str] | None = None
|
||||
|
||||
root = Path(repo_or_path)
|
||||
if root.exists() and root.is_dir():
|
||||
self.local_root = root
|
||||
else:
|
||||
self.repo_id = repo_or_path
|
||||
self.repo_files = set(HfApi().list_repo_files(repo_or_path))
|
||||
|
||||
def exists(self, relative_path: str) -> bool:
|
||||
relative_path = relative_path.replace("\\", "/")
|
||||
if self.local_root is not None:
|
||||
return (self.local_root / relative_path).is_file()
|
||||
return relative_path in self.repo_files
|
||||
|
||||
def fetch(self, relative_path: str) -> Path:
|
||||
relative_path = relative_path.replace("\\", "/")
|
||||
if self.local_root is not None:
|
||||
return self.local_root / relative_path
|
||||
downloaded = hf_hub_download(repo_id=self.repo_id, filename=relative_path, cache_dir=self.cache_dir)
|
||||
return Path(downloaded)
|
||||
|
||||
|
||||
def unwrap_state_dict(maybe_wrapped: dict[str, Any]) -> dict[str, Any]:
|
||||
state_dict = maybe_wrapped
|
||||
for k in ("model", "module", "state_dict"):
|
||||
if isinstance(state_dict, dict) and k in state_dict and isinstance(state_dict[k], dict):
|
||||
state_dict = state_dict[k]
|
||||
|
||||
out = dict(state_dict)
|
||||
if len(out) > 0 and all(key.startswith("module.") for key in out):
|
||||
out = {key[len("module.") :]: value for key, value in out.items()}
|
||||
if len(out) > 0 and all(key.startswith("decoder.") for key in out):
|
||||
out = {key[len("decoder.") :]: value for key, value in out.items()}
|
||||
return out
|
||||
|
||||
|
||||
def remap_decoder_attention_keys_for_diffusers(state_dict: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Map official RAE decoder attention key layout to diffusers Attention layout used by AutoencoderRAE decoder.
|
||||
|
||||
Example mappings:
|
||||
- `...attention.attention.query.*` -> `...attention.to_q.*`
|
||||
- `...attention.attention.key.*` -> `...attention.to_k.*`
|
||||
- `...attention.attention.value.*` -> `...attention.to_v.*`
|
||||
- `...attention.output.dense.*` -> `...attention.to_out.0.*`
|
||||
"""
|
||||
remapped: dict[str, Any] = {}
|
||||
for key, value in state_dict.items():
|
||||
new_key = key
|
||||
new_key = new_key.replace(".attention.attention.query.", ".attention.to_q.")
|
||||
new_key = new_key.replace(".attention.attention.key.", ".attention.to_k.")
|
||||
new_key = new_key.replace(".attention.attention.value.", ".attention.to_v.")
|
||||
new_key = new_key.replace(".attention.output.dense.", ".attention.to_out.0.")
|
||||
remapped[new_key] = value
|
||||
return remapped
|
||||
|
||||
|
||||
def resolve_decoder_file(
|
||||
accessor: RepoAccessor, encoder_type: str, variant: str, decoder_checkpoint: str | None
|
||||
) -> str:
|
||||
if decoder_checkpoint is not None:
|
||||
if accessor.exists(decoder_checkpoint):
|
||||
return decoder_checkpoint
|
||||
raise FileNotFoundError(f"Decoder checkpoint not found: {decoder_checkpoint}")
|
||||
|
||||
base = f"{DEFAULT_DECODER_SUBDIR[encoder_type]}/{variant}"
|
||||
for name in DECODER_FILE_CANDIDATES:
|
||||
candidate = f"{base}/{name}"
|
||||
if accessor.exists(candidate):
|
||||
return candidate
|
||||
|
||||
raise FileNotFoundError(
|
||||
f"Could not find decoder checkpoint under `{base}`. Tried: {list(DECODER_FILE_CANDIDATES)}"
|
||||
)
|
||||
|
||||
|
||||
def resolve_stats_file(
|
||||
accessor: RepoAccessor,
|
||||
encoder_type: str,
|
||||
dataset_name: str,
|
||||
stats_checkpoint: str | None,
|
||||
) -> str | None:
|
||||
if stats_checkpoint is not None:
|
||||
if accessor.exists(stats_checkpoint):
|
||||
return stats_checkpoint
|
||||
raise FileNotFoundError(f"Stats checkpoint not found: {stats_checkpoint}")
|
||||
|
||||
base = DEFAULT_STATS_SUBDIR[encoder_type]
|
||||
for dataset in dataset_case_candidates(dataset_name):
|
||||
for name in STATS_FILE_CANDIDATES:
|
||||
candidate = f"{base}/{dataset}/{name}"
|
||||
if accessor.exists(candidate):
|
||||
return candidate
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def extract_latent_stats(stats_obj: Any) -> tuple[Any | None, Any | None]:
|
||||
if not isinstance(stats_obj, dict):
|
||||
return None, None
|
||||
|
||||
if "latents_mean" in stats_obj or "latents_std" in stats_obj:
|
||||
return stats_obj.get("latents_mean", None), stats_obj.get("latents_std", None)
|
||||
|
||||
mean = stats_obj.get("mean", None)
|
||||
var = stats_obj.get("var", None)
|
||||
if mean is None and var is None:
|
||||
return None, None
|
||||
|
||||
latents_std = None
|
||||
if var is not None:
|
||||
if isinstance(var, torch.Tensor):
|
||||
latents_std = torch.sqrt(var + 1e-5)
|
||||
else:
|
||||
latents_std = torch.sqrt(torch.tensor(var) + 1e-5)
|
||||
return mean, latents_std
|
||||
|
||||
|
||||
def _strip_final_layernorm_affine(state_dict: dict[str, Any], prefix: str = "") -> dict[str, Any]:
|
||||
"""Remove final layernorm weight/bias from encoder state dict.
|
||||
|
||||
RAE uses non-affine layernorm (weight=1, bias=0 is the default identity).
|
||||
Stripping these keys means the model keeps its default init values, which
|
||||
is functionally equivalent to setting elementwise_affine=False.
|
||||
"""
|
||||
keys_to_strip = {f"{prefix}weight", f"{prefix}bias"}
|
||||
return {k: v for k, v in state_dict.items() if k not in keys_to_strip}
|
||||
|
||||
|
||||
def _load_hf_encoder_state_dict(encoder_type: str, encoder_name_or_path: str) -> dict[str, Any]:
|
||||
"""Download the HF encoder and extract the state dict for the inner model."""
|
||||
if encoder_type == "dinov2":
|
||||
from transformers import Dinov2WithRegistersModel
|
||||
|
||||
hf_model = Dinov2WithRegistersModel.from_pretrained(encoder_name_or_path)
|
||||
sd = hf_model.state_dict()
|
||||
return _strip_final_layernorm_affine(sd, prefix="layernorm.")
|
||||
elif encoder_type == "siglip2":
|
||||
from transformers import SiglipModel
|
||||
|
||||
# SiglipModel.vision_model is a SiglipVisionTransformer.
|
||||
# Our Siglip2Encoder wraps it inside SiglipVisionModel which nests it
|
||||
# under .vision_model, so we add the prefix to match the diffusers key layout.
|
||||
hf_model = SiglipModel.from_pretrained(encoder_name_or_path).vision_model
|
||||
sd = {f"vision_model.{k}": v for k, v in hf_model.state_dict().items()}
|
||||
return _strip_final_layernorm_affine(sd, prefix="vision_model.post_layernorm.")
|
||||
elif encoder_type == "mae":
|
||||
from transformers import ViTMAEForPreTraining
|
||||
|
||||
hf_model = ViTMAEForPreTraining.from_pretrained(encoder_name_or_path).vit
|
||||
sd = hf_model.state_dict()
|
||||
return _strip_final_layernorm_affine(sd, prefix="layernorm.")
|
||||
else:
|
||||
raise ValueError(f"Unknown encoder_type: {encoder_type}")
|
||||
|
||||
|
||||
def convert(args: argparse.Namespace) -> None:
|
||||
accessor = RepoAccessor(args.repo_or_path, cache_dir=args.cache_dir)
|
||||
encoder_name_or_path = args.encoder_name_or_path or ENCODER_DEFAULT_NAME_OR_PATH[args.encoder_type]
|
||||
|
||||
decoder_relpath = resolve_decoder_file(accessor, args.encoder_type, args.variant, args.decoder_checkpoint)
|
||||
stats_relpath = resolve_stats_file(accessor, args.encoder_type, args.dataset_name, args.stats_checkpoint)
|
||||
|
||||
print(f"Using decoder checkpoint: {decoder_relpath}")
|
||||
if stats_relpath is not None:
|
||||
print(f"Using stats checkpoint: {stats_relpath}")
|
||||
else:
|
||||
print("No stats checkpoint found; conversion will proceed without latent stats.")
|
||||
|
||||
if args.dry_run:
|
||||
return
|
||||
|
||||
decoder_path = accessor.fetch(decoder_relpath)
|
||||
decoder_obj = torch.load(decoder_path, map_location="cpu")
|
||||
decoder_state_dict = unwrap_state_dict(decoder_obj)
|
||||
decoder_state_dict = remap_decoder_attention_keys_for_diffusers(decoder_state_dict)
|
||||
|
||||
latents_mean, latents_std = None, None
|
||||
if stats_relpath is not None:
|
||||
stats_path = accessor.fetch(stats_relpath)
|
||||
stats_obj = torch.load(stats_path, map_location="cpu")
|
||||
latents_mean, latents_std = extract_latent_stats(stats_obj)
|
||||
|
||||
decoder_cfg = DECODER_CONFIGS[args.decoder_config_name]
|
||||
|
||||
# Read encoder normalization stats from the HF image processor (only place that downloads encoder info)
|
||||
from transformers import AutoConfig, AutoImageProcessor
|
||||
|
||||
proc = AutoImageProcessor.from_pretrained(encoder_name_or_path)
|
||||
encoder_norm_mean = list(proc.image_mean)
|
||||
encoder_norm_std = list(proc.image_std)
|
||||
|
||||
# Read encoder hidden size and patch size from HF config
|
||||
encoder_hidden_size = ENCODER_HIDDEN_SIZE[args.encoder_type]
|
||||
encoder_patch_size = ENCODER_PATCH_SIZE[args.encoder_type]
|
||||
try:
|
||||
hf_config = AutoConfig.from_pretrained(encoder_name_or_path)
|
||||
# For models like SigLIP that nest vision config
|
||||
if hasattr(hf_config, "vision_config"):
|
||||
hf_config = hf_config.vision_config
|
||||
encoder_hidden_size = hf_config.hidden_size
|
||||
encoder_patch_size = hf_config.patch_size
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Load the actual encoder weights from HF to include in the saved model
|
||||
encoder_state_dict = _load_hf_encoder_state_dict(args.encoder_type, encoder_name_or_path)
|
||||
|
||||
# Build model on meta device to avoid double init overhead
|
||||
with torch.device("meta"):
|
||||
model = AutoencoderRAE(
|
||||
encoder_type=args.encoder_type,
|
||||
encoder_hidden_size=encoder_hidden_size,
|
||||
encoder_patch_size=encoder_patch_size,
|
||||
encoder_input_size=args.encoder_input_size,
|
||||
patch_size=args.patch_size,
|
||||
image_size=args.image_size,
|
||||
num_channels=args.num_channels,
|
||||
encoder_norm_mean=encoder_norm_mean,
|
||||
encoder_norm_std=encoder_norm_std,
|
||||
decoder_hidden_size=decoder_cfg["decoder_hidden_size"],
|
||||
decoder_num_hidden_layers=decoder_cfg["decoder_num_hidden_layers"],
|
||||
decoder_num_attention_heads=decoder_cfg["decoder_num_attention_heads"],
|
||||
decoder_intermediate_size=decoder_cfg["decoder_intermediate_size"],
|
||||
latents_mean=latents_mean,
|
||||
latents_std=latents_std,
|
||||
scaling_factor=args.scaling_factor,
|
||||
)
|
||||
|
||||
# Assemble full state dict and load with assign=True
|
||||
full_state_dict = {}
|
||||
|
||||
# Encoder weights (prefixed with "encoder.")
|
||||
for k, v in encoder_state_dict.items():
|
||||
full_state_dict[f"encoder.{k}"] = v
|
||||
|
||||
# Decoder weights (prefixed with "decoder.")
|
||||
for k, v in decoder_state_dict.items():
|
||||
full_state_dict[f"decoder.{k}"] = v
|
||||
|
||||
# Buffers from config
|
||||
full_state_dict["encoder_mean"] = torch.tensor(encoder_norm_mean, dtype=torch.float32).view(1, 3, 1, 1)
|
||||
full_state_dict["encoder_std"] = torch.tensor(encoder_norm_std, dtype=torch.float32).view(1, 3, 1, 1)
|
||||
if latents_mean is not None:
|
||||
latents_mean_t = latents_mean if isinstance(latents_mean, torch.Tensor) else torch.tensor(latents_mean)
|
||||
full_state_dict["_latents_mean"] = latents_mean_t
|
||||
else:
|
||||
full_state_dict["_latents_mean"] = torch.zeros(1)
|
||||
if latents_std is not None:
|
||||
latents_std_t = latents_std if isinstance(latents_std, torch.Tensor) else torch.tensor(latents_std)
|
||||
full_state_dict["_latents_std"] = latents_std_t
|
||||
else:
|
||||
full_state_dict["_latents_std"] = torch.ones(1)
|
||||
|
||||
model.load_state_dict(full_state_dict, strict=False, assign=True)
|
||||
|
||||
# Verify no critical keys are missing
|
||||
model_keys = {name for name, _ in model.named_parameters()}
|
||||
model_keys |= {name for name, _ in model.named_buffers()}
|
||||
loaded_keys = set(full_state_dict.keys())
|
||||
missing = model_keys - loaded_keys
|
||||
# trainable_cls_token and decoder_pos_embed are initialized, not loaded from original checkpoint
|
||||
allowed_missing = {"decoder.trainable_cls_token", "decoder.decoder_pos_embed"}
|
||||
if missing - allowed_missing:
|
||||
print(f"Warning: missing keys after conversion: {sorted(missing - allowed_missing)}")
|
||||
|
||||
output_path = Path(args.output_path)
|
||||
output_path.mkdir(parents=True, exist_ok=True)
|
||||
model.save_pretrained(output_path)
|
||||
|
||||
if args.verify_load:
|
||||
print("Verifying converted checkpoint with AutoencoderRAE.from_pretrained(low_cpu_mem_usage=False)...")
|
||||
loaded_model = AutoencoderRAE.from_pretrained(output_path, low_cpu_mem_usage=False)
|
||||
if not isinstance(loaded_model, AutoencoderRAE):
|
||||
raise RuntimeError("Verification failed: loaded object is not AutoencoderRAE.")
|
||||
print("Verification passed.")
|
||||
|
||||
print(f"Saved converted AutoencoderRAE to: {output_path}")
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description="Convert RAE decoder checkpoints to diffusers AutoencoderRAE format")
|
||||
parser.add_argument(
|
||||
"--repo_or_path", type=str, required=True, help="Hub repo id (e.g. nyu-visionx/RAE-collections) or local path"
|
||||
)
|
||||
parser.add_argument("--output_path", type=str, required=True, help="Directory to save converted model")
|
||||
|
||||
parser.add_argument("--encoder_type", type=str, choices=["dinov2", "mae", "siglip2"], required=True)
|
||||
parser.add_argument(
|
||||
"--encoder_name_or_path", type=str, default=None, help="Optional encoder HF model id or local path override"
|
||||
)
|
||||
|
||||
parser.add_argument("--variant", type=str, default="ViTXL_n08", help="Decoder variant folder name")
|
||||
parser.add_argument("--dataset_name", type=str, default="imagenet1k", help="Stats dataset folder name")
|
||||
|
||||
parser.add_argument(
|
||||
"--decoder_checkpoint", type=str, default=None, help="Relative path to decoder checkpoint inside repo/path"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--stats_checkpoint", type=str, default=None, help="Relative path to stats checkpoint inside repo/path"
|
||||
)
|
||||
|
||||
parser.add_argument("--decoder_config_name", type=str, choices=list(DECODER_CONFIGS.keys()), default="ViTXL")
|
||||
parser.add_argument("--encoder_input_size", type=int, default=224)
|
||||
parser.add_argument("--patch_size", type=int, default=16)
|
||||
parser.add_argument("--image_size", type=int, default=None)
|
||||
parser.add_argument("--num_channels", type=int, default=3)
|
||||
parser.add_argument("--scaling_factor", type=float, default=1.0)
|
||||
|
||||
parser.add_argument("--cache_dir", type=str, default=None)
|
||||
parser.add_argument("--dry_run", action="store_true", help="Only resolve and print selected files")
|
||||
parser.add_argument(
|
||||
"--verify_load",
|
||||
action="store_true",
|
||||
help="After conversion, load back with AutoencoderRAE.from_pretrained(low_cpu_mem_usage=False).",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
convert(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -202,6 +202,7 @@ else:
|
||||
"AutoencoderKLTemporalDecoder",
|
||||
"AutoencoderKLWan",
|
||||
"AutoencoderOobleck",
|
||||
"AutoencoderRAE",
|
||||
"AutoencoderTiny",
|
||||
"AutoModel",
|
||||
"BriaFiboTransformer2DModel",
|
||||
@@ -974,6 +975,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
AutoencoderKLTemporalDecoder,
|
||||
AutoencoderKLWan,
|
||||
AutoencoderOobleck,
|
||||
AutoencoderRAE,
|
||||
AutoencoderTiny,
|
||||
AutoModel,
|
||||
BriaFiboTransformer2DModel,
|
||||
|
||||
@@ -49,6 +49,7 @@ if is_torch_available():
|
||||
_import_structure["autoencoders.autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"]
|
||||
_import_structure["autoencoders.autoencoder_kl_wan"] = ["AutoencoderKLWan"]
|
||||
_import_structure["autoencoders.autoencoder_oobleck"] = ["AutoencoderOobleck"]
|
||||
_import_structure["autoencoders.autoencoder_rae"] = ["AutoencoderRAE"]
|
||||
_import_structure["autoencoders.autoencoder_tiny"] = ["AutoencoderTiny"]
|
||||
_import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
|
||||
_import_structure["autoencoders.vq_model"] = ["VQModel"]
|
||||
@@ -168,6 +169,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
AutoencoderKLTemporalDecoder,
|
||||
AutoencoderKLWan,
|
||||
AutoencoderOobleck,
|
||||
AutoencoderRAE,
|
||||
AutoencoderTiny,
|
||||
ConsistencyDecoderVAE,
|
||||
VQModel,
|
||||
|
||||
@@ -18,6 +18,7 @@ from .autoencoder_kl_qwenimage import AutoencoderKLQwenImage
|
||||
from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder
|
||||
from .autoencoder_kl_wan import AutoencoderKLWan
|
||||
from .autoencoder_oobleck import AutoencoderOobleck
|
||||
from .autoencoder_rae import AutoencoderRAE
|
||||
from .autoencoder_tiny import AutoencoderTiny
|
||||
from .consistency_decoder_vae import ConsistencyDecoderVAE
|
||||
from .vq_model import VQModel
|
||||
|
||||
692
src/diffusers/models/autoencoders/autoencoder_rae.py
Normal file
692
src/diffusers/models/autoencoders/autoencoder_rae.py
Normal file
@@ -0,0 +1,692 @@
|
||||
# Copyright 2026 The NYU Vision-X and HuggingFace Teams. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from math import sqrt
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils import BaseOutput, logging
|
||||
from ...utils.accelerate_utils import apply_forward_hook
|
||||
from ...utils.import_utils import is_transformers_available
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
|
||||
|
||||
if is_transformers_available():
|
||||
from transformers import (
|
||||
Dinov2WithRegistersConfig,
|
||||
Dinov2WithRegistersModel,
|
||||
SiglipVisionConfig,
|
||||
SiglipVisionModel,
|
||||
ViTMAEConfig,
|
||||
ViTMAEModel,
|
||||
)
|
||||
|
||||
from ..activations import get_activation
|
||||
from ..attention import AttentionMixin
|
||||
from ..attention_processor import Attention
|
||||
from ..embeddings import get_2d_sincos_pos_embed
|
||||
from ..modeling_utils import ModelMixin
|
||||
from .vae import AutoencoderMixin, DecoderOutput, EncoderOutput
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Per-encoder forward functions
|
||||
# ---------------------------------------------------------------------------
|
||||
# Each function takes the raw transformers model + images and returns patch
|
||||
# tokens of shape (B, N, C), stripping CLS / register tokens as needed.
|
||||
|
||||
|
||||
def _dinov2_encoder_forward(model: nn.Module, images: torch.Tensor) -> torch.Tensor:
|
||||
outputs = model(images, output_hidden_states=True)
|
||||
unused_token_num = 5 # 1 CLS + 4 register tokens
|
||||
return outputs.last_hidden_state[:, unused_token_num:]
|
||||
|
||||
|
||||
def _siglip2_encoder_forward(model: nn.Module, images: torch.Tensor) -> torch.Tensor:
|
||||
outputs = model(images, output_hidden_states=True, interpolate_pos_encoding=True)
|
||||
return outputs.last_hidden_state
|
||||
|
||||
|
||||
def _mae_encoder_forward(model: nn.Module, images: torch.Tensor, patch_size: int) -> torch.Tensor:
|
||||
h, w = images.shape[2], images.shape[3]
|
||||
patch_num = int(h * w // patch_size**2)
|
||||
if patch_num * patch_size**2 != h * w:
|
||||
raise ValueError("Image size should be divisible by patch size.")
|
||||
noise = torch.arange(patch_num).unsqueeze(0).expand(images.shape[0], -1).to(images.device).to(images.dtype)
|
||||
outputs = model(images, noise, interpolate_pos_encoding=True)
|
||||
return outputs.last_hidden_state[:, 1:] # remove cls token
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Encoder construction helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _build_encoder(
|
||||
encoder_type: str, hidden_size: int, patch_size: int, num_hidden_layers: int, head_dim: int = 64
|
||||
) -> nn.Module:
|
||||
"""Build a frozen encoder from config (no pretrained download)."""
|
||||
num_attention_heads = hidden_size // head_dim # all supported encoders use head_dim=64
|
||||
|
||||
if encoder_type == "dinov2":
|
||||
config = Dinov2WithRegistersConfig(
|
||||
hidden_size=hidden_size,
|
||||
patch_size=patch_size,
|
||||
image_size=518,
|
||||
num_attention_heads=num_attention_heads,
|
||||
num_hidden_layers=num_hidden_layers,
|
||||
)
|
||||
model = Dinov2WithRegistersModel(config)
|
||||
# RAE strips the final layernorm affine params (identity LN). Remove them from
|
||||
# the architecture so `from_pretrained` doesn't leave them on the meta device.
|
||||
model.layernorm.weight = None
|
||||
model.layernorm.bias = None
|
||||
elif encoder_type == "siglip2":
|
||||
config = SiglipVisionConfig(
|
||||
hidden_size=hidden_size,
|
||||
patch_size=patch_size,
|
||||
image_size=256,
|
||||
num_attention_heads=num_attention_heads,
|
||||
num_hidden_layers=num_hidden_layers,
|
||||
)
|
||||
model = SiglipVisionModel(config)
|
||||
# See dinov2 comment above.
|
||||
model.vision_model.post_layernorm.weight = None
|
||||
model.vision_model.post_layernorm.bias = None
|
||||
elif encoder_type == "mae":
|
||||
config = ViTMAEConfig(
|
||||
hidden_size=hidden_size,
|
||||
patch_size=patch_size,
|
||||
image_size=224,
|
||||
num_attention_heads=num_attention_heads,
|
||||
num_hidden_layers=num_hidden_layers,
|
||||
mask_ratio=0.0,
|
||||
)
|
||||
model = ViTMAEModel(config)
|
||||
# See dinov2 comment above.
|
||||
model.layernorm.weight = None
|
||||
model.layernorm.bias = None
|
||||
else:
|
||||
raise ValueError(f"Unknown encoder_type='{encoder_type}'. Available: dinov2, siglip2, mae")
|
||||
|
||||
model.requires_grad_(False)
|
||||
return model
|
||||
|
||||
|
||||
_ENCODER_FORWARD_FNS = {
|
||||
"dinov2": _dinov2_encoder_forward,
|
||||
"siglip2": _siglip2_encoder_forward,
|
||||
"mae": _mae_encoder_forward,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class RAEDecoderOutput(BaseOutput):
|
||||
"""
|
||||
Output of `RAEDecoder`.
|
||||
|
||||
Args:
|
||||
logits (`torch.Tensor`):
|
||||
Patch reconstruction logits of shape `(batch_size, num_patches, patch_size**2 * num_channels)`.
|
||||
"""
|
||||
|
||||
logits: torch.Tensor
|
||||
|
||||
|
||||
class ViTMAEIntermediate(nn.Module):
|
||||
def __init__(self, hidden_size: int, intermediate_size: int, hidden_act: str = "gelu"):
|
||||
super().__init__()
|
||||
self.dense = nn.Linear(hidden_size, intermediate_size)
|
||||
self.intermediate_act_fn = get_activation(hidden_act)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.intermediate_act_fn(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class ViTMAEOutput(nn.Module):
|
||||
def __init__(self, hidden_size: int, intermediate_size: int, hidden_dropout_prob: float = 0.0):
|
||||
super().__init__()
|
||||
self.dense = nn.Linear(intermediate_size, hidden_size)
|
||||
self.dropout = nn.Dropout(hidden_dropout_prob)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
hidden_states = hidden_states + input_tensor
|
||||
return hidden_states
|
||||
|
||||
|
||||
class ViTMAELayer(nn.Module):
|
||||
"""
|
||||
This matches the naming/parameter structure used in RAE-main (ViTMAE decoder block).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
hidden_size: int,
|
||||
num_attention_heads: int,
|
||||
intermediate_size: int,
|
||||
qkv_bias: bool = True,
|
||||
layer_norm_eps: float = 1e-12,
|
||||
hidden_dropout_prob: float = 0.0,
|
||||
attention_probs_dropout_prob: float = 0.0,
|
||||
hidden_act: str = "gelu",
|
||||
):
|
||||
super().__init__()
|
||||
if hidden_size % num_attention_heads != 0:
|
||||
raise ValueError(
|
||||
f"hidden_size={hidden_size} must be divisible by num_attention_heads={num_attention_heads}"
|
||||
)
|
||||
self.attention = Attention(
|
||||
query_dim=hidden_size,
|
||||
heads=num_attention_heads,
|
||||
dim_head=hidden_size // num_attention_heads,
|
||||
dropout=attention_probs_dropout_prob,
|
||||
bias=qkv_bias,
|
||||
)
|
||||
self.intermediate = ViTMAEIntermediate(
|
||||
hidden_size=hidden_size, intermediate_size=intermediate_size, hidden_act=hidden_act
|
||||
)
|
||||
self.output = ViTMAEOutput(
|
||||
hidden_size=hidden_size, intermediate_size=intermediate_size, hidden_dropout_prob=hidden_dropout_prob
|
||||
)
|
||||
self.layernorm_before = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
|
||||
self.layernorm_after = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
attention_output = self.attention(self.layernorm_before(hidden_states))
|
||||
hidden_states = attention_output + hidden_states
|
||||
|
||||
layer_output = self.layernorm_after(hidden_states)
|
||||
layer_output = self.intermediate(layer_output)
|
||||
layer_output = self.output(layer_output, hidden_states)
|
||||
return layer_output
|
||||
|
||||
|
||||
class RAEDecoder(nn.Module):
|
||||
"""Lightweight RAE decoder."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int = 768,
|
||||
decoder_hidden_size: int = 512,
|
||||
decoder_num_hidden_layers: int = 8,
|
||||
decoder_num_attention_heads: int = 16,
|
||||
decoder_intermediate_size: int = 2048,
|
||||
num_patches: int = 256,
|
||||
patch_size: int = 16,
|
||||
num_channels: int = 3,
|
||||
image_size: int = 256,
|
||||
qkv_bias: bool = True,
|
||||
layer_norm_eps: float = 1e-12,
|
||||
hidden_dropout_prob: float = 0.0,
|
||||
attention_probs_dropout_prob: float = 0.0,
|
||||
hidden_act: str = "gelu",
|
||||
):
|
||||
super().__init__()
|
||||
self.decoder_hidden_size = decoder_hidden_size
|
||||
self.patch_size = patch_size
|
||||
self.num_channels = num_channels
|
||||
self.image_size = image_size
|
||||
self.num_patches = num_patches
|
||||
|
||||
self.decoder_embed = nn.Linear(hidden_size, decoder_hidden_size, bias=True)
|
||||
self.register_buffer("decoder_pos_embed", torch.zeros(1, num_patches + 1, decoder_hidden_size))
|
||||
|
||||
self.decoder_layers = nn.ModuleList(
|
||||
[
|
||||
ViTMAELayer(
|
||||
hidden_size=decoder_hidden_size,
|
||||
num_attention_heads=decoder_num_attention_heads,
|
||||
intermediate_size=decoder_intermediate_size,
|
||||
qkv_bias=qkv_bias,
|
||||
layer_norm_eps=layer_norm_eps,
|
||||
hidden_dropout_prob=hidden_dropout_prob,
|
||||
attention_probs_dropout_prob=attention_probs_dropout_prob,
|
||||
hidden_act=hidden_act,
|
||||
)
|
||||
for _ in range(decoder_num_hidden_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.decoder_norm = nn.LayerNorm(decoder_hidden_size, eps=layer_norm_eps)
|
||||
self.decoder_pred = nn.Linear(decoder_hidden_size, patch_size**2 * num_channels, bias=True)
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
self._initialize_weights(num_patches)
|
||||
self.trainable_cls_token = nn.Parameter(torch.zeros(1, 1, decoder_hidden_size))
|
||||
|
||||
def _initialize_weights(self, num_patches: int):
|
||||
# Skip initialization when parameters are on meta device (e.g. during
|
||||
# accelerate.init_empty_weights() used by low_cpu_mem_usage loading).
|
||||
# The weights are initialized.
|
||||
if self.decoder_pos_embed.device.type == "meta":
|
||||
return
|
||||
|
||||
grid_size = int(num_patches**0.5)
|
||||
pos_embed = get_2d_sincos_pos_embed(
|
||||
self.decoder_pos_embed.shape[-1],
|
||||
grid_size,
|
||||
cls_token=True,
|
||||
extra_tokens=1,
|
||||
output_type="pt",
|
||||
device=self.decoder_pos_embed.device,
|
||||
)
|
||||
self.decoder_pos_embed.data.copy_(pos_embed.unsqueeze(0).to(dtype=self.decoder_pos_embed.dtype))
|
||||
|
||||
def interpolate_pos_encoding(self, embeddings: torch.Tensor) -> torch.Tensor:
|
||||
embeddings_positions = embeddings.shape[1] - 1
|
||||
num_positions = self.decoder_pos_embed.shape[1] - 1
|
||||
|
||||
class_pos_embed = self.decoder_pos_embed[:, 0, :]
|
||||
patch_pos_embed = self.decoder_pos_embed[:, 1:, :]
|
||||
dim = self.decoder_pos_embed.shape[-1]
|
||||
|
||||
patch_pos_embed = patch_pos_embed.reshape(1, 1, -1, dim).permute(0, 3, 1, 2)
|
||||
patch_pos_embed = F.interpolate(
|
||||
patch_pos_embed,
|
||||
scale_factor=(1, embeddings_positions / num_positions),
|
||||
mode="bicubic",
|
||||
align_corners=False,
|
||||
)
|
||||
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
||||
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
|
||||
|
||||
def interpolate_latent(self, x: torch.Tensor) -> torch.Tensor:
|
||||
b, l, c = x.shape
|
||||
if l == self.num_patches:
|
||||
return x
|
||||
h = w = int(l**0.5)
|
||||
x = x.reshape(b, h, w, c).permute(0, 3, 1, 2)
|
||||
target_size = (int(self.num_patches**0.5), int(self.num_patches**0.5))
|
||||
x = F.interpolate(x, size=target_size, mode="bilinear", align_corners=False)
|
||||
x = x.permute(0, 2, 3, 1).contiguous().view(b, self.num_patches, c)
|
||||
return x
|
||||
|
||||
def unpatchify(self, patchified_pixel_values: torch.Tensor, original_image_size: tuple[int, int] | None = None):
|
||||
patch_size, num_channels = self.patch_size, self.num_channels
|
||||
original_image_size = (
|
||||
original_image_size if original_image_size is not None else (self.image_size, self.image_size)
|
||||
)
|
||||
original_height, original_width = original_image_size
|
||||
num_patches_h = original_height // patch_size
|
||||
num_patches_w = original_width // patch_size
|
||||
if num_patches_h * num_patches_w != patchified_pixel_values.shape[1]:
|
||||
raise ValueError(
|
||||
f"The number of patches in the patchified pixel values {patchified_pixel_values.shape[1]}, does not match the number of patches on original image {num_patches_h}*{num_patches_w}"
|
||||
)
|
||||
|
||||
batch_size = patchified_pixel_values.shape[0]
|
||||
patchified_pixel_values = patchified_pixel_values.reshape(
|
||||
batch_size,
|
||||
num_patches_h,
|
||||
num_patches_w,
|
||||
patch_size,
|
||||
patch_size,
|
||||
num_channels,
|
||||
)
|
||||
patchified_pixel_values = torch.einsum("nhwpqc->nchpwq", patchified_pixel_values)
|
||||
pixel_values = patchified_pixel_values.reshape(
|
||||
batch_size,
|
||||
num_channels,
|
||||
num_patches_h * patch_size,
|
||||
num_patches_w * patch_size,
|
||||
)
|
||||
return pixel_values
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
*,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
drop_cls_token: bool = False,
|
||||
return_dict: bool = True,
|
||||
) -> RAEDecoderOutput | tuple[torch.Tensor]:
|
||||
x = self.decoder_embed(hidden_states)
|
||||
if drop_cls_token:
|
||||
x_ = x[:, 1:, :]
|
||||
x_ = self.interpolate_latent(x_)
|
||||
else:
|
||||
x_ = self.interpolate_latent(x)
|
||||
|
||||
cls_token = self.trainable_cls_token.expand(x_.shape[0], -1, -1)
|
||||
x = torch.cat([cls_token, x_], dim=1)
|
||||
|
||||
if interpolate_pos_encoding:
|
||||
if not drop_cls_token:
|
||||
raise ValueError("interpolate_pos_encoding only supports drop_cls_token=True")
|
||||
decoder_pos_embed = self.interpolate_pos_encoding(x)
|
||||
else:
|
||||
decoder_pos_embed = self.decoder_pos_embed
|
||||
|
||||
hidden_states = x + decoder_pos_embed.to(device=x.device, dtype=x.dtype)
|
||||
|
||||
for layer_module in self.decoder_layers:
|
||||
hidden_states = layer_module(hidden_states)
|
||||
|
||||
hidden_states = self.decoder_norm(hidden_states)
|
||||
logits = self.decoder_pred(hidden_states)
|
||||
logits = logits[:, 1:, :]
|
||||
|
||||
if not return_dict:
|
||||
return (logits,)
|
||||
return RAEDecoderOutput(logits=logits)
|
||||
|
||||
|
||||
class AutoencoderRAE(ModelMixin, AttentionMixin, AutoencoderMixin, ConfigMixin):
|
||||
r"""
|
||||
Representation Autoencoder (RAE) model for encoding images to latents and decoding latents to images.
|
||||
|
||||
This model uses a frozen pretrained encoder (DINOv2, SigLIP2, or MAE) with a trainable ViT decoder to reconstruct
|
||||
images from learned representations.
|
||||
|
||||
This model inherits from [`ModelMixin`]. Check the superclass documentation for its generic methods implemented for
|
||||
all models (such as downloading or saving).
|
||||
|
||||
Args:
|
||||
encoder_type (`str`, *optional*, defaults to `"dinov2"`):
|
||||
Type of frozen encoder to use. One of `"dinov2"`, `"siglip2"`, or `"mae"`.
|
||||
encoder_hidden_size (`int`, *optional*, defaults to `768`):
|
||||
Hidden size of the encoder model.
|
||||
encoder_patch_size (`int`, *optional*, defaults to `14`):
|
||||
Patch size of the encoder model.
|
||||
encoder_num_hidden_layers (`int`, *optional*, defaults to `12`):
|
||||
Number of hidden layers in the encoder model.
|
||||
patch_size (`int`, *optional*, defaults to `16`):
|
||||
Decoder patch size (used for unpatchify and decoder head).
|
||||
encoder_input_size (`int`, *optional*, defaults to `224`):
|
||||
Input size expected by the encoder.
|
||||
image_size (`int`, *optional*):
|
||||
Decoder output image size. If `None`, it is derived from encoder token count and `patch_size` like
|
||||
RAE-main: `image_size = patch_size * sqrt(num_patches)`, where `num_patches = (encoder_input_size //
|
||||
encoder_patch_size) ** 2`.
|
||||
num_channels (`int`, *optional*, defaults to `3`):
|
||||
Number of input/output channels.
|
||||
encoder_norm_mean (`list`, *optional*, defaults to `[0.485, 0.456, 0.406]`):
|
||||
Channel-wise mean for encoder input normalization (ImageNet defaults).
|
||||
encoder_norm_std (`list`, *optional*, defaults to `[0.229, 0.224, 0.225]`):
|
||||
Channel-wise std for encoder input normalization (ImageNet defaults).
|
||||
latents_mean (`list` or `tuple`, *optional*):
|
||||
Optional mean for latent normalization. Tensor inputs are accepted and converted to config-serializable
|
||||
lists.
|
||||
latents_std (`list` or `tuple`, *optional*):
|
||||
Optional standard deviation for latent normalization. Tensor inputs are accepted and converted to
|
||||
config-serializable lists.
|
||||
noise_tau (`float`, *optional*, defaults to `0.0`):
|
||||
Noise level for training (adds noise to latents during training).
|
||||
reshape_to_2d (`bool`, *optional*, defaults to `True`):
|
||||
Whether to reshape latents to 2D (B, C, H, W) format.
|
||||
use_encoder_loss (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use encoder hidden states in the loss (for advanced training).
|
||||
"""
|
||||
|
||||
# NOTE: gradient checkpointing is not wired up for this model yet.
|
||||
_supports_gradient_checkpointing = False
|
||||
_no_split_modules = ["ViTMAELayer"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
encoder_type: str = "dinov2",
|
||||
encoder_hidden_size: int = 768,
|
||||
encoder_patch_size: int = 14,
|
||||
encoder_num_hidden_layers: int = 12,
|
||||
decoder_hidden_size: int = 512,
|
||||
decoder_num_hidden_layers: int = 8,
|
||||
decoder_num_attention_heads: int = 16,
|
||||
decoder_intermediate_size: int = 2048,
|
||||
patch_size: int = 16,
|
||||
encoder_input_size: int = 224,
|
||||
image_size: int | None = None,
|
||||
num_channels: int = 3,
|
||||
encoder_norm_mean: list | None = None,
|
||||
encoder_norm_std: list | None = None,
|
||||
latents_mean: list | tuple | torch.Tensor | None = None,
|
||||
latents_std: list | tuple | torch.Tensor | None = None,
|
||||
noise_tau: float = 0.0,
|
||||
reshape_to_2d: bool = True,
|
||||
use_encoder_loss: bool = False,
|
||||
scaling_factor: float = 1.0,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if encoder_type not in _ENCODER_FORWARD_FNS:
|
||||
raise ValueError(
|
||||
f"Unknown encoder_type='{encoder_type}'. Available: {sorted(_ENCODER_FORWARD_FNS.keys())}"
|
||||
)
|
||||
|
||||
if encoder_input_size % encoder_patch_size != 0:
|
||||
raise ValueError(
|
||||
f"encoder_input_size={encoder_input_size} must be divisible by encoder_patch_size={encoder_patch_size}."
|
||||
)
|
||||
|
||||
decoder_patch_size = patch_size
|
||||
if decoder_patch_size <= 0:
|
||||
raise ValueError("patch_size must be a positive integer (this is decoder_patch_size).")
|
||||
|
||||
num_patches = (encoder_input_size // encoder_patch_size) ** 2
|
||||
grid = int(sqrt(num_patches))
|
||||
if grid * grid != num_patches:
|
||||
raise ValueError(f"Computed num_patches={num_patches} must be a perfect square.")
|
||||
|
||||
derived_image_size = decoder_patch_size * grid
|
||||
if image_size is None:
|
||||
image_size = derived_image_size
|
||||
else:
|
||||
image_size = int(image_size)
|
||||
if image_size != derived_image_size:
|
||||
raise ValueError(
|
||||
f"image_size={image_size} must equal decoder_patch_size*sqrt(num_patches)={derived_image_size} "
|
||||
f"for patch_size={decoder_patch_size} and computed num_patches={num_patches}."
|
||||
)
|
||||
|
||||
def _to_config_compatible(value: Any) -> Any:
|
||||
if isinstance(value, torch.Tensor):
|
||||
return value.detach().cpu().tolist()
|
||||
if isinstance(value, tuple):
|
||||
return [_to_config_compatible(v) for v in value]
|
||||
if isinstance(value, list):
|
||||
return [_to_config_compatible(v) for v in value]
|
||||
return value
|
||||
|
||||
def _as_optional_tensor(value: torch.Tensor | list | tuple | None) -> torch.Tensor | None:
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, torch.Tensor):
|
||||
return value.detach().clone()
|
||||
return torch.tensor(value, dtype=torch.float32)
|
||||
|
||||
latents_std_tensor = _as_optional_tensor(latents_std)
|
||||
|
||||
# Ensure config values are JSON-serializable (list/None), even if caller passes torch.Tensors.
|
||||
self.register_to_config(
|
||||
latents_mean=_to_config_compatible(latents_mean),
|
||||
latents_std=_to_config_compatible(latents_std),
|
||||
)
|
||||
|
||||
# Frozen representation encoder (built from config, no downloads)
|
||||
self.encoder: nn.Module = _build_encoder(
|
||||
encoder_type=encoder_type,
|
||||
hidden_size=encoder_hidden_size,
|
||||
patch_size=encoder_patch_size,
|
||||
num_hidden_layers=encoder_num_hidden_layers,
|
||||
)
|
||||
self._encoder_forward_fn = _ENCODER_FORWARD_FNS[encoder_type]
|
||||
num_patches = (encoder_input_size // encoder_patch_size) ** 2
|
||||
|
||||
# Encoder input normalization stats (ImageNet defaults)
|
||||
if encoder_norm_mean is None:
|
||||
encoder_norm_mean = [0.485, 0.456, 0.406]
|
||||
if encoder_norm_std is None:
|
||||
encoder_norm_std = [0.229, 0.224, 0.225]
|
||||
encoder_mean_tensor = torch.tensor(encoder_norm_mean, dtype=torch.float32).view(1, 3, 1, 1)
|
||||
encoder_std_tensor = torch.tensor(encoder_norm_std, dtype=torch.float32).view(1, 3, 1, 1)
|
||||
|
||||
self.register_buffer("encoder_mean", encoder_mean_tensor, persistent=True)
|
||||
self.register_buffer("encoder_std", encoder_std_tensor, persistent=True)
|
||||
|
||||
# Latent normalization buffers (defaults are no-ops; actual values come from checkpoint)
|
||||
latents_mean_tensor = _as_optional_tensor(latents_mean)
|
||||
if latents_mean_tensor is None:
|
||||
latents_mean_tensor = torch.zeros(1)
|
||||
self.register_buffer("_latents_mean", latents_mean_tensor, persistent=True)
|
||||
|
||||
if latents_std_tensor is None:
|
||||
latents_std_tensor = torch.ones(1)
|
||||
self.register_buffer("_latents_std", latents_std_tensor, persistent=True)
|
||||
|
||||
# ViT-MAE style decoder
|
||||
self.decoder = RAEDecoder(
|
||||
hidden_size=int(encoder_hidden_size),
|
||||
decoder_hidden_size=int(decoder_hidden_size),
|
||||
decoder_num_hidden_layers=int(decoder_num_hidden_layers),
|
||||
decoder_num_attention_heads=int(decoder_num_attention_heads),
|
||||
decoder_intermediate_size=int(decoder_intermediate_size),
|
||||
num_patches=int(num_patches),
|
||||
patch_size=int(decoder_patch_size),
|
||||
num_channels=int(num_channels),
|
||||
image_size=int(image_size),
|
||||
)
|
||||
|
||||
self.num_patches = int(num_patches)
|
||||
self.decoder_patch_size = int(decoder_patch_size)
|
||||
self.decoder_image_size = int(image_size)
|
||||
|
||||
# Slicing support (batch dimension) similar to other diffusers autoencoders
|
||||
self.use_slicing = False
|
||||
|
||||
def _noising(self, x: torch.Tensor, generator: torch.Generator | None = None) -> torch.Tensor:
|
||||
# Per-sample random sigma in [0, noise_tau]
|
||||
noise_sigma = self.config.noise_tau * torch.rand(
|
||||
(x.size(0),) + (1,) * (x.ndim - 1), device=x.device, dtype=x.dtype, generator=generator
|
||||
)
|
||||
return x + noise_sigma * randn_tensor(x.shape, generator=generator, device=x.device, dtype=x.dtype)
|
||||
|
||||
def _resize_and_normalize(self, x: torch.Tensor) -> torch.Tensor:
|
||||
_, _, h, w = x.shape
|
||||
if h != self.config.encoder_input_size or w != self.config.encoder_input_size:
|
||||
x = F.interpolate(
|
||||
x,
|
||||
size=(self.config.encoder_input_size, self.config.encoder_input_size),
|
||||
mode="bicubic",
|
||||
align_corners=False,
|
||||
)
|
||||
mean = self.encoder_mean.to(device=x.device, dtype=x.dtype)
|
||||
std = self.encoder_std.to(device=x.device, dtype=x.dtype)
|
||||
return (x - mean) / std
|
||||
|
||||
def _denormalize_image(self, x: torch.Tensor) -> torch.Tensor:
|
||||
mean = self.encoder_mean.to(device=x.device, dtype=x.dtype)
|
||||
std = self.encoder_std.to(device=x.device, dtype=x.dtype)
|
||||
return x * std + mean
|
||||
|
||||
def _normalize_latents(self, z: torch.Tensor) -> torch.Tensor:
|
||||
latents_mean = self._latents_mean.to(device=z.device, dtype=z.dtype)
|
||||
latents_std = self._latents_std.to(device=z.device, dtype=z.dtype)
|
||||
return (z - latents_mean) / (latents_std + 1e-5)
|
||||
|
||||
def _denormalize_latents(self, z: torch.Tensor) -> torch.Tensor:
|
||||
latents_mean = self._latents_mean.to(device=z.device, dtype=z.dtype)
|
||||
latents_std = self._latents_std.to(device=z.device, dtype=z.dtype)
|
||||
return z * (latents_std + 1e-5) + latents_mean
|
||||
|
||||
def _encode(self, x: torch.Tensor, generator: torch.Generator | None = None) -> torch.Tensor:
|
||||
x = self._resize_and_normalize(x)
|
||||
|
||||
if self.config.encoder_type == "mae":
|
||||
tokens = self._encoder_forward_fn(self.encoder, x, self.config.encoder_patch_size)
|
||||
else:
|
||||
tokens = self._encoder_forward_fn(self.encoder, x) # (B, N, C)
|
||||
|
||||
if self.training and self.config.noise_tau > 0:
|
||||
tokens = self._noising(tokens, generator=generator)
|
||||
|
||||
if self.config.reshape_to_2d:
|
||||
b, n, c = tokens.shape
|
||||
side = int(sqrt(n))
|
||||
if side * side != n:
|
||||
raise ValueError(f"Token length n={n} is not a perfect square; cannot reshape to 2D.")
|
||||
z = tokens.transpose(1, 2).contiguous().view(b, c, side, side) # (B, C, h, w)
|
||||
else:
|
||||
z = tokens
|
||||
|
||||
z = self._normalize_latents(z)
|
||||
|
||||
# Follow diffusers convention: optionally scale latents for diffusion
|
||||
if self.config.scaling_factor != 1.0:
|
||||
z = z * self.config.scaling_factor
|
||||
|
||||
return z
|
||||
|
||||
@apply_forward_hook
|
||||
def encode(
|
||||
self, x: torch.Tensor, return_dict: bool = True, generator: torch.Generator | None = None
|
||||
) -> EncoderOutput | tuple[torch.Tensor]:
|
||||
if self.use_slicing and x.shape[0] > 1:
|
||||
latents = torch.cat([self._encode(x_slice, generator=generator) for x_slice in x.split(1)], dim=0)
|
||||
else:
|
||||
latents = self._encode(x, generator=generator)
|
||||
|
||||
if not return_dict:
|
||||
return (latents,)
|
||||
return EncoderOutput(latent=latents)
|
||||
|
||||
def _decode(self, z: torch.Tensor) -> torch.Tensor:
|
||||
# Undo scaling factor if applied at encode time
|
||||
if self.config.scaling_factor != 1.0:
|
||||
z = z / self.config.scaling_factor
|
||||
|
||||
z = self._denormalize_latents(z)
|
||||
|
||||
if self.config.reshape_to_2d:
|
||||
b, c, h, w = z.shape
|
||||
tokens = z.view(b, c, h * w).transpose(1, 2).contiguous() # (B, N, C)
|
||||
else:
|
||||
tokens = z
|
||||
|
||||
logits = self.decoder(tokens, return_dict=True).logits
|
||||
x_rec = self.decoder.unpatchify(logits)
|
||||
x_rec = self._denormalize_image(x_rec)
|
||||
return x_rec
|
||||
|
||||
@apply_forward_hook
|
||||
def decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput | tuple[torch.Tensor]:
|
||||
if self.use_slicing and z.shape[0] > 1:
|
||||
decoded = torch.cat([self._decode(z_slice) for z_slice in z.split(1)], dim=0)
|
||||
else:
|
||||
decoded = self._decode(z)
|
||||
|
||||
if not return_dict:
|
||||
return (decoded,)
|
||||
return DecoderOutput(sample=decoded)
|
||||
|
||||
def forward(
|
||||
self, sample: torch.Tensor, return_dict: bool = True, generator: torch.Generator | None = None
|
||||
) -> DecoderOutput | tuple[torch.Tensor]:
|
||||
latents = self.encode(sample, return_dict=False, generator=generator)[0]
|
||||
decoded = self.decode(latents, return_dict=False)[0]
|
||||
if not return_dict:
|
||||
return (decoded,)
|
||||
return DecoderOutput(sample=decoded)
|
||||
@@ -656,6 +656,21 @@ class AutoencoderOobleck(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class AutoencoderRAE(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class AutoencoderTiny(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
||||
300
tests/models/autoencoders/test_models_autoencoder_rae.py
Normal file
300
tests/models/autoencoders/test_models_autoencoder_rae.py
Normal file
@@ -0,0 +1,300 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torchvision.transforms.functional import to_tensor
|
||||
|
||||
import diffusers.models.autoencoders.autoencoder_rae as _rae_module
|
||||
from diffusers.models.autoencoders.autoencoder_rae import (
|
||||
_ENCODER_FORWARD_FNS,
|
||||
AutoencoderRAE,
|
||||
_build_encoder,
|
||||
)
|
||||
from diffusers.utils import load_image
|
||||
|
||||
from ...testing_utils import (
|
||||
backend_empty_cache,
|
||||
enable_full_determinism,
|
||||
slow,
|
||||
torch_all_close,
|
||||
torch_device,
|
||||
)
|
||||
from ..testing_utils import BaseModelTesterConfig, ModelTesterMixin
|
||||
from .testing_utils import AutoencoderTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tiny test encoder for fast unit tests (no transformers dependency)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _TinyTestEncoderModule(torch.nn.Module):
|
||||
"""Minimal encoder that mimics the patch-token interface without any HF model."""
|
||||
|
||||
def __init__(self, hidden_size: int = 16, patch_size: int = 8, **kwargs):
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
self.hidden_size = hidden_size
|
||||
|
||||
def forward(self, images: torch.Tensor) -> torch.Tensor:
|
||||
pooled = F.avg_pool2d(images.mean(dim=1, keepdim=True), kernel_size=self.patch_size, stride=self.patch_size)
|
||||
tokens = pooled.flatten(2).transpose(1, 2).contiguous()
|
||||
return tokens.repeat(1, 1, self.hidden_size)
|
||||
|
||||
|
||||
def _tiny_test_encoder_forward(model, images):
|
||||
return model(images)
|
||||
|
||||
|
||||
def _build_tiny_test_encoder(encoder_type, hidden_size, patch_size, num_hidden_layers):
|
||||
return _TinyTestEncoderModule(hidden_size=hidden_size, patch_size=patch_size)
|
||||
|
||||
|
||||
# Monkey-patch the dispatch tables so "tiny_test" is recognised by AutoencoderRAE
|
||||
_ENCODER_FORWARD_FNS["tiny_test"] = _tiny_test_encoder_forward
|
||||
_original_build_encoder = _build_encoder
|
||||
|
||||
|
||||
def _patched_build_encoder(encoder_type, hidden_size, patch_size, num_hidden_layers):
|
||||
if encoder_type == "tiny_test":
|
||||
return _build_tiny_test_encoder(encoder_type, hidden_size, patch_size, num_hidden_layers)
|
||||
return _original_build_encoder(encoder_type, hidden_size, patch_size, num_hidden_layers)
|
||||
|
||||
|
||||
_rae_module._build_encoder = _patched_build_encoder
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test config
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class AutoencoderRAETesterConfig(BaseModelTesterConfig):
|
||||
@property
|
||||
def model_class(self):
|
||||
return AutoencoderRAE
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (3, 16, 16)
|
||||
|
||||
def get_init_dict(self):
|
||||
return {
|
||||
"encoder_type": "tiny_test",
|
||||
"encoder_hidden_size": 16,
|
||||
"encoder_patch_size": 8,
|
||||
"encoder_input_size": 32,
|
||||
"patch_size": 4,
|
||||
"image_size": 16,
|
||||
"decoder_hidden_size": 32,
|
||||
"decoder_num_hidden_layers": 1,
|
||||
"decoder_num_attention_heads": 4,
|
||||
"decoder_intermediate_size": 64,
|
||||
"num_channels": 3,
|
||||
"encoder_norm_mean": [0.5, 0.5, 0.5],
|
||||
"encoder_norm_std": [0.5, 0.5, 0.5],
|
||||
"noise_tau": 0.0,
|
||||
"reshape_to_2d": True,
|
||||
"scaling_factor": 1.0,
|
||||
}
|
||||
|
||||
@property
|
||||
def generator(self):
|
||||
return torch.Generator("cpu").manual_seed(0)
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
return {"sample": torch.randn(2, 3, 32, 32, generator=self.generator, device="cpu").to(torch_device)}
|
||||
|
||||
# Bridge for AutoencoderTesterMixin which still uses the old interface
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
return self.get_init_dict(), self.get_dummy_inputs()
|
||||
|
||||
def _make_model(self, **overrides) -> AutoencoderRAE:
|
||||
config = self.get_init_dict()
|
||||
config.update(overrides)
|
||||
return AutoencoderRAE(**config).to(torch_device)
|
||||
|
||||
|
||||
class TestAutoEncoderRAE(AutoencoderRAETesterConfig, ModelTesterMixin):
|
||||
"""Core model tests for AutoencoderRAE."""
|
||||
|
||||
@pytest.mark.skip(reason="AutoencoderRAE does not support torch dynamo yet")
|
||||
def test_from_save_pretrained_dynamo(self): ...
|
||||
|
||||
def test_fast_encode_decode_and_forward_shapes(self):
|
||||
model = self._make_model().eval()
|
||||
x = torch.rand(2, 3, 32, 32, device=torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
z = model.encode(x).latent
|
||||
decoded = model.decode(z).sample
|
||||
recon = model(x).sample
|
||||
|
||||
assert z.shape == (2, 16, 4, 4)
|
||||
assert decoded.shape == (2, 3, 16, 16)
|
||||
assert recon.shape == (2, 3, 16, 16)
|
||||
assert torch.isfinite(recon).all().item()
|
||||
|
||||
def test_fast_scaling_factor_encode_and_decode_consistency(self):
|
||||
torch.manual_seed(0)
|
||||
model_base = self._make_model(scaling_factor=1.0).eval()
|
||||
torch.manual_seed(0)
|
||||
model_scaled = self._make_model(scaling_factor=2.0).eval()
|
||||
|
||||
x = torch.rand(2, 3, 32, 32, device=torch_device)
|
||||
with torch.no_grad():
|
||||
z_base = model_base.encode(x).latent
|
||||
z_scaled = model_scaled.encode(x).latent
|
||||
recon_base = model_base.decode(z_base).sample
|
||||
recon_scaled = model_scaled.decode(z_scaled).sample
|
||||
|
||||
assert torch.allclose(z_scaled, z_base * 2.0, atol=1e-5, rtol=1e-4)
|
||||
assert torch.allclose(recon_scaled, recon_base, atol=1e-5, rtol=1e-4)
|
||||
|
||||
def test_fast_latents_normalization_matches_formula(self):
|
||||
latents_mean = torch.full((1, 16, 1, 1), 0.25, dtype=torch.float32)
|
||||
latents_std = torch.full((1, 16, 1, 1), 2.0, dtype=torch.float32)
|
||||
|
||||
model_raw = self._make_model().eval()
|
||||
model_norm = self._make_model(latents_mean=latents_mean, latents_std=latents_std).eval()
|
||||
x = torch.rand(1, 3, 32, 32, device=torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
z_raw = model_raw.encode(x).latent
|
||||
z_norm = model_norm.encode(x).latent
|
||||
|
||||
expected = (z_raw - latents_mean.to(z_raw.device, z_raw.dtype)) / (
|
||||
latents_std.to(z_raw.device, z_raw.dtype) + 1e-5
|
||||
)
|
||||
assert torch.allclose(z_norm, expected, atol=1e-5, rtol=1e-4)
|
||||
|
||||
def test_fast_slicing_matches_non_slicing(self):
|
||||
model = self._make_model().eval()
|
||||
x = torch.rand(3, 3, 32, 32, device=torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
model.use_slicing = False
|
||||
z_no_slice = model.encode(x).latent
|
||||
out_no_slice = model.decode(z_no_slice).sample
|
||||
|
||||
model.use_slicing = True
|
||||
z_slice = model.encode(x).latent
|
||||
out_slice = model.decode(z_slice).sample
|
||||
|
||||
assert torch.allclose(z_slice, z_no_slice, atol=1e-6, rtol=1e-5)
|
||||
assert torch.allclose(out_slice, out_no_slice, atol=1e-6, rtol=1e-5)
|
||||
|
||||
def test_fast_noise_tau_applies_only_in_train(self):
|
||||
model = self._make_model(noise_tau=0.5).to(torch_device)
|
||||
x = torch.rand(2, 3, 32, 32, device=torch_device)
|
||||
|
||||
model.train()
|
||||
torch.manual_seed(0)
|
||||
z_train_1 = model.encode(x).latent
|
||||
torch.manual_seed(1)
|
||||
z_train_2 = model.encode(x).latent
|
||||
|
||||
model.eval()
|
||||
torch.manual_seed(0)
|
||||
z_eval_1 = model.encode(x).latent
|
||||
torch.manual_seed(1)
|
||||
z_eval_2 = model.encode(x).latent
|
||||
|
||||
assert z_train_1.shape == z_eval_1.shape
|
||||
assert not torch.allclose(z_train_1, z_train_2)
|
||||
assert torch.allclose(z_eval_1, z_eval_2, atol=1e-6, rtol=1e-5)
|
||||
|
||||
|
||||
class TestAutoEncoderRAESlicingTiling(AutoencoderRAETesterConfig, AutoencoderTesterMixin):
|
||||
"""Slicing and tiling tests for AutoencoderRAE."""
|
||||
|
||||
|
||||
@slow
|
||||
@pytest.mark.skip(reason="Not enough model usage to justify slow tests yet.")
|
||||
class AutoencoderRAEEncoderIntegrationTests:
|
||||
def teardown_method(self):
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def test_dinov2_encoder_forward_shape(self):
|
||||
encoder = _build_encoder("dinov2", hidden_size=768, patch_size=14, num_hidden_layers=12).to(torch_device)
|
||||
x = torch.rand(1, 3, 224, 224, device=torch_device)
|
||||
y = _ENCODER_FORWARD_FNS["dinov2"](encoder, x)
|
||||
|
||||
assert y.ndim == 3
|
||||
assert y.shape[0] == 1
|
||||
assert y.shape[1] == 256 # (224/14)^2 - 5 (CLS + 4 register) = 251? Actually dinov2 has 256 patches
|
||||
assert y.shape[2] == 768
|
||||
|
||||
def test_siglip2_encoder_forward_shape(self):
|
||||
encoder = _build_encoder("siglip2", hidden_size=768, patch_size=16, num_hidden_layers=12).to(torch_device)
|
||||
x = torch.rand(1, 3, 224, 224, device=torch_device)
|
||||
y = _ENCODER_FORWARD_FNS["siglip2"](encoder, x)
|
||||
|
||||
assert y.ndim == 3
|
||||
assert y.shape[0] == 1
|
||||
assert y.shape[1] == 196 # (224/16)^2
|
||||
assert y.shape[2] == 768
|
||||
|
||||
def test_mae_encoder_forward_shape(self):
|
||||
encoder = _build_encoder("mae", hidden_size=768, patch_size=16, num_hidden_layers=12).to(torch_device)
|
||||
x = torch.rand(1, 3, 224, 224, device=torch_device)
|
||||
y = _ENCODER_FORWARD_FNS["mae"](encoder, x, patch_size=16)
|
||||
|
||||
assert y.ndim == 3
|
||||
assert y.shape[0] == 1
|
||||
assert y.shape[1] == 196 # (224/16)^2
|
||||
assert y.shape[2] == 768
|
||||
|
||||
|
||||
@slow
|
||||
@pytest.mark.skip(reason="Not enough model usage to justify slow tests yet.")
|
||||
class AutoencoderRAEIntegrationTests:
|
||||
def teardown_method(self):
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def test_autoencoder_rae_from_pretrained_dinov2(self):
|
||||
model = AutoencoderRAE.from_pretrained("nyu-visionx/RAE-dinov2-wReg-base-ViTXL-n08").to(torch_device)
|
||||
model.eval()
|
||||
|
||||
image = load_image(
|
||||
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cat.png"
|
||||
)
|
||||
image = image.convert("RGB").resize((224, 224))
|
||||
x = to_tensor(image).unsqueeze(0).to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
latents = model.encode(x).latent
|
||||
assert latents.shape == (1, 768, 16, 16)
|
||||
|
||||
recon = model.decode(latents).sample
|
||||
assert recon.shape == (1, 3, 256, 256)
|
||||
assert torch.isfinite(recon).all().item()
|
||||
|
||||
# fmt: off
|
||||
expected_latent_slice = torch.tensor([0.7617, 0.8824, -0.4891])
|
||||
expected_recon_slice = torch.tensor([0.1263, 0.1355, 0.1435])
|
||||
# fmt: on
|
||||
|
||||
assert torch_all_close(latents[0, :3, 0, 0].float().cpu(), expected_latent_slice, atol=1e-3)
|
||||
assert torch_all_close(recon[0, 0, 0, :3].float().cpu(), expected_recon_slice, atol=1e-3)
|
||||
Reference in New Issue
Block a user