Compare commits

...

9 Commits

Author SHA1 Message Date
patrick@huggingface.co
b13dfac9dd push 2023-07-21 15:26:30 +00:00
patrick@huggingface.co
451631be51 add 2023-07-21 15:19:09 +00:00
Patrick von Platen
71d84a9ce1 fix original sizes 2023-07-21 16:21:41 +02:00
patrick@huggingface.co
cfd84dfc14 make it work 2023-07-20 21:27:14 +00:00
Patrick von Platen
0592773d90 Correct 2023-07-20 21:03:12 +02:00
Patrick von Platen
c1bad6e488 Correct name 2023-07-20 20:55:48 +02:00
Patrick von Platen
b62104c737 improve 2023-07-20 18:46:00 +00:00
Patrick von Platen
697594f635 improve 2023-07-20 18:41:33 +00:00
Patrick von Platen
83d0aba6c0 [ControlNet Webdatasets] Train controlnet webdatasets 2023-07-20 18:17:08 +00:00
3 changed files with 1596 additions and 0 deletions

235
examples/controlnet/data.py Normal file
View File

@@ -0,0 +1,235 @@
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is heavily inspired by https://github.com/mlfoundations/open_clip/blob/main/src/training/data.py
import itertools
import json
import math
import random
import re
from typing import List, Optional, Union
import webdataset as wds
from braceexpand import braceexpand
from torch.utils.data import default_collate
from torchvision import transforms
from transformers import PreTrainedTokenizer
from webdataset.tariterators import (
base_plus_ext,
tar_file_expander,
url_opener,
valid_sample,
)
import numpy as np
import cv2
from PIL import Image
def filter_keys(key_set):
def _f(dictionary):
return {k: v for k, v in dictionary.items() if k in key_set}
return _f
def group_by_keys_nothrow(data, keys=base_plus_ext, lcase=True, suffixes=None, handler=None):
"""Return function over iterator that groups key, value pairs into samples.
:param keys: function that splits the key into key and extension (base_plus_ext)
:param lcase: convert suffixes to lower case (Default value = True)
"""
current_sample = None
for filesample in data:
assert isinstance(filesample, dict)
fname, value = filesample["fname"], filesample["data"]
prefix, suffix = keys(fname)
if prefix is None:
continue
if lcase:
suffix = suffix.lower()
# FIXME webdataset version throws if suffix in current_sample, but we have a potential for
# this happening in the current LAION400m dataset if a tar ends with same prefix as the next
# begins, rare, but can happen since prefix aren't unique across tar files in that dataset
if current_sample is None or prefix != current_sample["__key__"] or suffix in current_sample:
if valid_sample(current_sample):
yield current_sample
current_sample = dict(__key__=prefix, __url__=filesample["__url__"])
if suffixes is None or suffix in suffixes:
current_sample[suffix] = value
if valid_sample(current_sample):
yield current_sample
def tarfile_to_samples_nothrow(src, handler=wds.warn_and_continue):
# NOTE this is a re-impl of the webdataset impl with group_by_keys that doesn't throw
streams = url_opener(src, handler=handler)
files = tar_file_expander(streams, handler=handler)
samples = group_by_keys_nothrow(files, handler=handler)
return samples
def control_transform(image):
image = np.array(image)
low_threshold = 100
high_threshold = 200
image = cv2.Canny(image, low_threshold, high_threshold)
image = image[:, :, None]
image = np.concatenate([image, image, image], axis=2)
control_image = Image.fromarray(image)
return control_image
class ImageNetTransform:
def __init__(self, resolution, center_crop=True, random_flip=False):
self.train_transform = transforms.Compose(
[
transforms.Resize(resolution, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(resolution),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
self.train_control_transform = transforms.Compose(
[
control_transform,
transforms.Resize(resolution, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(resolution),
transforms.ToTensor(),
]
)
self.eval_transform = transforms.Compose(
[
transforms.Resize(resolution, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(resolution),
transforms.ToTensor(),
]
)
class Text2ImageDataset:
def __init__(
self,
train_shards_path_or_url: Union[str, List[str]],
eval_shards_path_or_url: Union[str, List[str]],
tokenizer: PreTrainedTokenizer,
max_seq_length: int,
num_train_examples: int,
per_gpu_batch_size: int,
global_batch_size: int,
num_workers: int,
tokenizer_two: Optional[PreTrainedTokenizer] = None,
resolution: int = 256,
center_crop: bool = True,
random_flip: bool = False,
shuffle_buffer_size: int = 1000,
pin_memory: bool = False,
persistent_workers: bool = False,
):
transform = ImageNetTransform(resolution, center_crop, random_flip)
def tokenize(text):
input_ids = tokenizer(
text, max_length=max_seq_length, padding="max_length", truncation=True, return_tensors="pt"
).input_ids
return input_ids[0]
def tokenize_2(text):
input_ids = tokenizer_two(
text, max_length=max_seq_length, padding="max_length", truncation=True, return_tensors="pt"
).input_ids
return input_ids[0]
if not isinstance(train_shards_path_or_url, str):
train_shards_path_or_url = [list(braceexpand(urls)) for urls in train_shards_path_or_url]
# flatten list using itertools
train_shards_path_or_url = list(itertools.chain.from_iterable(train_shards_path_or_url))
if not isinstance(eval_shards_path_or_url, str):
eval_shards_path_or_url = [list(braceexpand(urls)) for urls in eval_shards_path_or_url]
# flatten list using itertools
eval_shards_path_or_url = list(itertools.chain.from_iterable(eval_shards_path_or_url))
processing_pipeline = [
wds.decode("pil", handler=wds.ignore_and_continue),
wds.rename(image="jpg;png;jpeg;webp", control_image="jpg;png;jpeg;webp", input_ids="text;txt;caption", input_ids_2="text;txt;caption", handler=wds.warn_and_continue),
wds.map(filter_keys(set(["image", "control_image", "input_ids", "input_ids_2"]))),
wds.map_dict(image=transform.train_transform, control_image=transform.train_control_transform, input_ids=tokenize, input_ids_2=tokenize_2),
wds.to_tuple("image", "control_image", "input_ids", "input_ids_2"),
]
# Create train dataset and loader
pipeline = [
wds.ResampledShards(train_shards_path_or_url),
tarfile_to_samples_nothrow,
wds.shuffle(shuffle_buffer_size),
*processing_pipeline,
wds.batched(per_gpu_batch_size, partial=False, collation_fn=default_collate),
]
num_batches = math.ceil(num_train_examples / global_batch_size)
num_worker_batches = math.ceil(num_train_examples / (global_batch_size * num_workers)) # per dataloader worker
num_batches = num_worker_batches * num_workers
num_samples = num_batches * global_batch_size
# each worker is iterating over this
self._train_dataset = wds.DataPipeline(*pipeline).with_epoch(num_worker_batches)
self._train_dataloader = wds.WebLoader(
self._train_dataset,
batch_size=None,
shuffle=False,
num_workers=num_workers,
pin_memory=pin_memory,
persistent_workers=persistent_workers,
)
# add meta-data to dataloader instance for convenience
self._train_dataloader.num_batches = num_batches
self._train_dataloader.num_samples = num_samples
# Create eval dataset and loader
pipeline = [
wds.SimpleShardList(eval_shards_path_or_url),
wds.split_by_worker,
wds.tarfile_to_samples(handler=wds.ignore_and_continue),
*processing_pipeline,
wds.batched(per_gpu_batch_size, partial=False, collation_fn=default_collate),
]
self._eval_dataset = wds.DataPipeline(*pipeline)
self._eval_dataloader = wds.WebLoader(
self._eval_dataset,
batch_size=None,
shuffle=False,
num_workers=num_workers,
pin_memory=pin_memory,
persistent_workers=persistent_workers,
)
@property
def train_dataset(self):
return self._train_dataset
@property
def train_dataloader(self):
return self._train_dataloader
@property
def eval_dataset(self):
return self._eval_dataset
@property
def eval_dataloader(self):
return self._eval_dataloader

View File

@@ -0,0 +1,28 @@
export MODEL_DIR="stabilityai/stable-diffusion-xl-base-0.9"
export OUTPUT_DIR="controlnet-0-9-canny"
# --max_train_steps=15000 \
accelerate launch train_controlnet_webdatasets.py \
--pretrained_model_name_or_path=$MODEL_DIR \
--pretrained_vae_model_name_or_path="madebyollin/sdxl-vae-fp16-fix" \
--output_dir=$OUTPUT_DIR \
--mixed_precision="fp16" \
--resolution=1024 \
--learning_rate=1e-5 \
--max_train_steps=30000 \
--max_train_samples=12000000 \
--dataloader_num_workers=4 \
--validation_image "./c_image_0.png" "./c_image_1.png" "./c_image_2.png" "./c_image_3.png" "./c_image_4.png" "./c_image_5.png" "./c_image_6.png" "./c_image_7.png" \
--validation_prompt "beautiful room" "two paradise birds" "a snowy house behind a forest" "a couple watching a romantic sunset" "boats in the Amazonas" "a beautiful face of a woman" "a skater in Brooklyn" "a tornado in Iowa" \
--train_shards_path_or_url "pipe:aws s3 cp s3://muse-datasets/laion-aesthetic6plus-data/{00000..01208}.tar -" \
--eval_shards_path_or_url "pipe:aws s3 cp s3://muse-datasets/laion-aesthetic6plus-data/{01209..01210}.tar -" \
--proportion_empty_prompts 0.5 \
--validation_steps=1000 \
--train_batch_size=12 \
--gradient_checkpointing \
--use_8bit_adam \
--enable_xformers_memory_efficient_attention \
--gradient_accumulation_steps=1 \
--seed=42 \
--report_to="wandb" \
--push_to_hub

File diff suppressed because it is too large Load Diff