mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-16 01:14:47 +08:00
Adding better way to define multiple concepts and also validation capabilities. (#3807)
* - Added validation parameters - Changed some parameter descriptions to better explain their use. - Fixed a few typos. - Added concept_list parameter for better management of multiple subjects - changed logic for image validation * - Fixed bad logic for class data root directories * Defaulting validation_steps to None for an easier logic * Fixed multiple validation prompts * Fixed bug on validation negative prompt * Changed validation logic for tracker. * Added uuid for validation image labeling * Fix error when comparing validation prompts and validation negative prompts * Improved error message when negative prompts for validation are more than the number of prompts * - Changed image tracking number from epoch to global_step - Added Typing for functions * Added some validations more when using concept_list parameter and the regular ones. * Fixed error message * Added more validations for validation parameters * Improved messaging for errors * Fixed validation error for parameters with default values * - Added train step to image name for validation - reformatted code * - Added train step to image's name for validation - reformatted code * Updated README.md file. * reverted back original script of train_dreambooth.py * reverted back original script of train_dreambooth.py * left one blank line at the eof * reverted back setup.py * reverted back setup.py * added same logic for when parameters for prior preservation are used without enabling the flag while using concept_list parameter. * Ran black formatter. * fixed a few strings * fixed import sort with isort and removed fstrings without placeholder * fixed import order with ruff (since with isort wasn't ok) --------- Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
committed by
GitHub
parent
2e8668f0af
commit
572d8e2002
@@ -86,6 +86,53 @@ This example shows training for 2 subjects, but please note that the model can b
|
||||
|
||||
Note also that in this script, `sks` and `t@y` were used as tokens to learn the new subjects ([this thread](https://github.com/XavierXiao/Dreambooth-Stable-Diffusion/issues/71) inspired the use of `t@y` as our second identifier). However, there may be better rare tokens to experiment with, and results also seemed to be good when more intuitive words are used.
|
||||
|
||||
**Important**: New parameters are added to the script, making possible to validate the progress of the training by
|
||||
generating images at specified steps. Taking also into account that a comma separated list in a text field for a prompt
|
||||
it's never a good idea (simply because it is very common in prompts to have them as part of a regular text) we
|
||||
introduce the `concept_list` parameter: allowing to specify a json-like file where you can define the different
|
||||
configuration for each subject that you want to train.
|
||||
|
||||
An example of how to generate the file:
|
||||
```python
|
||||
import json
|
||||
|
||||
# here we are using parameters for prior-preservation and validation as well.
|
||||
concepts_list = [
|
||||
{
|
||||
"instance_prompt": "drawing of a t@y meme",
|
||||
"class_prompt": "drawing of a meme",
|
||||
"instance_data_dir": "/some_folder/meme_toy",
|
||||
"class_data_dir": "/data/meme",
|
||||
"validation_prompt": "drawing of a t@y meme about football in Uruguay",
|
||||
"validation_negative_prompt": "black and white"
|
||||
},
|
||||
{
|
||||
"instance_prompt": "drawing of a sks sir",
|
||||
"class_prompt": "drawing of a sir",
|
||||
"instance_data_dir": "/some_other_folder/sir_sks",
|
||||
"class_data_dir": "/data/sir",
|
||||
"validation_prompt": "drawing of a sks sir with the Uruguayan sun in his chest",
|
||||
"validation_negative_prompt": "an old man",
|
||||
"validation_guidance_scale": 20,
|
||||
"validation_number_images": 3,
|
||||
"validation_inference_steps": 10
|
||||
}
|
||||
]
|
||||
|
||||
with open("concepts_list.json", "w") as f:
|
||||
json.dump(concepts_list, f, indent=4)
|
||||
```
|
||||
And then just point to the file when executing the script:
|
||||
|
||||
```bash
|
||||
# exports...
|
||||
accelerate launch train_multi_subject_dreambooth.py \
|
||||
# more parameters...
|
||||
--concepts_list="concepts_list.json"
|
||||
```
|
||||
|
||||
You can use the helper from the script to get a better sense of each parameter.
|
||||
|
||||
### Inference
|
||||
|
||||
Once you have trained a model using above command, the inference can be done simply using the `StableDiffusionPipeline`. Make sure to include the `identifier`(e.g. sks in above example) in your prompt.
|
||||
|
||||
@@ -1,13 +1,18 @@
|
||||
import argparse
|
||||
import hashlib
|
||||
import itertools
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import uuid
|
||||
import warnings
|
||||
from os import environ, listdir, makedirs
|
||||
from os.path import basename, join
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import datasets
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
@@ -17,24 +22,140 @@ from accelerate.logging import get_logger
|
||||
from accelerate.utils import ProjectConfiguration, set_seed
|
||||
from huggingface_hub import create_repo, upload_folder
|
||||
from PIL import Image
|
||||
from torch import dtype
|
||||
from torch.nn import Module
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision import transforms
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import AutoTokenizer, PretrainedConfig
|
||||
|
||||
import diffusers
|
||||
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
DDPMScheduler,
|
||||
DiffusionPipeline,
|
||||
DPMSolverMultistepScheduler,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.utils import check_min_version
|
||||
from diffusers.utils import check_min_version, is_wandb_available
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.13.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def log_validation_images_to_tracker(
|
||||
images: List[np.array], label: str, validation_prompt: str, accelerator: Accelerator, epoch: int
|
||||
):
|
||||
logger.info(f"Logging images to tracker for validation prompt: {validation_prompt}.")
|
||||
|
||||
for tracker in accelerator.trackers:
|
||||
if tracker.name == "tensorboard":
|
||||
np_images = np.stack([np.asarray(img) for img in images])
|
||||
tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
|
||||
if tracker.name == "wandb":
|
||||
tracker.log(
|
||||
{
|
||||
"validation": [
|
||||
wandb.Image(image, caption=f"{label}_{epoch}_{i}: {validation_prompt}")
|
||||
for i, image in enumerate(images)
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# TODO: Add `prompt_embeds` and `negative_prompt_embeds` parameters to the function when `pre_compute_text_embeddings`
|
||||
# argument is implemented.
|
||||
def generate_validation_images(
|
||||
text_encoder: Module,
|
||||
tokenizer: Module,
|
||||
unet: Module,
|
||||
vae: Module,
|
||||
arguments: argparse.Namespace,
|
||||
accelerator: Accelerator,
|
||||
weight_dtype: dtype,
|
||||
):
|
||||
logger.info("Running validation images.")
|
||||
|
||||
pipeline_args = {}
|
||||
|
||||
if text_encoder is not None:
|
||||
pipeline_args["text_encoder"] = accelerator.unwrap_model(text_encoder)
|
||||
|
||||
if vae is not None:
|
||||
pipeline_args["vae"] = vae
|
||||
|
||||
# create pipeline (note: unet and vae are loaded again in float32)
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
arguments.pretrained_model_name_or_path,
|
||||
tokenizer=tokenizer,
|
||||
unet=accelerator.unwrap_model(unet),
|
||||
revision=arguments.revision,
|
||||
torch_dtype=weight_dtype,
|
||||
**pipeline_args,
|
||||
)
|
||||
|
||||
# We train on the simplified learning objective. If we were previously predicting a variance, we need the
|
||||
# scheduler to ignore it
|
||||
scheduler_args = {}
|
||||
|
||||
if "variance_type" in pipeline.scheduler.config:
|
||||
variance_type = pipeline.scheduler.config.variance_type
|
||||
|
||||
if variance_type in ["learned", "learned_range"]:
|
||||
variance_type = "fixed_small"
|
||||
|
||||
scheduler_args["variance_type"] = variance_type
|
||||
|
||||
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args)
|
||||
pipeline = pipeline.to(accelerator.device)
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
|
||||
generator = (
|
||||
None if arguments.seed is None else torch.Generator(device=accelerator.device).manual_seed(arguments.seed)
|
||||
)
|
||||
|
||||
images_sets = []
|
||||
for vp, nvi, vnp, vis, vgs in zip(
|
||||
arguments.validation_prompt,
|
||||
arguments.validation_number_images,
|
||||
arguments.validation_negative_prompt,
|
||||
arguments.validation_inference_steps,
|
||||
arguments.validation_guidance_scale,
|
||||
):
|
||||
images = []
|
||||
if vp is not None:
|
||||
logger.info(
|
||||
f"Generating {nvi} images with prompt: '{vp}', negative prompt: '{vnp}', inference steps: {vis}, "
|
||||
f"guidance scale: {vgs}."
|
||||
)
|
||||
|
||||
pipeline_args = {"prompt": vp, "negative_prompt": vnp, "num_inference_steps": vis, "guidance_scale": vgs}
|
||||
|
||||
# run inference
|
||||
# TODO: it would be good to measure whether it's faster to run inference on all images at once, one at a
|
||||
# time or in small batches
|
||||
for _ in range(nvi):
|
||||
with torch.autocast("cuda"):
|
||||
image = pipeline(**pipeline_args, num_images_per_prompt=1, generator=generator).images[0]
|
||||
images.append(image)
|
||||
|
||||
images_sets.append(images)
|
||||
|
||||
del pipeline
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return images_sets
|
||||
|
||||
|
||||
def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
|
||||
text_encoder_config = PretrainedConfig.from_pretrained(
|
||||
pretrained_model_name_or_path,
|
||||
@@ -81,7 +202,7 @@ def parse_args(input_args=None):
|
||||
"--instance_data_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
required=False,
|
||||
help="A folder containing the training data of instance images.",
|
||||
)
|
||||
parser.add_argument(
|
||||
@@ -95,7 +216,7 @@ def parse_args(input_args=None):
|
||||
"--instance_prompt",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
required=False,
|
||||
help="The prompt with identifier specifying the instance",
|
||||
)
|
||||
parser.add_argument(
|
||||
@@ -272,6 +393,52 @@ def parse_args(input_args=None):
|
||||
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--validation_steps",
|
||||
type=int,
|
||||
default=None,
|
||||
help=(
|
||||
"Run validation every X steps. Validation consists of running the prompt(s) `validation_prompt` "
|
||||
"multiple times (`validation_number_images`) and logging the images."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--validation_prompt",
|
||||
type=str,
|
||||
default=None,
|
||||
help="A prompt that is used during validation to verify that the model is learning. You can use commas to "
|
||||
"define multiple negative prompts. This parameter can be defined also within the file given by "
|
||||
"`concepts_list` parameter in the respective subject.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--validation_number_images",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Number of images that should be generated during validation with the validation parameters given. This "
|
||||
"can be defined within the file given by `concepts_list` parameter in the respective subject.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--validation_negative_prompt",
|
||||
type=str,
|
||||
default=None,
|
||||
help="A negative prompt that is used during validation to verify that the model is learning. You can use commas"
|
||||
" to define multiple negative prompts, each one corresponding to a validation prompt. This parameter can "
|
||||
"be defined also within the file given by `concepts_list` parameter in the respective subject.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--validation_inference_steps",
|
||||
type=int,
|
||||
default=25,
|
||||
help="Number of inference steps (denoising steps) to run during validation. This can be defined within the "
|
||||
"file given by `concepts_list` parameter in the respective subject.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--validation_guidance_scale",
|
||||
type=float,
|
||||
default=7.5,
|
||||
help="To control how much the image generation process follows the text prompt. This can be defined within the "
|
||||
"file given by `concepts_list` parameter in the respective subject.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mixed_precision",
|
||||
type=str,
|
||||
@@ -297,27 +464,80 @@ def parse_args(input_args=None):
|
||||
parser.add_argument(
|
||||
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--set_grads_to_none",
|
||||
action="store_true",
|
||||
help=(
|
||||
"Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain"
|
||||
" behaviors, so disable this argument if it causes any problems. More info:"
|
||||
" https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--concepts_list",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to json file containing a list of multiple concepts, will overwrite parameters like instance_prompt,"
|
||||
" class_prompt, etc.",
|
||||
)
|
||||
|
||||
if input_args is not None:
|
||||
if input_args:
|
||||
args = parser.parse_args(input_args)
|
||||
else:
|
||||
args = parser.parse_args()
|
||||
|
||||
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
||||
if not args.concepts_list and (not args.instance_data_dir or not args.instance_prompt):
|
||||
raise ValueError(
|
||||
"You must specify either instance parameters (data directory, prompt, etc.) or use "
|
||||
"the `concept_list` parameter and specify them within the file."
|
||||
)
|
||||
|
||||
if args.concepts_list:
|
||||
if args.instance_prompt:
|
||||
raise ValueError("If you are using `concepts_list` parameter, define the instance prompt within the file.")
|
||||
if args.instance_data_dir:
|
||||
raise ValueError(
|
||||
"If you are using `concepts_list` parameter, define the instance data directory within the file."
|
||||
)
|
||||
if args.validation_steps and (args.validation_prompt or args.validation_negative_prompt):
|
||||
raise ValueError(
|
||||
"If you are using `concepts_list` parameter, define validation parameters for "
|
||||
"each subject within the file:\n - `validation_prompt`."
|
||||
"\n - `validation_negative_prompt`.\n - `validation_guidance_scale`."
|
||||
"\n - `validation_number_images`.\n - `validation_prompt`."
|
||||
"\n - `validation_inference_steps`.\nThe `validation_steps` parameter is the only one "
|
||||
"that needs to be defined outside the file."
|
||||
)
|
||||
|
||||
env_local_rank = int(environ.get("LOCAL_RANK", -1))
|
||||
if env_local_rank != -1 and env_local_rank != args.local_rank:
|
||||
args.local_rank = env_local_rank
|
||||
|
||||
if args.with_prior_preservation:
|
||||
if args.class_data_dir is None:
|
||||
raise ValueError("You must specify a data directory for class images.")
|
||||
if args.class_prompt is None:
|
||||
raise ValueError("You must specify prompt for class images.")
|
||||
if not args.concepts_list:
|
||||
if not args.class_data_dir:
|
||||
raise ValueError("You must specify a data directory for class images.")
|
||||
if not args.class_prompt:
|
||||
raise ValueError("You must specify prompt for class images.")
|
||||
else:
|
||||
if args.class_data_dir:
|
||||
raise ValueError(
|
||||
"If you are using `concepts_list` parameter, define the class data directory within the file."
|
||||
)
|
||||
if args.class_prompt:
|
||||
raise ValueError(
|
||||
"If you are using `concepts_list` parameter, define the class prompt within the file."
|
||||
)
|
||||
else:
|
||||
# logger is not available yet
|
||||
if args.class_data_dir is not None:
|
||||
warnings.warn("You need not use --class_data_dir without --with_prior_preservation.")
|
||||
if args.class_prompt is not None:
|
||||
warnings.warn("You need not use --class_prompt without --with_prior_preservation.")
|
||||
if not args.class_data_dir:
|
||||
warnings.warn(
|
||||
"Ignoring `class_data_dir` parameter, you need to use it together with `with_prior_preservation`."
|
||||
)
|
||||
if not args.class_prompt:
|
||||
warnings.warn(
|
||||
"Ignoring `class_prompt` parameter, you need to use it together with `with_prior_preservation`."
|
||||
)
|
||||
|
||||
return args
|
||||
|
||||
@@ -325,7 +545,7 @@ def parse_args(input_args=None):
|
||||
class DreamBoothDataset(Dataset):
|
||||
"""
|
||||
A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
|
||||
It pre-processes the images and the tokenizes prompts.
|
||||
It pre-processes the images and then tokenizes prompts.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -346,7 +566,7 @@ class DreamBoothDataset(Dataset):
|
||||
self.instance_images_path = []
|
||||
self.num_instance_images = []
|
||||
self.instance_prompt = []
|
||||
self.class_data_root = []
|
||||
self.class_data_root = [] if class_data_root is not None else None
|
||||
self.class_images_path = []
|
||||
self.num_class_images = []
|
||||
self.class_prompt = []
|
||||
@@ -371,8 +591,6 @@ class DreamBoothDataset(Dataset):
|
||||
self._length -= self.num_instance_images[i]
|
||||
self._length += self.num_class_images[i]
|
||||
self.class_prompt.append(class_prompt[i])
|
||||
else:
|
||||
self.class_data_root = None
|
||||
|
||||
self.image_transforms = transforms.Compose(
|
||||
[
|
||||
@@ -446,7 +664,7 @@ def collate_fn(num_instances, examples, with_prior_preservation=False):
|
||||
|
||||
|
||||
class PromptDataset(Dataset):
|
||||
"A simple dataset to prepare the prompts to generate class images on multiple GPUs."
|
||||
"""A simple dataset to prepare the prompts to generate class images on multiple GPUs."""
|
||||
|
||||
def __init__(self, prompt, num_samples):
|
||||
self.prompt = prompt
|
||||
@@ -474,6 +692,10 @@ def main(args):
|
||||
project_config=accelerator_project_config,
|
||||
)
|
||||
|
||||
if args.report_to == "wandb":
|
||||
if not is_wandb_available():
|
||||
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
|
||||
|
||||
# Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
|
||||
# This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
|
||||
# TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.
|
||||
@@ -483,23 +705,84 @@ def main(args):
|
||||
"Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
|
||||
)
|
||||
|
||||
# Parse instance and class inputs, and double check that lengths match
|
||||
instance_data_dir = args.instance_data_dir.split(",")
|
||||
instance_prompt = args.instance_prompt.split(",")
|
||||
assert all(
|
||||
x == len(instance_data_dir) for x in [len(instance_data_dir), len(instance_prompt)]
|
||||
), "Instance data dir and prompt inputs are not of the same length."
|
||||
instance_data_dir = []
|
||||
instance_prompt = []
|
||||
class_data_dir = [] if args.with_prior_preservation else None
|
||||
class_prompt = [] if args.with_prior_preservation else None
|
||||
if args.concepts_list:
|
||||
with open(args.concepts_list, "r") as f:
|
||||
concepts_list = json.load(f)
|
||||
|
||||
if args.with_prior_preservation:
|
||||
class_data_dir = args.class_data_dir.split(",")
|
||||
class_prompt = args.class_prompt.split(",")
|
||||
assert all(
|
||||
x == len(instance_data_dir)
|
||||
for x in [len(instance_data_dir), len(instance_prompt), len(class_data_dir), len(class_prompt)]
|
||||
), "Instance & class data dir or prompt inputs are not of the same length."
|
||||
if args.validation_steps:
|
||||
args.validation_prompt = []
|
||||
args.validation_number_images = []
|
||||
args.validation_negative_prompt = []
|
||||
args.validation_inference_steps = []
|
||||
args.validation_guidance_scale = []
|
||||
|
||||
for concept in concepts_list:
|
||||
instance_data_dir.append(concept["instance_data_dir"])
|
||||
instance_prompt.append(concept["instance_prompt"])
|
||||
|
||||
if args.with_prior_preservation:
|
||||
try:
|
||||
class_data_dir.append(concept["class_data_dir"])
|
||||
class_prompt.append(concept["class_prompt"])
|
||||
except KeyError:
|
||||
raise KeyError(
|
||||
"`class_data_dir` or `class_prompt` not found in concepts_list while using "
|
||||
"`with_prior_preservation`."
|
||||
)
|
||||
else:
|
||||
if "class_data_dir" in concept:
|
||||
warnings.warn(
|
||||
"Ignoring `class_data_dir` key, to use it you need to enable `with_prior_preservation`."
|
||||
)
|
||||
if "class_prompt" in concept:
|
||||
warnings.warn(
|
||||
"Ignoring `class_prompt` key, to use it you need to enable `with_prior_preservation`."
|
||||
)
|
||||
|
||||
if args.validation_steps:
|
||||
args.validation_prompt.append(concept.get("validation_prompt", None))
|
||||
args.validation_number_images.append(concept.get("validation_number_images", 4))
|
||||
args.validation_negative_prompt.append(concept.get("validation_negative_prompt", None))
|
||||
args.validation_inference_steps.append(concept.get("validation_inference_steps", 25))
|
||||
args.validation_guidance_scale.append(concept.get("validation_guidance_scale", 7.5))
|
||||
else:
|
||||
class_data_dir = args.class_data_dir
|
||||
class_prompt = args.class_prompt
|
||||
# Parse instance and class inputs, and double check that lengths match
|
||||
instance_data_dir = args.instance_data_dir.split(",")
|
||||
instance_prompt = args.instance_prompt.split(",")
|
||||
assert all(
|
||||
x == len(instance_data_dir) for x in [len(instance_data_dir), len(instance_prompt)]
|
||||
), "Instance data dir and prompt inputs are not of the same length."
|
||||
|
||||
if args.with_prior_preservation:
|
||||
class_data_dir = args.class_data_dir.split(",")
|
||||
class_prompt = args.class_prompt.split(",")
|
||||
assert all(
|
||||
x == len(instance_data_dir)
|
||||
for x in [len(instance_data_dir), len(instance_prompt), len(class_data_dir), len(class_prompt)]
|
||||
), "Instance & class data dir or prompt inputs are not of the same length."
|
||||
|
||||
if args.validation_steps:
|
||||
validation_prompts = args.validation_prompt.split(",")
|
||||
num_of_validation_prompts = len(validation_prompts)
|
||||
args.validation_prompt = validation_prompts
|
||||
args.validation_number_images = [args.validation_number_images] * num_of_validation_prompts
|
||||
|
||||
negative_validation_prompts = [None] * num_of_validation_prompts
|
||||
if args.validation_negative_prompt:
|
||||
negative_validation_prompts = args.validation_negative_prompt.split(",")
|
||||
while len(negative_validation_prompts) < num_of_validation_prompts:
|
||||
negative_validation_prompts.append(None)
|
||||
args.validation_negative_prompt = negative_validation_prompts
|
||||
|
||||
assert num_of_validation_prompts == len(
|
||||
negative_validation_prompts
|
||||
), "The length of negative prompts for validation is greater than the number of validation prompts."
|
||||
args.validation_inference_steps = [args.validation_inference_steps] * num_of_validation_prompts
|
||||
args.validation_guidance_scale = [args.validation_guidance_scale] * num_of_validation_prompts
|
||||
|
||||
# Make one log on every process with the configuration for debugging.
|
||||
logging.basicConfig(
|
||||
@@ -559,21 +842,24 @@ def main(args):
|
||||
):
|
||||
images = pipeline(example["prompt"]).images
|
||||
|
||||
for i, image in enumerate(images):
|
||||
for ii, image in enumerate(images):
|
||||
hash_image = hashlib.sha1(image.tobytes()).hexdigest()
|
||||
image_filename = (
|
||||
class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
|
||||
class_images_dir / f"{example['index'][ii] + cur_class_images}-{hash_image}.jpg"
|
||||
)
|
||||
image.save(image_filename)
|
||||
|
||||
# Clean up the memory deleting one-time-use variables.
|
||||
del pipeline
|
||||
del sample_dataloader
|
||||
del sample_dataset
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Handle the repository creation
|
||||
if accelerator.is_main_process:
|
||||
if args.output_dir is not None:
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
if args.push_to_hub:
|
||||
repo_id = create_repo(
|
||||
@@ -581,6 +867,7 @@ def main(args):
|
||||
).repo_id
|
||||
|
||||
# Load the tokenizer
|
||||
tokenizer = None
|
||||
if args.tokenizer_name:
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False)
|
||||
elif args.pretrained_model_name_or_path:
|
||||
@@ -658,7 +945,7 @@ def main(args):
|
||||
train_dataset = DreamBoothDataset(
|
||||
instance_data_root=instance_data_dir,
|
||||
instance_prompt=instance_prompt,
|
||||
class_data_root=class_data_dir if args.with_prior_preservation else None,
|
||||
class_data_root=class_data_dir,
|
||||
class_prompt=class_prompt,
|
||||
tokenizer=tokenizer,
|
||||
size=args.resolution,
|
||||
@@ -720,7 +1007,7 @@ def main(args):
|
||||
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||
|
||||
# We need to initialize the trackers we use, and also store our configuration.
|
||||
# The trackers initializes automatically on the main process.
|
||||
# The trackers initialize automatically on the main process.
|
||||
if accelerator.is_main_process:
|
||||
accelerator.init_trackers("dreambooth", config=vars(args))
|
||||
|
||||
@@ -741,10 +1028,10 @@ def main(args):
|
||||
# Potentially load in the weights and states from a previous save
|
||||
if args.resume_from_checkpoint:
|
||||
if args.resume_from_checkpoint != "latest":
|
||||
path = os.path.basename(args.resume_from_checkpoint)
|
||||
path = basename(args.resume_from_checkpoint)
|
||||
else:
|
||||
# Get the mos recent checkpoint
|
||||
dirs = os.listdir(args.output_dir)
|
||||
dirs = listdir(args.output_dir)
|
||||
dirs = [d for d in dirs if d.startswith("checkpoint")]
|
||||
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
|
||||
path = dirs[-1] if len(dirs) > 0 else None
|
||||
@@ -756,7 +1043,7 @@ def main(args):
|
||||
args.resume_from_checkpoint = None
|
||||
else:
|
||||
accelerator.print(f"Resuming from checkpoint {path}")
|
||||
accelerator.load_state(os.path.join(args.output_dir, path))
|
||||
accelerator.load_state(join(args.output_dir, path))
|
||||
global_step = int(path.split("-")[1])
|
||||
|
||||
resume_global_step = global_step * args.gradient_accumulation_steps
|
||||
@@ -787,24 +1074,26 @@ def main(args):
|
||||
noise = torch.randn_like(latents)
|
||||
bsz = latents.shape[0]
|
||||
# Sample a random timestep for each image
|
||||
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
|
||||
timesteps = timesteps.long()
|
||||
time_steps = torch.randint(
|
||||
0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device
|
||||
)
|
||||
time_steps = time_steps.long()
|
||||
|
||||
# Add noise to the latents according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process)
|
||||
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||
noisy_latents = noise_scheduler.add_noise(latents, noise, time_steps)
|
||||
|
||||
# Get the text embedding for conditioning
|
||||
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
|
||||
|
||||
# Predict the noise residual
|
||||
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
||||
model_pred = unet(noisy_latents, time_steps, encoder_hidden_states).sample
|
||||
|
||||
# Get the target for loss depending on the prediction type
|
||||
if noise_scheduler.config.prediction_type == "epsilon":
|
||||
target = noise
|
||||
elif noise_scheduler.config.prediction_type == "v_prediction":
|
||||
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
||||
target = noise_scheduler.get_velocity(latents, noise, time_steps)
|
||||
else:
|
||||
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
|
||||
|
||||
@@ -834,19 +1123,34 @@ def main(args):
|
||||
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
optimizer.zero_grad(set_to_none=args.set_grads_to_none)
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
|
||||
if global_step % args.checkpointing_steps == 0:
|
||||
if accelerator.is_main_process:
|
||||
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
||||
if accelerator.is_main_process:
|
||||
if global_step % args.checkpointing_steps == 0:
|
||||
save_path = join(args.output_dir, f"checkpoint-{global_step}")
|
||||
accelerator.save_state(save_path)
|
||||
logger.info(f"Saved state to {save_path}")
|
||||
|
||||
if (
|
||||
args.validation_steps
|
||||
and any(args.validation_prompt)
|
||||
and global_step % args.validation_steps == 0
|
||||
):
|
||||
images_set = generate_validation_images(
|
||||
text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype
|
||||
)
|
||||
for images, validation_prompt in zip(images_set, args.validation_prompt):
|
||||
if len(images) > 0:
|
||||
label = str(uuid.uuid1())[:8] # generate an id for different set of images
|
||||
log_validation_images_to_tracker(
|
||||
images, label, validation_prompt, accelerator, global_step
|
||||
)
|
||||
|
||||
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
|
||||
progress_bar.set_postfix(**logs)
|
||||
accelerator.log(logs, step=global_step)
|
||||
@@ -854,7 +1158,7 @@ def main(args):
|
||||
if global_step >= args.max_train_steps:
|
||||
break
|
||||
|
||||
# Create the pipeline using using the trained modules and save it.
|
||||
# Create the pipeline using the trained modules and save it.
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
|
||||
Reference in New Issue
Block a user