mirror of
https://github.com/huggingface/diffusers.git
synced 2026-03-10 18:51:46 +08:00
fix training script
This commit is contained in:
@@ -26,6 +26,7 @@ accelerate launch examples/research_projects/autoencoder_rae/train_autoencoder_r
|
||||
--train_batch_size 8 \
|
||||
--learning_rate 1e-4 \
|
||||
--num_train_epochs 10 \
|
||||
--report_to wandb \
|
||||
--reconstruction_loss_type l1 \
|
||||
--use_encoder_loss \
|
||||
--encoder_loss_weight 0.1
|
||||
|
||||
@@ -23,6 +23,7 @@ from pathlib import Path
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from accelerate import Accelerator
|
||||
from accelerate.logging import get_logger
|
||||
from accelerate.utils import ProjectConfiguration, set_seed
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision import transforms
|
||||
@@ -33,7 +34,7 @@ from diffusers import AutoencoderRAE
|
||||
from diffusers.optimization import get_scheduler
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def parse_args():
|
||||
@@ -104,7 +105,7 @@ def parse_args():
|
||||
parser.add_argument(
|
||||
"--use_encoder_loss",
|
||||
action="store_true",
|
||||
help="Enable encoder feature consistency loss in model forward.",
|
||||
help="Enable encoder feature consistency loss term in the training loop.",
|
||||
)
|
||||
parser.add_argument("--report_to", type=str, default="tensorboard")
|
||||
|
||||
@@ -122,9 +123,7 @@ def build_transforms(args):
|
||||
return transforms.Compose(image_transforms)
|
||||
|
||||
|
||||
def compute_losses(
|
||||
model, pixel_values, reconstruction_loss_type: str, use_encoder_loss: bool, encoder_loss_weight: float
|
||||
):
|
||||
def compute_losses(model, pixel_values, reconstruction_loss_type: str, use_encoder_loss: bool, encoder_loss_weight: float):
|
||||
decoded = model(pixel_values).sample
|
||||
|
||||
if decoded.shape[-2:] != pixel_values.shape[-2:]:
|
||||
@@ -140,10 +139,12 @@ def compute_losses(
|
||||
|
||||
encoder_loss = torch.zeros_like(reconstruction_loss)
|
||||
if use_encoder_loss and encoder_loss_weight > 0:
|
||||
target_tokens = model._encode_tokens(
|
||||
model._maybe_resize_and_normalize(pixel_values), requires_grad=False
|
||||
).detach()
|
||||
reconstructed_tokens = model._encode_tokens(model._maybe_resize_and_normalize(decoded), requires_grad=True)
|
||||
base_model = model.module if hasattr(model, "module") else model
|
||||
target_encoder_input = base_model._maybe_resize_and_normalize(pixel_values)
|
||||
reconstructed_encoder_input = base_model._maybe_resize_and_normalize(decoded)
|
||||
|
||||
target_tokens = base_model.encoder(target_encoder_input, requires_grad=False).detach()
|
||||
reconstructed_tokens = base_model.encoder(reconstructed_encoder_input, requires_grad=True)
|
||||
encoder_loss = F.mse_loss(reconstructed_tokens.float(), target_tokens.float())
|
||||
|
||||
loss = reconstruction_loss + float(encoder_loss_weight) * encoder_loss
|
||||
|
||||
Reference in New Issue
Block a user