Compare commits

...

15 Commits

Author SHA1 Message Date
sayakpaul
3a95742edc fix 2023-11-27 15:09:35 +05:30
sayakpaul
6ff90f6c70 reduce batch size. 2023-11-27 14:57:23 +05:30
sayakpaul
b5e405168d command edit 2023-11-27 14:27:22 +05:30
sayakpaul
8dbc46dfa9 mkdir 2023-11-27 14:22:45 +05:30
sayakpaul
466553b885 mkdir 2023-11-27 14:17:46 +05:30
sayakpaul
3a5ef6c78f add: slurm script. 2023-11-17 16:45:12 +05:30
sayakpaul
2a64edcb2c fix 2023-11-17 14:47:33 +05:30
sayakpaul
ed2a52daf6 fix validation step 2023-11-17 14:41:50 +05:30
sayakpaul
148d6f9e58 fix validation stepping 2023-11-17 14:25:56 +05:30
sayakpaul
f35b76c523 fix: null embeddings 2023-11-17 14:13:08 +05:30
sayakpaul
d3eea16750 up 2023-11-17 14:08:50 +05:30
sayakpaul
1d486c95a1 up 2023-11-17 13:50:00 +05:30
sayakpaul
0fcff42916 partial up 2023-11-17 13:32:27 +05:30
sayakpaul
21fb55844e fix 2023-11-17 13:18:18 +05:30
sayakpaul
c8b88f8b31 initial 2023-11-17 12:59:52 +05:30
2 changed files with 420 additions and 247 deletions

View File

@@ -0,0 +1,124 @@
#!/bin/bash
#SBATCH --job-name=instruct-pix2pix-sdxl-emu
#SBATCH --nodes=1
#SBATCH --ntasks-per-node=1 # crucial - only 1 task per dist per node!
#SBATCH --cpus-per-task=96
#SBATCH --gres=gpu:8
#SBATCH --exclusive
#SBATCH --partition=production-cluster
#SBATCH --output=/admin/home/sayak/logs/instruct-pix2pix-sdxl-emu/%x-%j.out
set -x -e
source /admin/home/sayak/.bashrc
source /admin/home/sayak/miniconda3/etc/profile.d/conda.sh
conda activate diffusers
echo "START TIME: $(date)"
REPO=/fsx/sayak/diffusers/examples/instruct_pix2pix
OUTPUT_DIR=/fsx/sayak/instruct-pix2pix-sdxl-emu
LOG_PATH=$OUTPUT_DIR/main_log.txt
ACCELERATE_CONFIG_FILE="$OUTPUT_DIR/${SLURM_JOB_ID}_accelerate_config.yaml.autogenerated"
mkdir -p $OUTPUT_DIR
touch $LOG_PATH
pushd $REPO
GPUS_PER_NODE=8
NNODES=$SLURM_NNODES
NUM_GPUS=$((GPUS_PER_NODE*SLURM_NNODES))
# so processes know who to talk to
MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1)
MASTER_PORT=6000
# Auto-generate the accelerate config
cat << EOT > $ACCELERATE_CONFIG_FILE
compute_environment: LOCAL_MACHINE
deepspeed_config: {}
distributed_type: MULTI_GPU
fsdp_config: {}
machine_rank: 0
main_process_ip: $MASTER_ADDR
main_process_port: $MASTER_PORT
main_training_function: main
num_machines: $SLURM_NNODES
num_processes: $NUM_GPUS
use_cpu: false
EOT
export MODEL_ID="stabilityai/stable-diffusion-xl-base-1.0"
export DATASET_ID="facebook/emu_edit_test_set_generations"
PROGRAM="train_instruct_pix2pix_sdxl.py \
--pretrained_model_name_or_path=$MODEL_ID \
--pretrained_vae_model_name_or_path=madebyollin/sdxl-vae-fp16-fix \
--dataset_name=$DATASET_ID \
--original_image_column=image --edited_image_column=edited_image --edit_prompt_column=instruction \
--resolution=1024 \
--train_batch_size=8 --gradient_accumulation_steps=4 --gradient_checkpointing \
--dataloader_num_workers=8 \
--enable_xformers_memory_efficient_attention \
--max_train_steps=10000 \
--checkpointing_steps=2500 \
--learning_rate=1e-5 --lr_warmup_steps=0 \
--mixed_precision=fp16 \
--val_image_url_or_path='https://hf.co/datasets/diffusers/diffusers-images-docs/resolve/main/mountain.png' \
--validation_prompt='Turn sky into a cloudy one' \
--seed=42 \
--output_dir=$OUTPUT_DIR \
--report_to=wandb \
--push_to_hub
"
# Note: it is important to escape `$SLURM_PROCID` since we want the srun on each node to evaluate this variable
export LAUNCHER="accelerate launch \
--rdzv_conf "rdzv_backend=c10d,rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT,max_restarts=0,tee=3" \
--config_file $ACCELERATE_CONFIG_FILE \
--main_process_ip $MASTER_ADDR \
--main_process_port $MASTER_PORT \
--num_processes $NUM_GPUS \
--machine_rank \$SLURM_PROCID \
"
export CMD="$LAUNCHER $PROGRAM"
echo $CMD
# hide duplicated errors using this hack - will be properly fixed in pt-1.12
# export TORCHELASTIC_ERROR_FILE=/tmp/torch-elastic-error.json
# force crashing on nccl issues like hanging broadcast
export NCCL_ASYNC_ERROR_HANDLING=1
# export NCCL_DEBUG=INFO
# export NCCL_DEBUG_SUBSYS=COLL
# export NCCL_SOCKET_NTHREADS=1
# export NCCL_NSOCKS_PERTHREAD=1
# export CUDA_LAUNCH_BLOCKING=1
# AWS specific
export NCCL_PROTO=simple
export RDMAV_FORK_SAFE=1
export FI_EFA_FORK_SAFE=1
export FI_EFA_USE_DEVICE_RDMA=1
export FI_PROVIDER=efa
export FI_LOG_LEVEL=1
export NCCL_IB_DISABLE=1
export NCCL_SOCKET_IFNAME=ens
# srun error handling:
# --wait=60: wait 60 sec after the first task terminates before terminating all remaining tasks
# --kill-on-bad-exit=1: terminate a step if any task exits with a non-zero exit code
SRUN_ARGS=" \
--wait=60 \
--kill-on-bad-exit=1 \
"
clear; srun $SRUN_ARGS --jobid $SLURM_JOB_ID bash -c "$CMD" 2>&1 | tee $LOG_PATH
echo "END TIME: $(date)"

