Compare commits

...

1 Commits

Author SHA1 Message Date
Sayak Paul
a7d1171ea3 incorporate memory logging. 2023-06-20 15:23:56 +05:30

View File

@@ -21,10 +21,12 @@ import logging
import math
import os
import shutil
import threading
import warnings
from pathlib import Path
import numpy as np
import psutil
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
@@ -596,6 +598,61 @@ def collate_fn(examples, with_prior_preservation=False):
return batch
# Converting Bytes to Megabytes
def b2mb(x):
return int(x / 2**20)
# From:
# https://github.com/huggingface/peft/blob/main/examples/lora_dreambooth/train_dreambooth.py
# This context manager is used to track the peak memory usage of the process
class TorchTracemalloc:
def __enter__(self):
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated() # reset the peak gauge to zero
self.begin = torch.cuda.memory_allocated()
self.process = psutil.Process()
self.cpu_begin = self.cpu_mem_used()
self.peak_monitoring = True
peak_monitor_thread = threading.Thread(target=self.peak_monitor_func)
peak_monitor_thread.daemon = True
peak_monitor_thread.start()
return self
def cpu_mem_used(self):
"""get resident set size memory for the current process"""
return self.process.memory_info().rss
def peak_monitor_func(self):
self.cpu_peak = -1
while True:
self.cpu_peak = max(self.cpu_mem_used(), self.cpu_peak)
# can't sleep or will not catch the peak right (this comment is here on purpose)
# time.sleep(0.001) # 1msec
if not self.peak_monitoring:
break
def __exit__(self, *exc):
self.peak_monitoring = False
gc.collect()
torch.cuda.empty_cache()
self.end = torch.cuda.memory_allocated()
self.peak = torch.cuda.max_memory_allocated()
self.used = b2mb(self.end - self.begin)
self.peaked = b2mb(self.peak - self.begin)
self.cpu_end = self.cpu_mem_used()
self.cpu_used = b2mb(self.cpu_end - self.cpu_begin)
self.cpu_peaked = b2mb(self.cpu_peak - self.cpu_begin)
# print(f"delta used/peak {self.used:4d}/{self.peaked:4d}")
class PromptDataset(Dataset):
"A simple dataset to prepare the prompts to generate class images on multiple GPUs."
@@ -1113,139 +1170,159 @@ def main(args):
unet.train()
if args.train_text_encoder:
text_encoder.train()
for step, batch in enumerate(train_dataloader):
# Skip steps until we reach the resumed step
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
if step % args.gradient_accumulation_steps == 0:
progress_bar.update(1)
continue
with accelerator.accumulate(unet):
pixel_values = batch["pixel_values"].to(dtype=weight_dtype)
with TorchTracemalloc() as tracemalloc:
for step, batch in enumerate(train_dataloader):
# Skip steps until we reach the resumed step
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
if step % args.gradient_accumulation_steps == 0:
progress_bar.update(1)
continue
if vae is not None:
# Convert images to latent space
model_input = vae.encode(pixel_values).latent_dist.sample()
model_input = model_input * vae.config.scaling_factor
else:
model_input = pixel_values
with accelerator.accumulate(unet):
pixel_values = batch["pixel_values"].to(dtype=weight_dtype)
# Sample noise that we'll add to the latents
noise = torch.randn_like(model_input)
bsz, channels, height, width = model_input.shape
# Sample a random timestep for each image
timesteps = torch.randint(
0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device
)
timesteps = timesteps.long()
if vae is not None:
# Convert images to latent space
model_input = vae.encode(pixel_values).latent_dist.sample()
model_input = model_input * vae.config.scaling_factor
else:
model_input = pixel_values
# Add noise to the model input according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)
# Get the text embedding for conditioning
if args.pre_compute_text_embeddings:
encoder_hidden_states = batch["input_ids"]
else:
encoder_hidden_states = encode_prompt(
text_encoder,
batch["input_ids"],
batch["attention_mask"],
text_encoder_use_attention_mask=args.text_encoder_use_attention_mask,
# Sample noise that we'll add to the latents
noise = torch.randn_like(model_input)
bsz, channels, height, width = model_input.shape
# Sample a random timestep for each image
timesteps = torch.randint(
0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device
)
timesteps = timesteps.long()
if accelerator.unwrap_model(unet).config.in_channels == channels * 2:
noisy_model_input = torch.cat([noisy_model_input, noisy_model_input], dim=1)
# Add noise to the model input according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)
if args.class_labels_conditioning == "timesteps":
class_labels = timesteps
else:
class_labels = None
# Get the text embedding for conditioning
if args.pre_compute_text_embeddings:
encoder_hidden_states = batch["input_ids"]
else:
encoder_hidden_states = encode_prompt(
text_encoder,
batch["input_ids"],
batch["attention_mask"],
text_encoder_use_attention_mask=args.text_encoder_use_attention_mask,
)
# Predict the noise residual
model_pred = unet(
noisy_model_input, timesteps, encoder_hidden_states, class_labels=class_labels
).sample
if accelerator.unwrap_model(unet).config.in_channels == channels * 2:
noisy_model_input = torch.cat([noisy_model_input, noisy_model_input], dim=1)
# if model predicts variance, throw away the prediction. we will only train on the
# simplified training objective. This means that all schedulers using the fine tuned
# model must be configured to use one of the fixed variance variance types.
if model_pred.shape[1] == 6:
model_pred, _ = torch.chunk(model_pred, 2, dim=1)
if args.class_labels_conditioning == "timesteps":
class_labels = timesteps
else:
class_labels = None
# Get the target for loss depending on the prediction type
if noise_scheduler.config.prediction_type == "epsilon":
target = noise
elif noise_scheduler.config.prediction_type == "v_prediction":
target = noise_scheduler.get_velocity(model_input, noise, timesteps)
else:
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
# Predict the noise residual
model_pred = unet(
noisy_model_input, timesteps, encoder_hidden_states, class_labels=class_labels
).sample
if args.with_prior_preservation:
# Chunk the noise and model_pred into two parts and compute the loss on each part separately.
model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
target, target_prior = torch.chunk(target, 2, dim=0)
# if model predicts variance, throw away the prediction. we will only train on the
# simplified training objective. This means that all schedulers using the fine tuned
# model must be configured to use one of the fixed variance variance types.
if model_pred.shape[1] == 6:
model_pred, _ = torch.chunk(model_pred, 2, dim=1)
# Compute instance loss
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
# Get the target for loss depending on the prediction type
if noise_scheduler.config.prediction_type == "epsilon":
target = noise
elif noise_scheduler.config.prediction_type == "v_prediction":
target = noise_scheduler.get_velocity(model_input, noise, timesteps)
else:
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
# Compute prior loss
prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
if args.with_prior_preservation:
# Chunk the noise and model_pred into two parts and compute the loss on each part separately.
model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
target, target_prior = torch.chunk(target, 2, dim=0)
# Add the prior loss to the instance loss.
loss = loss + args.prior_loss_weight * prior_loss
else:
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
# Compute instance loss
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
accelerator.backward(loss)
# Compute prior loss
prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
# Add the prior loss to the instance loss.
loss = loss + args.prior_loss_weight * prior_loss
else:
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
accelerator.backward(loss)
if accelerator.sync_gradients:
params_to_clip = (
itertools.chain(unet_lora_layers.parameters(), text_encoder_lora_layers.parameters())
if args.train_text_encoder
else unet_lora_layers.parameters()
)
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
params_to_clip = (
itertools.chain(unet_lora_layers.parameters(), text_encoder_lora_layers.parameters())
if args.train_text_encoder
else unet_lora_layers.parameters()
)
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
progress_bar.update(1)
global_step += 1
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
if accelerator.is_main_process:
if global_step % args.checkpointing_steps == 0:
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
if args.checkpoints_total_limit is not None:
checkpoints = os.listdir(args.output_dir)
checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
if accelerator.is_main_process:
if global_step % args.checkpointing_steps == 0:
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
if args.checkpoints_total_limit is not None:
checkpoints = os.listdir(args.output_dir)
checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
# before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
if len(checkpoints) >= args.checkpoints_total_limit:
num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
removing_checkpoints = checkpoints[0:num_to_remove]
# before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
if len(checkpoints) >= args.checkpoints_total_limit:
num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
removing_checkpoints = checkpoints[0:num_to_remove]
logger.info(
f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
)
logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
logger.info(
f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
)
logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
for removing_checkpoint in removing_checkpoints:
removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
shutil.rmtree(removing_checkpoint)
for removing_checkpoint in removing_checkpoints:
removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
shutil.rmtree(removing_checkpoint)
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
accelerator.save_state(save_path)
logger.info(f"Saved state to {save_path}")
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
accelerator.save_state(save_path)
logger.info(f"Saved state to {save_path}")
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
accelerator.log(logs, step=global_step)
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
accelerator.log(logs, step=global_step)
if global_step >= args.max_train_steps:
break
if global_step >= args.max_train_steps:
break
# Printing the GPU memory usage details such as allocated memory, peak memory, and total memory usage
accelerator.print("GPU Memory before entering the train : {}".format(b2mb(tracemalloc.begin)))
accelerator.print("GPU Memory consumed at the end of the train (end-begin): {}".format(tracemalloc.used))
accelerator.print("GPU Peak Memory consumed during the train (max-begin): {}".format(tracemalloc.peaked))
accelerator.print(
"GPU Total Peak Memory consumed during the train (max): {}".format(
tracemalloc.peaked + b2mb(tracemalloc.begin)
)
)
accelerator.print("CPU Memory before entering the train : {}".format(b2mb(tracemalloc.cpu_begin)))
accelerator.print("CPU Memory consumed at the end of the train (end-begin): {}".format(tracemalloc.cpu_used))
accelerator.print("CPU Peak Memory consumed during the train (max-begin): {}".format(tracemalloc.cpu_peaked))
accelerator.print(
"CPU Total Peak Memory consumed during the train (max): {}".format(
tracemalloc.cpu_peaked + b2mb(tracemalloc.cpu_begin)
)
)
if accelerator.is_main_process:
if args.validation_prompt is not None and epoch % args.validation_epochs == 0: