Files
diffusers/scripts/convert_rae_to_diffusers.py
Ando 8ec0a5ccad feat: implement rae autoencoder. (#13046)
* feat: implement three RAE encoders(dinov2, siglip2, mae)

* feat: finish first version of autoencoder_rae

* fix formatting

* make fix-copies

* initial doc

* fix latent_mean / latent_var init types to accept config-friendly inputs

* use mean and std convention

* cleanup

* add rae to diffusers script

* use imports

* use attention

* remove unneeded class

* example traiing script

* input and ground truth sizes have to be the same

* fix argument

* move loss to training script

* cleanup

* simplify mixins

* fix training script

* fix entrypoint for instantiating the AutoencoderRAE

* added encoder_image_size config

* undo last change

* fixes from pretrained weights

* cleanups

* address reviews

* fix train script to use pretrained

* fix conversion script review

* latebt normalization buffers are now always registered with no-op defaults

* Update examples/research_projects/autoencoder_rae/README.md

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* Update src/diffusers/models/autoencoders/autoencoder_rae.py

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* use image url

* Encoder is frozen

* fix slow test

* remove config

* use ModelTesterMixin and AutoencoderTesterMixin

* make quality

* strip final layernorm when converting

* _strip_final_layernorm_affine for training script

* fix test

* add dispatch forward and update conversion script

* update training script

* error out as soon as possible and add comments

* Update src/diffusers/models/autoencoders/autoencoder_rae.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* use buffer

* inline

* Update src/diffusers/models/autoencoders/autoencoder_rae.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* remove optional

* _noising takes a generator

* Update src/diffusers/models/autoencoders/autoencoder_rae.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* fix api

* rename

* remove unittest

* use randn_tensor

* fix device map on multigpu

* check if the key is missing in the original state dict and only then add to the allow_missing set

* remove initialize_weights

---------

Co-authored-by: wangyuqi <wangyuqi@MBP-FJDQNJTWYN-0208.local>
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>
2026-03-05 20:17:14 +05:30

407 lines
16 KiB
Python

import argparse
from pathlib import Path
from typing import Any
import torch
from huggingface_hub import HfApi, hf_hub_download
from diffusers import AutoencoderRAE
DECODER_CONFIGS = {
"ViTB": {
"decoder_hidden_size": 768,
"decoder_intermediate_size": 3072,
"decoder_num_attention_heads": 12,
"decoder_num_hidden_layers": 12,
},
"ViTL": {
"decoder_hidden_size": 1024,
"decoder_intermediate_size": 4096,
"decoder_num_attention_heads": 16,
"decoder_num_hidden_layers": 24,
},
"ViTXL": {
"decoder_hidden_size": 1152,
"decoder_intermediate_size": 4096,
"decoder_num_attention_heads": 16,
"decoder_num_hidden_layers": 28,
},
}
ENCODER_DEFAULT_NAME_OR_PATH = {
"dinov2": "facebook/dinov2-with-registers-base",
"siglip2": "google/siglip2-base-patch16-256",
"mae": "facebook/vit-mae-base",
}
ENCODER_HIDDEN_SIZE = {
"dinov2": 768,
"siglip2": 768,
"mae": 768,
}
ENCODER_PATCH_SIZE = {
"dinov2": 14,
"siglip2": 16,
"mae": 16,
}
DEFAULT_DECODER_SUBDIR = {
"dinov2": "decoders/dinov2/wReg_base",
"mae": "decoders/mae/base_p16",
"siglip2": "decoders/siglip2/base_p16_i256",
}
DEFAULT_STATS_SUBDIR = {
"dinov2": "stats/dinov2/wReg_base",
"mae": "stats/mae/base_p16",
"siglip2": "stats/siglip2/base_p16_i256",
}
DECODER_FILE_CANDIDATES = ("dinov2_decoder.pt", "model.pt")
STATS_FILE_CANDIDATES = ("stat.pt",)
def dataset_case_candidates(name: str) -> tuple[str, ...]:
return (name, name.lower(), name.upper(), name.title(), "imagenet1k", "ImageNet1k")
class RepoAccessor:
def __init__(self, repo_or_path: str, cache_dir: str | None = None):
self.repo_or_path = repo_or_path
self.cache_dir = cache_dir
self.local_root: Path | None = None
self.repo_id: str | None = None
self.repo_files: set[str] | None = None
root = Path(repo_or_path)
if root.exists() and root.is_dir():
self.local_root = root
else:
self.repo_id = repo_or_path
self.repo_files = set(HfApi().list_repo_files(repo_or_path))
def exists(self, relative_path: str) -> bool:
relative_path = relative_path.replace("\\", "/")
if self.local_root is not None:
return (self.local_root / relative_path).is_file()
return relative_path in self.repo_files
def fetch(self, relative_path: str) -> Path:
relative_path = relative_path.replace("\\", "/")
if self.local_root is not None:
return self.local_root / relative_path
downloaded = hf_hub_download(repo_id=self.repo_id, filename=relative_path, cache_dir=self.cache_dir)
return Path(downloaded)
def unwrap_state_dict(maybe_wrapped: dict[str, Any]) -> dict[str, Any]:
state_dict = maybe_wrapped
for k in ("model", "module", "state_dict"):
if isinstance(state_dict, dict) and k in state_dict and isinstance(state_dict[k], dict):
state_dict = state_dict[k]
out = dict(state_dict)
if len(out) > 0 and all(key.startswith("module.") for key in out):
out = {key[len("module.") :]: value for key, value in out.items()}
if len(out) > 0 and all(key.startswith("decoder.") for key in out):
out = {key[len("decoder.") :]: value for key, value in out.items()}
return out
def remap_decoder_attention_keys_for_diffusers(state_dict: dict[str, Any]) -> dict[str, Any]:
"""
Map official RAE decoder attention key layout to diffusers Attention layout used by AutoencoderRAE decoder.
Example mappings:
- `...attention.attention.query.*` -> `...attention.to_q.*`
- `...attention.attention.key.*` -> `...attention.to_k.*`
- `...attention.attention.value.*` -> `...attention.to_v.*`
- `...attention.output.dense.*` -> `...attention.to_out.0.*`
"""
remapped: dict[str, Any] = {}
for key, value in state_dict.items():
new_key = key
new_key = new_key.replace(".attention.attention.query.", ".attention.to_q.")
new_key = new_key.replace(".attention.attention.key.", ".attention.to_k.")
new_key = new_key.replace(".attention.attention.value.", ".attention.to_v.")
new_key = new_key.replace(".attention.output.dense.", ".attention.to_out.0.")
remapped[new_key] = value
return remapped
def resolve_decoder_file(
accessor: RepoAccessor, encoder_type: str, variant: str, decoder_checkpoint: str | None
) -> str:
if decoder_checkpoint is not None:
if accessor.exists(decoder_checkpoint):
return decoder_checkpoint
raise FileNotFoundError(f"Decoder checkpoint not found: {decoder_checkpoint}")
base = f"{DEFAULT_DECODER_SUBDIR[encoder_type]}/{variant}"
for name in DECODER_FILE_CANDIDATES:
candidate = f"{base}/{name}"
if accessor.exists(candidate):
return candidate
raise FileNotFoundError(
f"Could not find decoder checkpoint under `{base}`. Tried: {list(DECODER_FILE_CANDIDATES)}"
)
def resolve_stats_file(
accessor: RepoAccessor,
encoder_type: str,
dataset_name: str,
stats_checkpoint: str | None,
) -> str | None:
if stats_checkpoint is not None:
if accessor.exists(stats_checkpoint):
return stats_checkpoint
raise FileNotFoundError(f"Stats checkpoint not found: {stats_checkpoint}")
base = DEFAULT_STATS_SUBDIR[encoder_type]
for dataset in dataset_case_candidates(dataset_name):
for name in STATS_FILE_CANDIDATES:
candidate = f"{base}/{dataset}/{name}"
if accessor.exists(candidate):
return candidate
return None
def extract_latent_stats(stats_obj: Any) -> tuple[Any | None, Any | None]:
if not isinstance(stats_obj, dict):
return None, None
if "latents_mean" in stats_obj or "latents_std" in stats_obj:
return stats_obj.get("latents_mean", None), stats_obj.get("latents_std", None)
mean = stats_obj.get("mean", None)
var = stats_obj.get("var", None)
if mean is None and var is None:
return None, None
latents_std = None
if var is not None:
if isinstance(var, torch.Tensor):
latents_std = torch.sqrt(var + 1e-5)
else:
latents_std = torch.sqrt(torch.tensor(var) + 1e-5)
return mean, latents_std
def _strip_final_layernorm_affine(state_dict: dict[str, Any], prefix: str = "") -> dict[str, Any]:
"""Remove final layernorm weight/bias from encoder state dict.
RAE uses non-affine layernorm (weight=1, bias=0 is the default identity).
Stripping these keys means the model keeps its default init values, which
is functionally equivalent to setting elementwise_affine=False.
"""
keys_to_strip = {f"{prefix}weight", f"{prefix}bias"}
return {k: v for k, v in state_dict.items() if k not in keys_to_strip}
def _load_hf_encoder_state_dict(encoder_type: str, encoder_name_or_path: str) -> dict[str, Any]:
"""Download the HF encoder and extract the state dict for the inner model."""
if encoder_type == "dinov2":
from transformers import Dinov2WithRegistersModel
hf_model = Dinov2WithRegistersModel.from_pretrained(encoder_name_or_path)
sd = hf_model.state_dict()
return _strip_final_layernorm_affine(sd, prefix="layernorm.")
elif encoder_type == "siglip2":
from transformers import SiglipModel
# SiglipModel.vision_model is a SiglipVisionTransformer.
# Our Siglip2Encoder wraps it inside SiglipVisionModel which nests it
# under .vision_model, so we add the prefix to match the diffusers key layout.
hf_model = SiglipModel.from_pretrained(encoder_name_or_path).vision_model
sd = {f"vision_model.{k}": v for k, v in hf_model.state_dict().items()}
return _strip_final_layernorm_affine(sd, prefix="vision_model.post_layernorm.")
elif encoder_type == "mae":
from transformers import ViTMAEForPreTraining
hf_model = ViTMAEForPreTraining.from_pretrained(encoder_name_or_path).vit
sd = hf_model.state_dict()
return _strip_final_layernorm_affine(sd, prefix="layernorm.")
else:
raise ValueError(f"Unknown encoder_type: {encoder_type}")
def convert(args: argparse.Namespace) -> None:
accessor = RepoAccessor(args.repo_or_path, cache_dir=args.cache_dir)
encoder_name_or_path = args.encoder_name_or_path or ENCODER_DEFAULT_NAME_OR_PATH[args.encoder_type]
decoder_relpath = resolve_decoder_file(accessor, args.encoder_type, args.variant, args.decoder_checkpoint)
stats_relpath = resolve_stats_file(accessor, args.encoder_type, args.dataset_name, args.stats_checkpoint)
print(f"Using decoder checkpoint: {decoder_relpath}")
if stats_relpath is not None:
print(f"Using stats checkpoint: {stats_relpath}")
else:
print("No stats checkpoint found; conversion will proceed without latent stats.")
if args.dry_run:
return
decoder_path = accessor.fetch(decoder_relpath)
decoder_obj = torch.load(decoder_path, map_location="cpu")
decoder_state_dict = unwrap_state_dict(decoder_obj)
decoder_state_dict = remap_decoder_attention_keys_for_diffusers(decoder_state_dict)
latents_mean, latents_std = None, None
if stats_relpath is not None:
stats_path = accessor.fetch(stats_relpath)
stats_obj = torch.load(stats_path, map_location="cpu")
latents_mean, latents_std = extract_latent_stats(stats_obj)
decoder_cfg = DECODER_CONFIGS[args.decoder_config_name]
# Read encoder normalization stats from the HF image processor (only place that downloads encoder info)
from transformers import AutoConfig, AutoImageProcessor
proc = AutoImageProcessor.from_pretrained(encoder_name_or_path)
encoder_norm_mean = list(proc.image_mean)
encoder_norm_std = list(proc.image_std)
# Read encoder hidden size and patch size from HF config
encoder_hidden_size = ENCODER_HIDDEN_SIZE[args.encoder_type]
encoder_patch_size = ENCODER_PATCH_SIZE[args.encoder_type]
try:
hf_config = AutoConfig.from_pretrained(encoder_name_or_path)
# For models like SigLIP that nest vision config
if hasattr(hf_config, "vision_config"):
hf_config = hf_config.vision_config
encoder_hidden_size = hf_config.hidden_size
encoder_patch_size = hf_config.patch_size
except Exception:
pass
# Load the actual encoder weights from HF to include in the saved model
encoder_state_dict = _load_hf_encoder_state_dict(args.encoder_type, encoder_name_or_path)
# Build model on meta device to avoid double init overhead
with torch.device("meta"):
model = AutoencoderRAE(
encoder_type=args.encoder_type,
encoder_hidden_size=encoder_hidden_size,
encoder_patch_size=encoder_patch_size,
encoder_input_size=args.encoder_input_size,
patch_size=args.patch_size,
image_size=args.image_size,
num_channels=args.num_channels,
encoder_norm_mean=encoder_norm_mean,
encoder_norm_std=encoder_norm_std,
decoder_hidden_size=decoder_cfg["decoder_hidden_size"],
decoder_num_hidden_layers=decoder_cfg["decoder_num_hidden_layers"],
decoder_num_attention_heads=decoder_cfg["decoder_num_attention_heads"],
decoder_intermediate_size=decoder_cfg["decoder_intermediate_size"],
latents_mean=latents_mean,
latents_std=latents_std,
scaling_factor=args.scaling_factor,
)
# Assemble full state dict and load with assign=True
full_state_dict = {}
# Encoder weights (prefixed with "encoder.")
for k, v in encoder_state_dict.items():
full_state_dict[f"encoder.{k}"] = v
# Decoder weights (prefixed with "decoder.")
for k, v in decoder_state_dict.items():
full_state_dict[f"decoder.{k}"] = v
# Buffers from config
full_state_dict["encoder_mean"] = torch.tensor(encoder_norm_mean, dtype=torch.float32).view(1, 3, 1, 1)
full_state_dict["encoder_std"] = torch.tensor(encoder_norm_std, dtype=torch.float32).view(1, 3, 1, 1)
if latents_mean is not None:
latents_mean_t = latents_mean if isinstance(latents_mean, torch.Tensor) else torch.tensor(latents_mean)
full_state_dict["_latents_mean"] = latents_mean_t
else:
full_state_dict["_latents_mean"] = torch.zeros(1)
if latents_std is not None:
latents_std_t = latents_std if isinstance(latents_std, torch.Tensor) else torch.tensor(latents_std)
full_state_dict["_latents_std"] = latents_std_t
else:
full_state_dict["_latents_std"] = torch.ones(1)
model.load_state_dict(full_state_dict, strict=False, assign=True)
# Verify no critical keys are missing
model_keys = {name for name, _ in model.named_parameters()}
model_keys |= {name for name, _ in model.named_buffers()}
loaded_keys = set(full_state_dict.keys())
missing = model_keys - loaded_keys
# decoder_pos_embed is initialized in-model. trainable_cls_token is only
# allowed to be missing if it was absent in the source decoder checkpoint.
allowed_missing = {"decoder.decoder_pos_embed"}
if "trainable_cls_token" not in decoder_state_dict:
allowed_missing.add("decoder.trainable_cls_token")
if missing - allowed_missing:
print(f"Warning: missing keys after conversion: {sorted(missing - allowed_missing)}")
output_path = Path(args.output_path)
output_path.mkdir(parents=True, exist_ok=True)
model.save_pretrained(output_path)
if args.verify_load:
print("Verifying converted checkpoint with AutoencoderRAE.from_pretrained(low_cpu_mem_usage=False)...")
loaded_model = AutoencoderRAE.from_pretrained(output_path, low_cpu_mem_usage=False)
if not isinstance(loaded_model, AutoencoderRAE):
raise RuntimeError("Verification failed: loaded object is not AutoencoderRAE.")
print("Verification passed.")
print(f"Saved converted AutoencoderRAE to: {output_path}")
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Convert RAE decoder checkpoints to diffusers AutoencoderRAE format")
parser.add_argument(
"--repo_or_path", type=str, required=True, help="Hub repo id (e.g. nyu-visionx/RAE-collections) or local path"
)
parser.add_argument("--output_path", type=str, required=True, help="Directory to save converted model")
parser.add_argument("--encoder_type", type=str, choices=["dinov2", "mae", "siglip2"], required=True)
parser.add_argument(
"--encoder_name_or_path", type=str, default=None, help="Optional encoder HF model id or local path override"
)
parser.add_argument("--variant", type=str, default="ViTXL_n08", help="Decoder variant folder name")
parser.add_argument("--dataset_name", type=str, default="imagenet1k", help="Stats dataset folder name")
parser.add_argument(
"--decoder_checkpoint", type=str, default=None, help="Relative path to decoder checkpoint inside repo/path"
)
parser.add_argument(
"--stats_checkpoint", type=str, default=None, help="Relative path to stats checkpoint inside repo/path"
)
parser.add_argument("--decoder_config_name", type=str, choices=list(DECODER_CONFIGS.keys()), default="ViTXL")
parser.add_argument("--encoder_input_size", type=int, default=224)
parser.add_argument("--patch_size", type=int, default=16)
parser.add_argument("--image_size", type=int, default=None)
parser.add_argument("--num_channels", type=int, default=3)
parser.add_argument("--scaling_factor", type=float, default=1.0)
parser.add_argument("--cache_dir", type=str, default=None)
parser.add_argument("--dry_run", action="store_true", help="Only resolve and print selected files")
parser.add_argument(
"--verify_load",
action="store_true",
help="After conversion, load back with AutoencoderRAE.from_pretrained(low_cpu_mem_usage=False).",
)
return parser.parse_args()
def main() -> None:
args = parse_args()
convert(args)
if __name__ == "__main__":
main()