View File

@@ -18,6 +18,7 @@ import argparse
import logging
import math
import os
import random
import shutil
import warnings
from pathlib import Path
@@ -35,11 +36,12 @@ import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed
from datasets import load_dataset
from datasets import concatenate_datasets, load_dataset
from huggingface_hub import create_repo, upload_folder
from packaging import version
from PIL import Image
from torchvision import transforms
from torchvision.transforms.functional import crop
from tqdm.auto import tqdm
from transformers import AutoTokenizer, PretrainedConfig
@@ -54,6 +56,10 @@ from diffusers.utils import check_min_version, deprecate, is_wandb_available, lo
from diffusers.utils.import_utils import is_xformers_available
if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.24.0.dev0")
@@ -62,7 +68,7 @@ logger = get_logger(__name__, log_level="INFO")
DATASET_NAME_MAPPING = {
"fusing/instructpix2pix-1000-samples": ("file_name", "edited_image", "edit_prompt"),
}
WANDB_TABLE_COL_NAMES = ["file_name", "edited_image", "edit_prompt"]
WANDB_TABLE_COL_NAMES = ["original_image", "edited_image", "edit_prompt"]
TORCH_DTYPE_MAPPING = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}
@@ -86,6 +92,133 @@ def import_model_class_from_model_name_or_path(
raise ValueError(f"{model_class} is not supported.")
def tokenize_prompt(tokenizer, prompt):
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
return text_input_ids
# Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt
def encode_prompt(text_encoders, tokenizers, prompt, text_input_ids_list=None):
prompt_embeds_list = []
for i, text_encoder in enumerate(text_encoders):
if tokenizers is not None:
tokenizer = tokenizers[i]
text_input_ids = tokenize_prompt(tokenizer, prompt)
else:
assert text_input_ids_list is not None
text_input_ids = text_input_ids_list[i]
prompt_embeds = text_encoder(
text_input_ids.to(text_encoder.device),
output_hidden_states=True,
)
# We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.hidden_states[-2]
bs_embed, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
prompt_embeds_list.append(prompt_embeds)
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
return prompt_embeds, pooled_prompt_embeds
def log_validation(
vae,
unet,
text_encoder_1,
text_encoder_2,
tokenizer_1,
tokenizer_2,
args,
accelerator,
weight_dtype,
global_step,
):
logger.info(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
f" {args.validation_prompt}."
)
# The models need unwrapping because for compatibility in distributed training mode.
pipeline = StableDiffusionXLInstructPix2PixPipeline.from_pretrained(
args.pretrained_model_name_or_path,
unet=accelerator.unwrap_model(unet),
text_encoder=text_encoder_1,
text_encoder_2=text_encoder_2,
tokenizer=tokenizer_1,
tokenizer_2=tokenizer_2,
vae=vae,
revision=args.revision,
torch_dtype=weight_dtype,
)
pipeline = pipeline.to(accelerator.device)
pipeline.set_progress_bar_config(disable=True)
if args.enable_xformers_memory_efficient_attention:
pipeline.enable_xformers_memory_efficient_attention()
if args.seed is None:
generator = None
else:
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
# run inference
# Save validation images
val_save_dir = os.path.join(args.output_dir, "validation_images")
if not os.path.exists(val_save_dir):
os.makedirs(val_save_dir)
original_image = (
lambda image_url_or_path: load_image(image_url_or_path)
if urlparse(image_url_or_path).scheme
else Image.open(image_url_or_path).convert("RGB")
)(args.val_image_url_or_path)
original_image = original_image.resize((args.resolution, args.resolution))
with torch.autocast("cuda"):
edited_images = []
for val_img_idx in range(args.num_validation_images):
a_val_img = pipeline(
args.validation_prompt,
height=args.resolution,
width=args.resolution,
image=original_image,
num_inference_steps=25,
image_guidance_scale=1.5,
guidance_scale=5.0,
generator=generator,
).images[0]
edited_images.append(a_val_img)
a_val_img.save(
os.path.join(
val_save_dir,
f"step_{global_step}_val_img_{val_img_idx}.png",
)
)
formatted_images = [wandb.Image(original_image, caption="Original Image")]
for edited_image in edited_images:
formatted_images.append(wandb.Image(edited_image, caption=args.validation_prompt))
for tracker in accelerator.trackers:
if tracker.name == "wandb":
tracker.log({"validation": formatted_images})
del pipeline
torch.cuda.empty_cache()
def parse_args():
parser = argparse.ArgumentParser(description="Script to train Stable Diffusion XL for InstructPix2Pix.")
parser.add_argument(
@@ -177,15 +310,7 @@ def parse_args():
default=4,
help="Number of images that should be generated during validation with `validation_prompt`.",
)
parser.add_argument(
"--validation_steps",
type=int,
default=100,
help=(
"Run fine-tuning validation every X steps. The validation process consists of running the prompt"
" `args.validation_prompt` multiple times: `args.num_validation_images`."
),
)
parser.add_argument("--validation_epochs", type=int, default=1, help="Run fine-tuning validation every X epochs.")
parser.add_argument(
"--max_train_samples",
type=int,
@@ -198,7 +323,7 @@ def parse_args():
parser.add_argument(
"--output_dir",
type=str,
default="instruct-pix2pix-model",
default="instruct-pix2pix-sdxl",
help="The output directory where the model predictions and checkpoints will be written.",
)
parser.add_argument(
@@ -216,18 +341,6 @@ def parse_args():
"The resolution for input images, all the images in the train/validation dataset will be resized to this resolution."
),
)
parser.add_argument(
"--crops_coords_top_left_h",
type=int,
default=0,
help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."),
)
parser.add_argument(
"--crops_coords_top_left_w",
type=int,
default=0,
help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."),
)
parser.add_argument(
"--center_crop",
default=False,
@@ -443,7 +556,6 @@ def main():
if args.report_to == "wandb":
if not is_wandb_available():
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
import wandb
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
@@ -605,6 +717,7 @@ def main():
args.dataset_config_name,
cache_dir=args.cache_dir,
)
dataset = concatenate_datasets([dataset["validation"], dataset["test"]])
else:
data_files = {}
if args.train_data_dir is not None:
@@ -619,7 +732,7 @@ def main():
# Preprocessing the datasets.
# We need to tokenize inputs and targets.
column_names = dataset["train"].column_names
column_names = dataset.column_names
# 6. Get the column names for input/target.
dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)
@@ -659,40 +772,6 @@ def main():
weight_dtype = torch.bfloat16
warnings.warn(f"weight_dtype {weight_dtype} may cause nan during vae encoding", UserWarning)
# Preprocessing the datasets.
# We need to tokenize input captions and transform the images.
def tokenize_captions(captions, tokenizer):
inputs = tokenizer(
captions,
max_length=tokenizer.model_max_length,
padding="max_length",
truncation=True,
return_tensors="pt",
)
return inputs.input_ids
# Preprocessing the datasets.
train_transforms = transforms.Compose(
[
transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
]
)
def preprocess_images(examples):
original_images = np.concatenate(
[convert_to_np(image, args.resolution) for image in examples[original_image_column]]
)
edited_images = np.concatenate(
[convert_to_np(image, args.resolution) for image in examples[edited_image_column]]
)
# We need to ensure that the original and the edited images undergo the same
# augmentation transforms.
images = np.concatenate([original_images, edited_images])
images = torch.tensor(images)
images = 2 * (images / 255) - 1
return train_transforms(images)
# Load scheduler, tokenizer and models.
tokenizer_1 = AutoTokenizer.from_pretrained(
args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision, use_fast=False
@@ -729,132 +808,111 @@ def main():
# Set UNet to trainable.
unet.train()
# Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt
def encode_prompt(text_encoders, tokenizers, prompt):
prompt_embeds_list = []
for tokenizer, text_encoder in zip(tokenizers, text_encoders):
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
text_input_ids, untruncated_ids
):
removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {tokenizer.model_max_length} tokens: {removed_text}"
# Preprocessing the datasets.
# We need to tokenize input captions and transform the images.
# Preprocessing the datasets.
def tokenize_captions(examples, is_train=True):
captions = []
for caption in examples[edit_prompt_column]:
if isinstance(caption, str):
captions.append(caption)
elif isinstance(caption, (list, np.ndarray)):
# take a random caption if there are multiple
captions.append(random.choice(caption) if is_train else caption[0])
else:
raise ValueError(
f"Caption column `{edit_prompt_column}` should contain either strings or lists of strings."
)
tokens_one = tokenize_prompt(tokenizer_1, captions)
tokens_two = tokenize_prompt(tokenizer_2, captions)
return tokens_one, tokens_two
prompt_embeds = text_encoder(
text_input_ids.to(text_encoder.device),
output_hidden_states=True,
)
train_resize = transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR)
train_crop = transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution)
train_flip = transforms.RandomHorizontalFlip(p=1.0)
normalize = transforms.Normalize([0.5], [0.5])
# We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.hidden_states[-2]
bs_embed, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
prompt_embeds_list.append(prompt_embeds)
def preprocess_train(samples):
orig_images = [image.convert("RGB") for image in samples[original_image_column]]
edited_images = [image.convert("RGB") for image in samples[edited_image_column]]
resized_edited_images = []
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
return prompt_embeds, pooled_prompt_embeds
# Resize edited images if necessary.
for edited_image, orig_image in zip(edited_images, orig_images):
if edited_image.size != orig_image.size:
edited_image = edited_image.resize(orig_image.size)
resized_edited_images.append(edited_image)
else:
resized_edited_images.append(edited_image)
# Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt
def encode_prompts(text_encoders, tokenizers, prompts):
prompt_embeds_all = []
pooled_prompt_embeds_all = []
# Main image processing.
final_original_images = []
final_edited_images = []
original_sizes = []
crop_top_lefts = []
for edited_image, orig_image in zip(resized_edited_images, orig_images):
original_sizes.append((orig_image.height, orig_image.width))
for prompt in prompts:
prompt_embeds, pooled_prompt_embeds = encode_prompt(text_encoders, tokenizers, prompt)
prompt_embeds_all.append(prompt_embeds)
pooled_prompt_embeds_all.append(pooled_prompt_embeds)
images = torch.stack([transforms.ToTensor()(orig_image), transforms.ToTensor()(edited_image)])
images = train_resize(images)
if args.center_crop:
y1 = max(0, int(round((orig_image.height - args.resolution) / 2.0)))
x1 = max(0, int(round((orig_image.width - args.resolution) / 2.0)))
images = train_crop(images)
else:
y1, x1, h, w = train_crop.get_params(images, (args.resolution, args.resolution))
images = crop(images, y1, x1, h, w)
return torch.stack(prompt_embeds_all), torch.stack(pooled_prompt_embeds_all)
if args.random_flip and random.random() < 0.5:
# flip
x1 = orig_image.width - x1
images = train_flip(images)
crop_top_left = (y1, x1)
crop_top_lefts.append(crop_top_left)
# Adapted from examples.dreambooth.train_dreambooth_lora_sdxl
# Here, we compute not just the text embeddings but also the additional embeddings
# needed for the SD XL UNet to operate.
def compute_embeddings_for_prompts(prompts, text_encoders, tokenizers):
with torch.no_grad():
prompt_embeds_all, pooled_prompt_embeds_all = encode_prompts(text_encoders, tokenizers, prompts)
add_text_embeds_all = pooled_prompt_embeds_all
transformed_images = normalize(images)
prompt_embeds_all = prompt_embeds_all.to(accelerator.device)
add_text_embeds_all = add_text_embeds_all.to(accelerator.device)
return prompt_embeds_all, add_text_embeds_all
# Separate the original and edited images and the edit prompt.
original_image, edited_image = transformed_images.chunk(2)
original_image = original_image.squeeze(0)
edited_image = edited_image.squeeze(0)
final_original_images.append(original_image)
final_edited_images.append(edited_image)
# Get null conditioning
def compute_null_conditioning():
null_conditioning_list = []
for a_tokenizer, a_text_encoder in zip(tokenizers, text_encoders):
null_conditioning_list.append(
a_text_encoder(
tokenize_captions([""], tokenizer=a_tokenizer).to(accelerator.device),
output_hidden_states=True,
).hidden_states[-2]
)
return torch.concat(null_conditioning_list, dim=-1)
null_conditioning = compute_null_conditioning()
def compute_time_ids():
crops_coords_top_left = (args.crops_coords_top_left_h, args.crops_coords_top_left_w)
original_size = target_size = (args.resolution, args.resolution)
add_time_ids = list(original_size + crops_coords_top_left + target_size)
add_time_ids = torch.tensor([add_time_ids], dtype=weight_dtype)
return add_time_ids.to(accelerator.device).repeat(args.train_batch_size, 1)
add_time_ids = compute_time_ids()
def preprocess_train(examples):
# Preprocess images.
preprocessed_images = preprocess_images(examples)
# Since the original and edited images were concatenated before
# applying the transformations, we need to separate them and reshape
# them accordingly.
original_images, edited_images = preprocessed_images.chunk(2)
original_images = original_images.reshape(-1, 3, args.resolution, args.resolution)
edited_images = edited_images.reshape(-1, 3, args.resolution, args.resolution)
# Collate the preprocessed images into the `examples`.
examples["original_pixel_values"] = original_images
examples["edited_pixel_values"] = edited_images
# Preprocess the captions.
captions = list(examples[edit_prompt_column])
prompt_embeds_all, add_text_embeds_all = compute_embeddings_for_prompts(captions, text_encoders, tokenizers)
examples["prompt_embeds"] = prompt_embeds_all
examples["add_text_embeds"] = add_text_embeds_all
return examples
# Pack the values.
samples["original_sizes"] = original_sizes
samples["crop_top_lefts"] = crop_top_lefts
samples["original_pixel_values"] = final_original_images
samples["edited_pixel_values"] = final_original_images
tokens_one, tokens_two = tokenize_captions(samples)
samples["input_ids_one"] = tokens_one
samples["input_ids_two"] = tokens_two
return samples
with accelerator.main_process_first():
if args.max_train_samples is not None:
dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
dataset = dataset.shuffle(seed=args.seed).select(range(args.max_train_samples))
# Set the training transforms
train_dataset = dataset["train"].with_transform(preprocess_train)
train_dataset = dataset.with_transform(preprocess_train)
def collate_fn(examples):
original_pixel_values = torch.stack([example["original_pixel_values"] for example in examples])
original_pixel_values = original_pixel_values.to(memory_format=torch.contiguous_format).float()
edited_pixel_values = torch.stack([example["edited_pixel_values"] for example in examples])
edited_pixel_values = edited_pixel_values.to(memory_format=torch.contiguous_format).float()
prompt_embeds = torch.concat([example["prompt_embeds"] for example in examples], dim=0)
add_text_embeds = torch.concat([example["add_text_embeds"] for example in examples], dim=0)
original_sizes = [example["original_sizes"] for example in examples]
crop_top_lefts = [example["crop_top_lefts"] for example in examples]
input_ids_one = torch.stack([example["input_ids_one"] for example in examples])
input_ids_two = torch.stack([example["input_ids_two"] for example in examples])
return {
"original_pixel_values": original_pixel_values,
"edited_pixel_values": edited_pixel_values,
"prompt_embeds": prompt_embeds,
"add_text_embeds": add_text_embeds,
"input_ids_one": input_ids_one,
"input_ids_two": input_ids_two,
"original_sizes": original_sizes,
"crop_top_lefts": crop_top_lefts,
}
# DataLoaders creation:
@@ -947,6 +1005,12 @@ def main():
else:
initial_global_step = 0
# Get null conditioning.
# Remains fixed throughout training.
null_conditioning_prompt_embeds, null_conditioning_pooled_prompt_embeds = encode_prompt(
text_encoders, tokenizers, [""]
)
progress_bar = tqdm(
range(0, args.max_train_steps),
initial=initial_global_step,
@@ -982,9 +1046,13 @@ def main():
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# SDXL additional inputs
encoder_hidden_states = batch["prompt_embeds"]
add_text_embeds = batch["add_text_embeds"]
# Encode prompts.
prompt_embeds, pooled_prompt_embeds = encode_prompt(
text_encoders=[text_encoder_1, text_encoder_2],
tokenizers=None,
prompt=None,
text_input_ids_list=[batch["input_ids_one"], batch["input_ids_two"]],
)
# Get the additional image embedding for conditioning.
# Instead of getting a diagonal Gaussian here, we simply take the mode.
@@ -992,7 +1060,7 @@ def main():
original_pixel_values = batch["original_pixel_values"].to(dtype=weight_dtype)
else:
original_pixel_values = batch["original_pixel_values"]
original_image_embeds = vae.encode(original_pixel_values).latent_dist.sample()
original_image_embeds = vae.encode(original_pixel_values).latent_dist.mode()
if args.pretrained_vae_model_name_or_path is None:
original_image_embeds = original_image_embeds.to(weight_dtype)
@@ -1003,8 +1071,13 @@ def main():
# Sample masks for the edit prompts.
prompt_mask = random_p < 2 * args.conditioning_dropout_prob
prompt_mask = prompt_mask.reshape(bsz, 1, 1)
pooled_prompt_mask = prompt_mask.reshape(bsz, 1)
# Final text conditioning.
encoder_hidden_states = torch.where(prompt_mask, null_conditioning, encoder_hidden_states)
prompt_embeds = torch.where(prompt_mask, null_conditioning_prompt_embeds, prompt_embeds)
pooled_prompt_embeds = torch.where(
pooled_prompt_mask, null_conditioning_pooled_prompt_embeds, pooled_prompt_embeds
)
# Sample masks for the original images.
image_mask_dtype = original_image_embeds.dtype
@@ -1027,11 +1100,24 @@ def main():
else:
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
# Predict the noise residual and compute loss
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
# Compute additional embedding inputs.
# time ids
def compute_time_ids(original_size, crops_coords_top_left):
# Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
target_size = (args.resolution, args.resolution)
add_time_ids = list(original_size + crops_coords_top_left + target_size)
add_time_ids = torch.tensor([add_time_ids])
add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype)
return add_time_ids
add_time_ids = torch.cat(
[compute_time_ids(s, c) for s, c in zip(batch["original_sizes"], batch["crop_top_lefts"])]
)
unet_added_conditions = {"time_ids": add_time_ids, "text_embeds": pooled_prompt_embeds}
# Predict the noise residual and compute loss
model_pred = unet(
concatenated_noisy_latents, timesteps, encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
concatenated_noisy_latents, timesteps, prompt_embeds, added_cond_kwargs=unet_added_conditions
).sample
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
@@ -1056,8 +1142,8 @@ def main():
accelerator.log({"train_loss": train_loss}, step=global_step)
train_loss = 0.0
if global_step % args.checkpointing_steps == 0:
if accelerator.is_main_process:
if accelerator.is_main_process:
if global_step % args.checkpointing_steps == 0:
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
if args.checkpoints_total_limit is not None:
checkpoints = os.listdir(args.output_dir)
@@ -1085,81 +1171,37 @@ def main():
logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
### BEGIN: Perform validation every `validation_epochs` steps
if global_step % args.validation_steps == 0 or global_step == 1:
if (args.val_image_url_or_path is not None) and (args.validation_prompt is not None):
logger.info(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
f" {args.validation_prompt}."
)
# create pipeline
if args.use_ema:
# Store the UNet parameters temporarily and load the EMA parameters to perform inference.
ema_unet.store(unet.parameters())
ema_unet.copy_to(unet.parameters())
# The models need unwrapping because for compatibility in distributed training mode.
pipeline = StableDiffusionXLInstructPix2PixPipeline.from_pretrained(
args.pretrained_model_name_or_path,
unet=accelerator.unwrap_model(unet),
text_encoder=text_encoder_1,
text_encoder_2=text_encoder_2,
tokenizer=tokenizer_1,
tokenizer_2=tokenizer_2,
vae=vae,
revision=args.revision,
torch_dtype=weight_dtype,
)
pipeline = pipeline.to(accelerator.device)
pipeline.set_progress_bar_config(disable=True)
# run inference
# Save validation images
val_save_dir = os.path.join(args.output_dir, "validation_images")
if not os.path.exists(val_save_dir):
os.makedirs(val_save_dir)
original_image = (
lambda image_url_or_path: load_image(image_url_or_path)
if urlparse(image_url_or_path).scheme
else Image.open(image_url_or_path).convert("RGB")
)(args.val_image_url_or_path)
with torch.autocast(
str(accelerator.device).replace(":0", ""), enabled=accelerator.mixed_precision == "fp16"
):
edited_images = []
for val_img_idx in range(args.num_validation_images):
a_val_img = pipeline(
args.validation_prompt,
image=original_image,
num_inference_steps=20,
image_guidance_scale=1.5,
guidance_scale=7,
generator=generator,
).images[0]
edited_images.append(a_val_img)
a_val_img.save(os.path.join(val_save_dir, f"step_{global_step}_val_img_{val_img_idx}.png"))
for tracker in accelerator.trackers:
if tracker.name == "wandb":
wandb_table = wandb.Table(columns=WANDB_TABLE_COL_NAMES)
for edited_image in edited_images:
wandb_table.add_data(
wandb.Image(original_image), wandb.Image(edited_image), args.validation_prompt
)
tracker.log({"validation": wandb_table})
if args.use_ema:
# Switch back to the original UNet parameters.
ema_unet.restore(unet.parameters())
del pipeline
torch.cuda.empty_cache()
### END: Perform validation every `validation_epochs` steps
if global_step >= args.max_train_steps:
break
if accelerator.is_main_process:
if (
(args.val_image_url_or_path is not None)
and (args.validation_prompt is not None)
and (epoch % args.validation_epochs == 0)
):
if args.use_ema:
# Store the UNet parameters temporarily and load the EMA parameters to perform inference.
ema_unet.store(unet.parameters())
ema_unet.copy_to(unet.parameters())
log_validation(
vae=vae,
unet=unet,
text_encoder_1=text_encoder_1,
text_encoder_2=text_encoder_2,
tokenizer_1=tokenizer_1,
tokenizer_2=tokenizer_2,
args=args,
accelerator=accelerator,
weight_dtype=weight_dtype,
global_step=global_step,
)
if args.use_ema:
# Switch back to the original UNet parameters.
ema_unet.restore(unet.parameters())
# Create the pipeline using the trained modules and save it.
accelerator.wait_for_everyone()
if accelerator.is_main_process:
@@ -1189,8 +1231,15 @@ def main():
if args.validation_prompt is not None:
edited_images = []
original_image = (
lambda image_url_or_path: load_image(image_url_or_path)
if urlparse(image_url_or_path).scheme
else Image.open(image_url_or_path).convert("RGB")
)(args.val_image_url_or_path)
original_image = original_image.resize((args.resolution, args.resolution))
pipeline = pipeline.to(accelerator.device)
with torch.autocast(str(accelerator.device).replace(":0", "")):
with torch.autocast(str(accelerator.device)):
for _ in range(args.num_validation_images):
edited_images.append(
pipeline(
@@ -1198,7 +1247,7 @@ def main():
image=original_image,
num_inference_steps=20,
image_guidance_scale=1.5,
guidance_scale=7,
guidance_scale=5.0,
generator=generator,
).images[0]
)