mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-06 20:44:33 +08:00
Compare commits
15 Commits
variants-f
...
instruct-p
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3a95742edc | ||
|
|
6ff90f6c70 | ||
|
|
b5e405168d | ||
|
|
8dbc46dfa9 | ||
|
|
466553b885 | ||
|
|
3a5ef6c78f | ||
|
|
2a64edcb2c | ||
|
|
ed2a52daf6 | ||
|
|
148d6f9e58 | ||
|
|
f35b76c523 | ||
|
|
d3eea16750 | ||
|
|
1d486c95a1 | ||
|
|
0fcff42916 | ||
|
|
21fb55844e | ||
|
|
c8b88f8b31 |
124
examples/instruct_pix2pix/run.slurm
Normal file
124
examples/instruct_pix2pix/run.slurm
Normal file
@@ -0,0 +1,124 @@
|
||||
#!/bin/bash
|
||||
#SBATCH --job-name=instruct-pix2pix-sdxl-emu
|
||||
#SBATCH --nodes=1
|
||||
#SBATCH --ntasks-per-node=1 # crucial - only 1 task per dist per node!
|
||||
#SBATCH --cpus-per-task=96
|
||||
#SBATCH --gres=gpu:8
|
||||
#SBATCH --exclusive
|
||||
#SBATCH --partition=production-cluster
|
||||
#SBATCH --output=/admin/home/sayak/logs/instruct-pix2pix-sdxl-emu/%x-%j.out
|
||||
|
||||
set -x -e
|
||||
|
||||
source /admin/home/sayak/.bashrc
|
||||
source /admin/home/sayak/miniconda3/etc/profile.d/conda.sh
|
||||
conda activate diffusers
|
||||
|
||||
echo "START TIME: $(date)"
|
||||
|
||||
REPO=/fsx/sayak/diffusers/examples/instruct_pix2pix
|
||||
OUTPUT_DIR=/fsx/sayak/instruct-pix2pix-sdxl-emu
|
||||
LOG_PATH=$OUTPUT_DIR/main_log.txt
|
||||
ACCELERATE_CONFIG_FILE="$OUTPUT_DIR/${SLURM_JOB_ID}_accelerate_config.yaml.autogenerated"
|
||||
|
||||
mkdir -p $OUTPUT_DIR
|
||||
touch $LOG_PATH
|
||||
pushd $REPO
|
||||
|
||||
|
||||
GPUS_PER_NODE=8
|
||||
NNODES=$SLURM_NNODES
|
||||
NUM_GPUS=$((GPUS_PER_NODE*SLURM_NNODES))
|
||||
|
||||
# so processes know who to talk to
|
||||
MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1)
|
||||
MASTER_PORT=6000
|
||||
|
||||
|
||||
# Auto-generate the accelerate config
|
||||
cat << EOT > $ACCELERATE_CONFIG_FILE
|
||||
compute_environment: LOCAL_MACHINE
|
||||
deepspeed_config: {}
|
||||
distributed_type: MULTI_GPU
|
||||
fsdp_config: {}
|
||||
machine_rank: 0
|
||||
main_process_ip: $MASTER_ADDR
|
||||
main_process_port: $MASTER_PORT
|
||||
main_training_function: main
|
||||
num_machines: $SLURM_NNODES
|
||||
num_processes: $NUM_GPUS
|
||||
use_cpu: false
|
||||
EOT
|
||||
|
||||
|
||||
export MODEL_ID="stabilityai/stable-diffusion-xl-base-1.0"
|
||||
export DATASET_ID="facebook/emu_edit_test_set_generations"
|
||||
|
||||
PROGRAM="train_instruct_pix2pix_sdxl.py \
|
||||
--pretrained_model_name_or_path=$MODEL_ID \
|
||||
--pretrained_vae_model_name_or_path=madebyollin/sdxl-vae-fp16-fix \
|
||||
--dataset_name=$DATASET_ID \
|
||||
--original_image_column=image --edited_image_column=edited_image --edit_prompt_column=instruction \
|
||||
--resolution=1024 \
|
||||
--train_batch_size=8 --gradient_accumulation_steps=4 --gradient_checkpointing \
|
||||
--dataloader_num_workers=8 \
|
||||
--enable_xformers_memory_efficient_attention \
|
||||
--max_train_steps=10000 \
|
||||
--checkpointing_steps=2500 \
|
||||
--learning_rate=1e-5 --lr_warmup_steps=0 \
|
||||
--mixed_precision=fp16 \
|
||||
--val_image_url_or_path='https://hf.co/datasets/diffusers/diffusers-images-docs/resolve/main/mountain.png' \
|
||||
--validation_prompt='Turn sky into a cloudy one' \
|
||||
--seed=42 \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--report_to=wandb \
|
||||
--push_to_hub
|
||||
"
|
||||
|
||||
# Note: it is important to escape `$SLURM_PROCID` since we want the srun on each node to evaluate this variable
|
||||
export LAUNCHER="accelerate launch \
|
||||
--rdzv_conf "rdzv_backend=c10d,rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT,max_restarts=0,tee=3" \
|
||||
--config_file $ACCELERATE_CONFIG_FILE \
|
||||
--main_process_ip $MASTER_ADDR \
|
||||
--main_process_port $MASTER_PORT \
|
||||
--num_processes $NUM_GPUS \
|
||||
--machine_rank \$SLURM_PROCID \
|
||||
"
|
||||
|
||||
|
||||
export CMD="$LAUNCHER $PROGRAM"
|
||||
echo $CMD
|
||||
|
||||
# hide duplicated errors using this hack - will be properly fixed in pt-1.12
|
||||
# export TORCHELASTIC_ERROR_FILE=/tmp/torch-elastic-error.json
|
||||
|
||||
# force crashing on nccl issues like hanging broadcast
|
||||
export NCCL_ASYNC_ERROR_HANDLING=1
|
||||
# export NCCL_DEBUG=INFO
|
||||
# export NCCL_DEBUG_SUBSYS=COLL
|
||||
# export NCCL_SOCKET_NTHREADS=1
|
||||
# export NCCL_NSOCKS_PERTHREAD=1
|
||||
# export CUDA_LAUNCH_BLOCKING=1
|
||||
|
||||
# AWS specific
|
||||
export NCCL_PROTO=simple
|
||||
export RDMAV_FORK_SAFE=1
|
||||
export FI_EFA_FORK_SAFE=1
|
||||
export FI_EFA_USE_DEVICE_RDMA=1
|
||||
export FI_PROVIDER=efa
|
||||
export FI_LOG_LEVEL=1
|
||||
export NCCL_IB_DISABLE=1
|
||||
export NCCL_SOCKET_IFNAME=ens
|
||||
|
||||
|
||||
# srun error handling:
|
||||
# --wait=60: wait 60 sec after the first task terminates before terminating all remaining tasks
|
||||
# --kill-on-bad-exit=1: terminate a step if any task exits with a non-zero exit code
|
||||
SRUN_ARGS=" \
|
||||
--wait=60 \
|
||||
--kill-on-bad-exit=1 \
|
||||
"
|
||||
|
||||
clear; srun $SRUN_ARGS --jobid $SLURM_JOB_ID bash -c "$CMD" 2>&1 | tee $LOG_PATH
|
||||
|
||||
echo "END TIME: $(date)"
|
||||
@@ -18,6 +18,7 @@ import argparse
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import shutil
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
@@ -35,11 +36,12 @@ import transformers
|
||||
from accelerate import Accelerator
|
||||
from accelerate.logging import get_logger
|
||||
from accelerate.utils import ProjectConfiguration, set_seed
|
||||
from datasets import load_dataset
|
||||
from datasets import concatenate_datasets, load_dataset
|
||||
from huggingface_hub import create_repo, upload_folder
|
||||
from packaging import version
|
||||
from PIL import Image
|
||||
from torchvision import transforms
|
||||
from torchvision.transforms.functional import crop
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import AutoTokenizer, PretrainedConfig
|
||||
|
||||
@@ -54,6 +56,10 @@ from diffusers.utils import check_min_version, deprecate, is_wandb_available, lo
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.24.0.dev0")
|
||||
|
||||
@@ -62,7 +68,7 @@ logger = get_logger(__name__, log_level="INFO")
|
||||
DATASET_NAME_MAPPING = {
|
||||
"fusing/instructpix2pix-1000-samples": ("file_name", "edited_image", "edit_prompt"),
|
||||
}
|
||||
WANDB_TABLE_COL_NAMES = ["file_name", "edited_image", "edit_prompt"]
|
||||
WANDB_TABLE_COL_NAMES = ["original_image", "edited_image", "edit_prompt"]
|
||||
TORCH_DTYPE_MAPPING = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}
|
||||
|
||||
|
||||
@@ -86,6 +92,133 @@ def import_model_class_from_model_name_or_path(
|
||||
raise ValueError(f"{model_class} is not supported.")
|
||||
|
||||
|
||||
def tokenize_prompt(tokenizer, prompt):
|
||||
text_inputs = tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
return text_input_ids
|
||||
|
||||
|
||||
# Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt
|
||||
def encode_prompt(text_encoders, tokenizers, prompt, text_input_ids_list=None):
|
||||
prompt_embeds_list = []
|
||||
|
||||
for i, text_encoder in enumerate(text_encoders):
|
||||
if tokenizers is not None:
|
||||
tokenizer = tokenizers[i]
|
||||
text_input_ids = tokenize_prompt(tokenizer, prompt)
|
||||
else:
|
||||
assert text_input_ids_list is not None
|
||||
text_input_ids = text_input_ids_list[i]
|
||||
|
||||
prompt_embeds = text_encoder(
|
||||
text_input_ids.to(text_encoder.device),
|
||||
output_hidden_states=True,
|
||||
)
|
||||
|
||||
# We are only ALWAYS interested in the pooled output of the final text encoder
|
||||
pooled_prompt_embeds = prompt_embeds[0]
|
||||
prompt_embeds = prompt_embeds.hidden_states[-2]
|
||||
bs_embed, seq_len, _ = prompt_embeds.shape
|
||||
prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
|
||||
prompt_embeds_list.append(prompt_embeds)
|
||||
|
||||
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
|
||||
return prompt_embeds, pooled_prompt_embeds
|
||||
|
||||
|
||||
def log_validation(
|
||||
vae,
|
||||
unet,
|
||||
text_encoder_1,
|
||||
text_encoder_2,
|
||||
tokenizer_1,
|
||||
tokenizer_2,
|
||||
args,
|
||||
accelerator,
|
||||
weight_dtype,
|
||||
global_step,
|
||||
):
|
||||
logger.info(
|
||||
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
|
||||
f" {args.validation_prompt}."
|
||||
)
|
||||
|
||||
# The models need unwrapping because for compatibility in distributed training mode.
|
||||
pipeline = StableDiffusionXLInstructPix2PixPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
unet=accelerator.unwrap_model(unet),
|
||||
text_encoder=text_encoder_1,
|
||||
text_encoder_2=text_encoder_2,
|
||||
tokenizer=tokenizer_1,
|
||||
tokenizer_2=tokenizer_2,
|
||||
vae=vae,
|
||||
revision=args.revision,
|
||||
torch_dtype=weight_dtype,
|
||||
)
|
||||
pipeline = pipeline.to(accelerator.device)
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
|
||||
if args.enable_xformers_memory_efficient_attention:
|
||||
pipeline.enable_xformers_memory_efficient_attention()
|
||||
|
||||
if args.seed is None:
|
||||
generator = None
|
||||
else:
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
|
||||
|
||||
# run inference
|
||||
# Save validation images
|
||||
val_save_dir = os.path.join(args.output_dir, "validation_images")
|
||||
if not os.path.exists(val_save_dir):
|
||||
os.makedirs(val_save_dir)
|
||||
|
||||
original_image = (
|
||||
lambda image_url_or_path: load_image(image_url_or_path)
|
||||
if urlparse(image_url_or_path).scheme
|
||||
else Image.open(image_url_or_path).convert("RGB")
|
||||
)(args.val_image_url_or_path)
|
||||
original_image = original_image.resize((args.resolution, args.resolution))
|
||||
|
||||
with torch.autocast("cuda"):
|
||||
edited_images = []
|
||||
for val_img_idx in range(args.num_validation_images):
|
||||
a_val_img = pipeline(
|
||||
args.validation_prompt,
|
||||
height=args.resolution,
|
||||
width=args.resolution,
|
||||
image=original_image,
|
||||
num_inference_steps=25,
|
||||
image_guidance_scale=1.5,
|
||||
guidance_scale=5.0,
|
||||
generator=generator,
|
||||
).images[0]
|
||||
edited_images.append(a_val_img)
|
||||
a_val_img.save(
|
||||
os.path.join(
|
||||
val_save_dir,
|
||||
f"step_{global_step}_val_img_{val_img_idx}.png",
|
||||
)
|
||||
)
|
||||
|
||||
formatted_images = [wandb.Image(original_image, caption="Original Image")]
|
||||
for edited_image in edited_images:
|
||||
formatted_images.append(wandb.Image(edited_image, caption=args.validation_prompt))
|
||||
|
||||
for tracker in accelerator.trackers:
|
||||
if tracker.name == "wandb":
|
||||
tracker.log({"validation": formatted_images})
|
||||
|
||||
del pipeline
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Script to train Stable Diffusion XL for InstructPix2Pix.")
|
||||
parser.add_argument(
|
||||
@@ -177,15 +310,7 @@ def parse_args():
|
||||
default=4,
|
||||
help="Number of images that should be generated during validation with `validation_prompt`.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--validation_steps",
|
||||
type=int,
|
||||
default=100,
|
||||
help=(
|
||||
"Run fine-tuning validation every X steps. The validation process consists of running the prompt"
|
||||
" `args.validation_prompt` multiple times: `args.num_validation_images`."
|
||||
),
|
||||
)
|
||||
parser.add_argument("--validation_epochs", type=int, default=1, help="Run fine-tuning validation every X epochs.")
|
||||
parser.add_argument(
|
||||
"--max_train_samples",
|
||||
type=int,
|
||||
@@ -198,7 +323,7 @@ def parse_args():
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
default="instruct-pix2pix-model",
|
||||
default="instruct-pix2pix-sdxl",
|
||||
help="The output directory where the model predictions and checkpoints will be written.",
|
||||
)
|
||||
parser.add_argument(
|
||||
@@ -216,18 +341,6 @@ def parse_args():
|
||||
"The resolution for input images, all the images in the train/validation dataset will be resized to this resolution."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--crops_coords_top_left_h",
|
||||
type=int,
|
||||
default=0,
|
||||
help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--crops_coords_top_left_w",
|
||||
type=int,
|
||||
default=0,
|
||||
help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--center_crop",
|
||||
default=False,
|
||||
@@ -443,7 +556,6 @@ def main():
|
||||
if args.report_to == "wandb":
|
||||
if not is_wandb_available():
|
||||
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
|
||||
import wandb
|
||||
|
||||
# Make one log on every process with the configuration for debugging.
|
||||
logging.basicConfig(
|
||||
@@ -605,6 +717,7 @@ def main():
|
||||
args.dataset_config_name,
|
||||
cache_dir=args.cache_dir,
|
||||
)
|
||||
dataset = concatenate_datasets([dataset["validation"], dataset["test"]])
|
||||
else:
|
||||
data_files = {}
|
||||
if args.train_data_dir is not None:
|
||||
@@ -619,7 +732,7 @@ def main():
|
||||
|
||||
# Preprocessing the datasets.
|
||||
# We need to tokenize inputs and targets.
|
||||
column_names = dataset["train"].column_names
|
||||
column_names = dataset.column_names
|
||||
|
||||
# 6. Get the column names for input/target.
|
||||
dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)
|
||||
@@ -659,40 +772,6 @@ def main():
|
||||
weight_dtype = torch.bfloat16
|
||||
warnings.warn(f"weight_dtype {weight_dtype} may cause nan during vae encoding", UserWarning)
|
||||
|
||||
# Preprocessing the datasets.
|
||||
# We need to tokenize input captions and transform the images.
|
||||
def tokenize_captions(captions, tokenizer):
|
||||
inputs = tokenizer(
|
||||
captions,
|
||||
max_length=tokenizer.model_max_length,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
return inputs.input_ids
|
||||
|
||||
# Preprocessing the datasets.
|
||||
train_transforms = transforms.Compose(
|
||||
[
|
||||
transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
|
||||
transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
|
||||
]
|
||||
)
|
||||
|
||||
def preprocess_images(examples):
|
||||
original_images = np.concatenate(
|
||||
[convert_to_np(image, args.resolution) for image in examples[original_image_column]]
|
||||
)
|
||||
edited_images = np.concatenate(
|
||||
[convert_to_np(image, args.resolution) for image in examples[edited_image_column]]
|
||||
)
|
||||
# We need to ensure that the original and the edited images undergo the same
|
||||
# augmentation transforms.
|
||||
images = np.concatenate([original_images, edited_images])
|
||||
images = torch.tensor(images)
|
||||
images = 2 * (images / 255) - 1
|
||||
return train_transforms(images)
|
||||
|
||||
# Load scheduler, tokenizer and models.
|
||||
tokenizer_1 = AutoTokenizer.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision, use_fast=False
|
||||
@@ -729,132 +808,111 @@ def main():
|
||||
# Set UNet to trainable.
|
||||
unet.train()
|
||||
|
||||
# Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt
|
||||
def encode_prompt(text_encoders, tokenizers, prompt):
|
||||
prompt_embeds_list = []
|
||||
|
||||
for tokenizer, text_encoder in zip(tokenizers, text_encoders):
|
||||
text_inputs = tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
||||
text_input_ids, untruncated_ids
|
||||
):
|
||||
removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {tokenizer.model_max_length} tokens: {removed_text}"
|
||||
# Preprocessing the datasets.
|
||||
# We need to tokenize input captions and transform the images.
|
||||
# Preprocessing the datasets.
|
||||
def tokenize_captions(examples, is_train=True):
|
||||
captions = []
|
||||
for caption in examples[edit_prompt_column]:
|
||||
if isinstance(caption, str):
|
||||
captions.append(caption)
|
||||
elif isinstance(caption, (list, np.ndarray)):
|
||||
# take a random caption if there are multiple
|
||||
captions.append(random.choice(caption) if is_train else caption[0])
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Caption column `{edit_prompt_column}` should contain either strings or lists of strings."
|
||||
)
|
||||
tokens_one = tokenize_prompt(tokenizer_1, captions)
|
||||
tokens_two = tokenize_prompt(tokenizer_2, captions)
|
||||
return tokens_one, tokens_two
|
||||
|
||||
prompt_embeds = text_encoder(
|
||||
text_input_ids.to(text_encoder.device),
|
||||
output_hidden_states=True,
|
||||
)
|
||||
train_resize = transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR)
|
||||
train_crop = transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution)
|
||||
train_flip = transforms.RandomHorizontalFlip(p=1.0)
|
||||
normalize = transforms.Normalize([0.5], [0.5])
|
||||
|
||||
# We are only ALWAYS interested in the pooled output of the final text encoder
|
||||
pooled_prompt_embeds = prompt_embeds[0]
|
||||
prompt_embeds = prompt_embeds.hidden_states[-2]
|
||||
bs_embed, seq_len, _ = prompt_embeds.shape
|
||||
prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
|
||||
prompt_embeds_list.append(prompt_embeds)
|
||||
def preprocess_train(samples):
|
||||
orig_images = [image.convert("RGB") for image in samples[original_image_column]]
|
||||
edited_images = [image.convert("RGB") for image in samples[edited_image_column]]
|
||||
resized_edited_images = []
|
||||
|
||||
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
|
||||
return prompt_embeds, pooled_prompt_embeds
|
||||
# Resize edited images if necessary.
|
||||
for edited_image, orig_image in zip(edited_images, orig_images):
|
||||
if edited_image.size != orig_image.size:
|
||||
edited_image = edited_image.resize(orig_image.size)
|
||||
resized_edited_images.append(edited_image)
|
||||
else:
|
||||
resized_edited_images.append(edited_image)
|
||||
|
||||
# Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt
|
||||
def encode_prompts(text_encoders, tokenizers, prompts):
|
||||
prompt_embeds_all = []
|
||||
pooled_prompt_embeds_all = []
|
||||
# Main image processing.
|
||||
final_original_images = []
|
||||
final_edited_images = []
|
||||
original_sizes = []
|
||||
crop_top_lefts = []
|
||||
for edited_image, orig_image in zip(resized_edited_images, orig_images):
|
||||
original_sizes.append((orig_image.height, orig_image.width))
|
||||
|
||||
for prompt in prompts:
|
||||
prompt_embeds, pooled_prompt_embeds = encode_prompt(text_encoders, tokenizers, prompt)
|
||||
prompt_embeds_all.append(prompt_embeds)
|
||||
pooled_prompt_embeds_all.append(pooled_prompt_embeds)
|
||||
images = torch.stack([transforms.ToTensor()(orig_image), transforms.ToTensor()(edited_image)])
|
||||
images = train_resize(images)
|
||||
if args.center_crop:
|
||||
y1 = max(0, int(round((orig_image.height - args.resolution) / 2.0)))
|
||||
x1 = max(0, int(round((orig_image.width - args.resolution) / 2.0)))
|
||||
images = train_crop(images)
|
||||
else:
|
||||
y1, x1, h, w = train_crop.get_params(images, (args.resolution, args.resolution))
|
||||
images = crop(images, y1, x1, h, w)
|
||||
|
||||
return torch.stack(prompt_embeds_all), torch.stack(pooled_prompt_embeds_all)
|
||||
if args.random_flip and random.random() < 0.5:
|
||||
# flip
|
||||
x1 = orig_image.width - x1
|
||||
images = train_flip(images)
|
||||
crop_top_left = (y1, x1)
|
||||
crop_top_lefts.append(crop_top_left)
|
||||
|
||||
# Adapted from examples.dreambooth.train_dreambooth_lora_sdxl
|
||||
# Here, we compute not just the text embeddings but also the additional embeddings
|
||||
# needed for the SD XL UNet to operate.
|
||||
def compute_embeddings_for_prompts(prompts, text_encoders, tokenizers):
|
||||
with torch.no_grad():
|
||||
prompt_embeds_all, pooled_prompt_embeds_all = encode_prompts(text_encoders, tokenizers, prompts)
|
||||
add_text_embeds_all = pooled_prompt_embeds_all
|
||||
transformed_images = normalize(images)
|
||||
|
||||
prompt_embeds_all = prompt_embeds_all.to(accelerator.device)
|
||||
add_text_embeds_all = add_text_embeds_all.to(accelerator.device)
|
||||
return prompt_embeds_all, add_text_embeds_all
|
||||
# Separate the original and edited images and the edit prompt.
|
||||
original_image, edited_image = transformed_images.chunk(2)
|
||||
original_image = original_image.squeeze(0)
|
||||
edited_image = edited_image.squeeze(0)
|
||||
final_original_images.append(original_image)
|
||||
final_edited_images.append(edited_image)
|
||||
|
||||
# Get null conditioning
|
||||
def compute_null_conditioning():
|
||||
null_conditioning_list = []
|
||||
for a_tokenizer, a_text_encoder in zip(tokenizers, text_encoders):
|
||||
null_conditioning_list.append(
|
||||
a_text_encoder(
|
||||
tokenize_captions([""], tokenizer=a_tokenizer).to(accelerator.device),
|
||||
output_hidden_states=True,
|
||||
).hidden_states[-2]
|
||||
)
|
||||
return torch.concat(null_conditioning_list, dim=-1)
|
||||
|
||||
null_conditioning = compute_null_conditioning()
|
||||
|
||||
def compute_time_ids():
|
||||
crops_coords_top_left = (args.crops_coords_top_left_h, args.crops_coords_top_left_w)
|
||||
original_size = target_size = (args.resolution, args.resolution)
|
||||
add_time_ids = list(original_size + crops_coords_top_left + target_size)
|
||||
add_time_ids = torch.tensor([add_time_ids], dtype=weight_dtype)
|
||||
return add_time_ids.to(accelerator.device).repeat(args.train_batch_size, 1)
|
||||
|
||||
add_time_ids = compute_time_ids()
|
||||
|
||||
def preprocess_train(examples):
|
||||
# Preprocess images.
|
||||
preprocessed_images = preprocess_images(examples)
|
||||
# Since the original and edited images were concatenated before
|
||||
# applying the transformations, we need to separate them and reshape
|
||||
# them accordingly.
|
||||
original_images, edited_images = preprocessed_images.chunk(2)
|
||||
original_images = original_images.reshape(-1, 3, args.resolution, args.resolution)
|
||||
edited_images = edited_images.reshape(-1, 3, args.resolution, args.resolution)
|
||||
|
||||
# Collate the preprocessed images into the `examples`.
|
||||
examples["original_pixel_values"] = original_images
|
||||
examples["edited_pixel_values"] = edited_images
|
||||
|
||||
# Preprocess the captions.
|
||||
captions = list(examples[edit_prompt_column])
|
||||
prompt_embeds_all, add_text_embeds_all = compute_embeddings_for_prompts(captions, text_encoders, tokenizers)
|
||||
examples["prompt_embeds"] = prompt_embeds_all
|
||||
examples["add_text_embeds"] = add_text_embeds_all
|
||||
return examples
|
||||
# Pack the values.
|
||||
samples["original_sizes"] = original_sizes
|
||||
samples["crop_top_lefts"] = crop_top_lefts
|
||||
samples["original_pixel_values"] = final_original_images
|
||||
samples["edited_pixel_values"] = final_original_images
|
||||
tokens_one, tokens_two = tokenize_captions(samples)
|
||||
samples["input_ids_one"] = tokens_one
|
||||
samples["input_ids_two"] = tokens_two
|
||||
return samples
|
||||
|
||||
with accelerator.main_process_first():
|
||||
if args.max_train_samples is not None:
|
||||
dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
|
||||
dataset = dataset.shuffle(seed=args.seed).select(range(args.max_train_samples))
|
||||
# Set the training transforms
|
||||
train_dataset = dataset["train"].with_transform(preprocess_train)
|
||||
train_dataset = dataset.with_transform(preprocess_train)
|
||||
|
||||
def collate_fn(examples):
|
||||
original_pixel_values = torch.stack([example["original_pixel_values"] for example in examples])
|
||||
original_pixel_values = original_pixel_values.to(memory_format=torch.contiguous_format).float()
|
||||
|
||||
edited_pixel_values = torch.stack([example["edited_pixel_values"] for example in examples])
|
||||
edited_pixel_values = edited_pixel_values.to(memory_format=torch.contiguous_format).float()
|
||||
prompt_embeds = torch.concat([example["prompt_embeds"] for example in examples], dim=0)
|
||||
add_text_embeds = torch.concat([example["add_text_embeds"] for example in examples], dim=0)
|
||||
|
||||
original_sizes = [example["original_sizes"] for example in examples]
|
||||
crop_top_lefts = [example["crop_top_lefts"] for example in examples]
|
||||
input_ids_one = torch.stack([example["input_ids_one"] for example in examples])
|
||||
input_ids_two = torch.stack([example["input_ids_two"] for example in examples])
|
||||
return {
|
||||
"original_pixel_values": original_pixel_values,
|
||||
"edited_pixel_values": edited_pixel_values,
|
||||
"prompt_embeds": prompt_embeds,
|
||||
"add_text_embeds": add_text_embeds,
|
||||
"input_ids_one": input_ids_one,
|
||||
"input_ids_two": input_ids_two,
|
||||
"original_sizes": original_sizes,
|
||||
"crop_top_lefts": crop_top_lefts,
|
||||
}
|
||||
|
||||
# DataLoaders creation:
|
||||
@@ -947,6 +1005,12 @@ def main():
|
||||
else:
|
||||
initial_global_step = 0
|
||||
|
||||
# Get null conditioning.
|
||||
# Remains fixed throughout training.
|
||||
null_conditioning_prompt_embeds, null_conditioning_pooled_prompt_embeds = encode_prompt(
|
||||
text_encoders, tokenizers, [""]
|
||||
)
|
||||
|
||||
progress_bar = tqdm(
|
||||
range(0, args.max_train_steps),
|
||||
initial=initial_global_step,
|
||||
@@ -982,9 +1046,13 @@ def main():
|
||||
# (this is the forward diffusion process)
|
||||
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||
|
||||
# SDXL additional inputs
|
||||
encoder_hidden_states = batch["prompt_embeds"]
|
||||
add_text_embeds = batch["add_text_embeds"]
|
||||
# Encode prompts.
|
||||
prompt_embeds, pooled_prompt_embeds = encode_prompt(
|
||||
text_encoders=[text_encoder_1, text_encoder_2],
|
||||
tokenizers=None,
|
||||
prompt=None,
|
||||
text_input_ids_list=[batch["input_ids_one"], batch["input_ids_two"]],
|
||||
)
|
||||
|
||||
# Get the additional image embedding for conditioning.
|
||||
# Instead of getting a diagonal Gaussian here, we simply take the mode.
|
||||
@@ -992,7 +1060,7 @@ def main():
|
||||
original_pixel_values = batch["original_pixel_values"].to(dtype=weight_dtype)
|
||||
else:
|
||||
original_pixel_values = batch["original_pixel_values"]
|
||||
original_image_embeds = vae.encode(original_pixel_values).latent_dist.sample()
|
||||
original_image_embeds = vae.encode(original_pixel_values).latent_dist.mode()
|
||||
if args.pretrained_vae_model_name_or_path is None:
|
||||
original_image_embeds = original_image_embeds.to(weight_dtype)
|
||||
|
||||
@@ -1003,8 +1071,13 @@ def main():
|
||||
# Sample masks for the edit prompts.
|
||||
prompt_mask = random_p < 2 * args.conditioning_dropout_prob
|
||||
prompt_mask = prompt_mask.reshape(bsz, 1, 1)
|
||||
pooled_prompt_mask = prompt_mask.reshape(bsz, 1)
|
||||
|
||||
# Final text conditioning.
|
||||
encoder_hidden_states = torch.where(prompt_mask, null_conditioning, encoder_hidden_states)
|
||||
prompt_embeds = torch.where(prompt_mask, null_conditioning_prompt_embeds, prompt_embeds)
|
||||
pooled_prompt_embeds = torch.where(
|
||||
pooled_prompt_mask, null_conditioning_pooled_prompt_embeds, pooled_prompt_embeds
|
||||
)
|
||||
|
||||
# Sample masks for the original images.
|
||||
image_mask_dtype = original_image_embeds.dtype
|
||||
@@ -1027,11 +1100,24 @@ def main():
|
||||
else:
|
||||
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
|
||||
|
||||
# Predict the noise residual and compute loss
|
||||
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
|
||||
# Compute additional embedding inputs.
|
||||
# time ids
|
||||
def compute_time_ids(original_size, crops_coords_top_left):
|
||||
# Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
|
||||
target_size = (args.resolution, args.resolution)
|
||||
add_time_ids = list(original_size + crops_coords_top_left + target_size)
|
||||
add_time_ids = torch.tensor([add_time_ids])
|
||||
add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype)
|
||||
return add_time_ids
|
||||
|
||||
add_time_ids = torch.cat(
|
||||
[compute_time_ids(s, c) for s, c in zip(batch["original_sizes"], batch["crop_top_lefts"])]
|
||||
)
|
||||
unet_added_conditions = {"time_ids": add_time_ids, "text_embeds": pooled_prompt_embeds}
|
||||
|
||||
# Predict the noise residual and compute loss
|
||||
model_pred = unet(
|
||||
concatenated_noisy_latents, timesteps, encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
|
||||
concatenated_noisy_latents, timesteps, prompt_embeds, added_cond_kwargs=unet_added_conditions
|
||||
).sample
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
||||
|
||||
@@ -1056,8 +1142,8 @@ def main():
|
||||
accelerator.log({"train_loss": train_loss}, step=global_step)
|
||||
train_loss = 0.0
|
||||
|
||||
if global_step % args.checkpointing_steps == 0:
|
||||
if accelerator.is_main_process:
|
||||
if accelerator.is_main_process:
|
||||
if global_step % args.checkpointing_steps == 0:
|
||||
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
|
||||
if args.checkpoints_total_limit is not None:
|
||||
checkpoints = os.listdir(args.output_dir)
|
||||
@@ -1085,81 +1171,37 @@ def main():
|
||||
logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
|
||||
progress_bar.set_postfix(**logs)
|
||||
|
||||
### BEGIN: Perform validation every `validation_epochs` steps
|
||||
if global_step % args.validation_steps == 0 or global_step == 1:
|
||||
if (args.val_image_url_or_path is not None) and (args.validation_prompt is not None):
|
||||
logger.info(
|
||||
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
|
||||
f" {args.validation_prompt}."
|
||||
)
|
||||
|
||||
# create pipeline
|
||||
if args.use_ema:
|
||||
# Store the UNet parameters temporarily and load the EMA parameters to perform inference.
|
||||
ema_unet.store(unet.parameters())
|
||||
ema_unet.copy_to(unet.parameters())
|
||||
|
||||
# The models need unwrapping because for compatibility in distributed training mode.
|
||||
pipeline = StableDiffusionXLInstructPix2PixPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
unet=accelerator.unwrap_model(unet),
|
||||
text_encoder=text_encoder_1,
|
||||
text_encoder_2=text_encoder_2,
|
||||
tokenizer=tokenizer_1,
|
||||
tokenizer_2=tokenizer_2,
|
||||
vae=vae,
|
||||
revision=args.revision,
|
||||
torch_dtype=weight_dtype,
|
||||
)
|
||||
pipeline = pipeline.to(accelerator.device)
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
|
||||
# run inference
|
||||
# Save validation images
|
||||
val_save_dir = os.path.join(args.output_dir, "validation_images")
|
||||
if not os.path.exists(val_save_dir):
|
||||
os.makedirs(val_save_dir)
|
||||
|
||||
original_image = (
|
||||
lambda image_url_or_path: load_image(image_url_or_path)
|
||||
if urlparse(image_url_or_path).scheme
|
||||
else Image.open(image_url_or_path).convert("RGB")
|
||||
)(args.val_image_url_or_path)
|
||||
with torch.autocast(
|
||||
str(accelerator.device).replace(":0", ""), enabled=accelerator.mixed_precision == "fp16"
|
||||
):
|
||||
edited_images = []
|
||||
for val_img_idx in range(args.num_validation_images):
|
||||
a_val_img = pipeline(
|
||||
args.validation_prompt,
|
||||
image=original_image,
|
||||
num_inference_steps=20,
|
||||
image_guidance_scale=1.5,
|
||||
guidance_scale=7,
|
||||
generator=generator,
|
||||
).images[0]
|
||||
edited_images.append(a_val_img)
|
||||
a_val_img.save(os.path.join(val_save_dir, f"step_{global_step}_val_img_{val_img_idx}.png"))
|
||||
|
||||
for tracker in accelerator.trackers:
|
||||
if tracker.name == "wandb":
|
||||
wandb_table = wandb.Table(columns=WANDB_TABLE_COL_NAMES)
|
||||
for edited_image in edited_images:
|
||||
wandb_table.add_data(
|
||||
wandb.Image(original_image), wandb.Image(edited_image), args.validation_prompt
|
||||
)
|
||||
tracker.log({"validation": wandb_table})
|
||||
if args.use_ema:
|
||||
# Switch back to the original UNet parameters.
|
||||
ema_unet.restore(unet.parameters())
|
||||
|
||||
del pipeline
|
||||
torch.cuda.empty_cache()
|
||||
### END: Perform validation every `validation_epochs` steps
|
||||
|
||||
if global_step >= args.max_train_steps:
|
||||
break
|
||||
|
||||
if accelerator.is_main_process:
|
||||
if (
|
||||
(args.val_image_url_or_path is not None)
|
||||
and (args.validation_prompt is not None)
|
||||
and (epoch % args.validation_epochs == 0)
|
||||
):
|
||||
if args.use_ema:
|
||||
# Store the UNet parameters temporarily and load the EMA parameters to perform inference.
|
||||
ema_unet.store(unet.parameters())
|
||||
ema_unet.copy_to(unet.parameters())
|
||||
|
||||
log_validation(
|
||||
vae=vae,
|
||||
unet=unet,
|
||||
text_encoder_1=text_encoder_1,
|
||||
text_encoder_2=text_encoder_2,
|
||||
tokenizer_1=tokenizer_1,
|
||||
tokenizer_2=tokenizer_2,
|
||||
args=args,
|
||||
accelerator=accelerator,
|
||||
weight_dtype=weight_dtype,
|
||||
global_step=global_step,
|
||||
)
|
||||
|
||||
if args.use_ema:
|
||||
# Switch back to the original UNet parameters.
|
||||
ema_unet.restore(unet.parameters())
|
||||
|
||||
# Create the pipeline using the trained modules and save it.
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
@@ -1189,8 +1231,15 @@ def main():
|
||||
|
||||
if args.validation_prompt is not None:
|
||||
edited_images = []
|
||||
original_image = (
|
||||
lambda image_url_or_path: load_image(image_url_or_path)
|
||||
if urlparse(image_url_or_path).scheme
|
||||
else Image.open(image_url_or_path).convert("RGB")
|
||||
)(args.val_image_url_or_path)
|
||||
original_image = original_image.resize((args.resolution, args.resolution))
|
||||
|
||||
pipeline = pipeline.to(accelerator.device)
|
||||
with torch.autocast(str(accelerator.device).replace(":0", "")):
|
||||
with torch.autocast(str(accelerator.device)):
|
||||
for _ in range(args.num_validation_images):
|
||||
edited_images.append(
|
||||
pipeline(
|
||||
@@ -1198,7 +1247,7 @@ def main():
|
||||
image=original_image,
|
||||
num_inference_steps=20,
|
||||
image_guidance_scale=1.5,
|
||||
guidance_scale=7,
|
||||
guidance_scale=5.0,
|
||||
generator=generator,
|
||||
).images[0]
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user