fix training script

This commit is contained in:
Kashif Rasul
2026-02-16 13:00:00 +00:00
parent a4fc9f64b2
commit d06b501850
2 changed files with 11 additions and 9 deletions

View File

@@ -26,6 +26,7 @@ accelerate launch examples/research_projects/autoencoder_rae/train_autoencoder_r
--train_batch_size 8 \
--learning_rate 1e-4 \
--num_train_epochs 10 \
--report_to wandb \
--reconstruction_loss_type l1 \
--use_encoder_loss \
--encoder_loss_weight 0.1

View File

@@ -23,6 +23,7 @@ from pathlib import Path
import torch
import torch.nn.functional as F
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed
from torch.utils.data import DataLoader
from torchvision import transforms
@@ -33,7 +34,7 @@ from diffusers import AutoencoderRAE
from diffusers.optimization import get_scheduler
logger = logging.getLogger(__name__)
logger = get_logger(__name__)
def parse_args():
@@ -104,7 +105,7 @@ def parse_args():
parser.add_argument(
"--use_encoder_loss",
action="store_true",
help="Enable encoder feature consistency loss in model forward.",
help="Enable encoder feature consistency loss term in the training loop.",
)
parser.add_argument("--report_to", type=str, default="tensorboard")
@@ -122,9 +123,7 @@ def build_transforms(args):
return transforms.Compose(image_transforms)
def compute_losses(
model, pixel_values, reconstruction_loss_type: str, use_encoder_loss: bool, encoder_loss_weight: float
):
def compute_losses(model, pixel_values, reconstruction_loss_type: str, use_encoder_loss: bool, encoder_loss_weight: float):
decoded = model(pixel_values).sample
if decoded.shape[-2:] != pixel_values.shape[-2:]:
@@ -140,10 +139,12 @@ def compute_losses(
encoder_loss = torch.zeros_like(reconstruction_loss)
if use_encoder_loss and encoder_loss_weight > 0:
target_tokens = model._encode_tokens(
model._maybe_resize_and_normalize(pixel_values), requires_grad=False
).detach()
reconstructed_tokens = model._encode_tokens(model._maybe_resize_and_normalize(decoded), requires_grad=True)
base_model = model.module if hasattr(model, "module") else model
target_encoder_input = base_model._maybe_resize_and_normalize(pixel_values)
reconstructed_encoder_input = base_model._maybe_resize_and_normalize(decoded)
target_tokens = base_model.encoder(target_encoder_input, requires_grad=False).detach()
reconstructed_tokens = base_model.encoder(reconstructed_encoder_input, requires_grad=True)
encoder_loss = F.mse_loss(reconstructed_tokens.float(), target_tokens.float())
loss = reconstruction_loss + float(encoder_loss_weight) * encoder_loss