diff --git a/examples/research_projects/autoencoder_rae/README.md b/examples/research_projects/autoencoder_rae/README.md index 559eb37518..c6ffe77112 100644 --- a/examples/research_projects/autoencoder_rae/README.md +++ b/examples/research_projects/autoencoder_rae/README.md @@ -26,6 +26,7 @@ accelerate launch examples/research_projects/autoencoder_rae/train_autoencoder_r --train_batch_size 8 \ --learning_rate 1e-4 \ --num_train_epochs 10 \ + --report_to wandb \ --reconstruction_loss_type l1 \ --use_encoder_loss \ --encoder_loss_weight 0.1 diff --git a/examples/research_projects/autoencoder_rae/train_autoencoder_rae.py b/examples/research_projects/autoencoder_rae/train_autoencoder_rae.py index 4b0a2c5516..72f030a428 100644 --- a/examples/research_projects/autoencoder_rae/train_autoencoder_rae.py +++ b/examples/research_projects/autoencoder_rae/train_autoencoder_rae.py @@ -23,6 +23,7 @@ from pathlib import Path import torch import torch.nn.functional as F from accelerate import Accelerator +from accelerate.logging import get_logger from accelerate.utils import ProjectConfiguration, set_seed from torch.utils.data import DataLoader from torchvision import transforms @@ -33,7 +34,7 @@ from diffusers import AutoencoderRAE from diffusers.optimization import get_scheduler -logger = logging.getLogger(__name__) +logger = get_logger(__name__) def parse_args(): @@ -104,7 +105,7 @@ def parse_args(): parser.add_argument( "--use_encoder_loss", action="store_true", - help="Enable encoder feature consistency loss in model forward.", + help="Enable encoder feature consistency loss term in the training loop.", ) parser.add_argument("--report_to", type=str, default="tensorboard") @@ -122,9 +123,7 @@ def build_transforms(args): return transforms.Compose(image_transforms) -def compute_losses( - model, pixel_values, reconstruction_loss_type: str, use_encoder_loss: bool, encoder_loss_weight: float -): +def compute_losses(model, pixel_values, reconstruction_loss_type: str, use_encoder_loss: bool, encoder_loss_weight: float): decoded = model(pixel_values).sample if decoded.shape[-2:] != pixel_values.shape[-2:]: @@ -140,10 +139,12 @@ def compute_losses( encoder_loss = torch.zeros_like(reconstruction_loss) if use_encoder_loss and encoder_loss_weight > 0: - target_tokens = model._encode_tokens( - model._maybe_resize_and_normalize(pixel_values), requires_grad=False - ).detach() - reconstructed_tokens = model._encode_tokens(model._maybe_resize_and_normalize(decoded), requires_grad=True) + base_model = model.module if hasattr(model, "module") else model + target_encoder_input = base_model._maybe_resize_and_normalize(pixel_values) + reconstructed_encoder_input = base_model._maybe_resize_and_normalize(decoded) + + target_tokens = base_model.encoder(target_encoder_input, requires_grad=False).detach() + reconstructed_tokens = base_model.encoder(reconstructed_encoder_input, requires_grad=True) encoder_loss = F.mse_loss(reconstructed_tokens.float(), target_tokens.float()) loss = reconstruction_loss + float(encoder_loss_weight) * encoder_loss