mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-06 20:44:33 +08:00
Compare commits
9 Commits
test-clean
...
controlnet
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b13dfac9dd | ||
|
|
451631be51 | ||
|
|
71d84a9ce1 | ||
|
|
cfd84dfc14 | ||
|
|
0592773d90 | ||
|
|
c1bad6e488 | ||
|
|
b62104c737 | ||
|
|
697594f635 | ||
|
|
83d0aba6c0 |
235
examples/controlnet/data.py
Normal file
235
examples/controlnet/data.py
Normal file
@@ -0,0 +1,235 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# This file is heavily inspired by https://github.com/mlfoundations/open_clip/blob/main/src/training/data.py
|
||||
|
||||
import itertools
|
||||
import json
|
||||
import math
|
||||
import random
|
||||
import re
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import webdataset as wds
|
||||
from braceexpand import braceexpand
|
||||
from torch.utils.data import default_collate
|
||||
from torchvision import transforms
|
||||
from transformers import PreTrainedTokenizer
|
||||
from webdataset.tariterators import (
|
||||
base_plus_ext,
|
||||
tar_file_expander,
|
||||
url_opener,
|
||||
valid_sample,
|
||||
)
|
||||
import numpy as np
|
||||
import cv2
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def filter_keys(key_set):
|
||||
def _f(dictionary):
|
||||
return {k: v for k, v in dictionary.items() if k in key_set}
|
||||
|
||||
return _f
|
||||
|
||||
|
||||
def group_by_keys_nothrow(data, keys=base_plus_ext, lcase=True, suffixes=None, handler=None):
|
||||
"""Return function over iterator that groups key, value pairs into samples.
|
||||
|
||||
:param keys: function that splits the key into key and extension (base_plus_ext)
|
||||
:param lcase: convert suffixes to lower case (Default value = True)
|
||||
"""
|
||||
current_sample = None
|
||||
for filesample in data:
|
||||
assert isinstance(filesample, dict)
|
||||
fname, value = filesample["fname"], filesample["data"]
|
||||
prefix, suffix = keys(fname)
|
||||
if prefix is None:
|
||||
continue
|
||||
if lcase:
|
||||
suffix = suffix.lower()
|
||||
# FIXME webdataset version throws if suffix in current_sample, but we have a potential for
|
||||
# this happening in the current LAION400m dataset if a tar ends with same prefix as the next
|
||||
# begins, rare, but can happen since prefix aren't unique across tar files in that dataset
|
||||
if current_sample is None or prefix != current_sample["__key__"] or suffix in current_sample:
|
||||
if valid_sample(current_sample):
|
||||
yield current_sample
|
||||
current_sample = dict(__key__=prefix, __url__=filesample["__url__"])
|
||||
if suffixes is None or suffix in suffixes:
|
||||
current_sample[suffix] = value
|
||||
if valid_sample(current_sample):
|
||||
yield current_sample
|
||||
|
||||
|
||||
def tarfile_to_samples_nothrow(src, handler=wds.warn_and_continue):
|
||||
# NOTE this is a re-impl of the webdataset impl with group_by_keys that doesn't throw
|
||||
streams = url_opener(src, handler=handler)
|
||||
files = tar_file_expander(streams, handler=handler)
|
||||
samples = group_by_keys_nothrow(files, handler=handler)
|
||||
return samples
|
||||
|
||||
|
||||
def control_transform(image):
|
||||
image = np.array(image)
|
||||
|
||||
low_threshold = 100
|
||||
high_threshold = 200
|
||||
|
||||
image = cv2.Canny(image, low_threshold, high_threshold)
|
||||
image = image[:, :, None]
|
||||
image = np.concatenate([image, image, image], axis=2)
|
||||
control_image = Image.fromarray(image)
|
||||
return control_image
|
||||
|
||||
|
||||
class ImageNetTransform:
|
||||
def __init__(self, resolution, center_crop=True, random_flip=False):
|
||||
self.train_transform = transforms.Compose(
|
||||
[
|
||||
transforms.Resize(resolution, interpolation=transforms.InterpolationMode.BILINEAR),
|
||||
transforms.CenterCrop(resolution),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.5], [0.5]),
|
||||
]
|
||||
)
|
||||
self.train_control_transform = transforms.Compose(
|
||||
[
|
||||
control_transform,
|
||||
transforms.Resize(resolution, interpolation=transforms.InterpolationMode.BILINEAR),
|
||||
transforms.CenterCrop(resolution),
|
||||
transforms.ToTensor(),
|
||||
]
|
||||
)
|
||||
self.eval_transform = transforms.Compose(
|
||||
[
|
||||
transforms.Resize(resolution, interpolation=transforms.InterpolationMode.BILINEAR),
|
||||
transforms.CenterCrop(resolution),
|
||||
transforms.ToTensor(),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class Text2ImageDataset:
|
||||
def __init__(
|
||||
self,
|
||||
train_shards_path_or_url: Union[str, List[str]],
|
||||
eval_shards_path_or_url: Union[str, List[str]],
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
max_seq_length: int,
|
||||
num_train_examples: int,
|
||||
per_gpu_batch_size: int,
|
||||
global_batch_size: int,
|
||||
num_workers: int,
|
||||
tokenizer_two: Optional[PreTrainedTokenizer] = None,
|
||||
resolution: int = 256,
|
||||
center_crop: bool = True,
|
||||
random_flip: bool = False,
|
||||
shuffle_buffer_size: int = 1000,
|
||||
pin_memory: bool = False,
|
||||
persistent_workers: bool = False,
|
||||
):
|
||||
transform = ImageNetTransform(resolution, center_crop, random_flip)
|
||||
|
||||
def tokenize(text):
|
||||
input_ids = tokenizer(
|
||||
text, max_length=max_seq_length, padding="max_length", truncation=True, return_tensors="pt"
|
||||
).input_ids
|
||||
return input_ids[0]
|
||||
|
||||
def tokenize_2(text):
|
||||
input_ids = tokenizer_two(
|
||||
text, max_length=max_seq_length, padding="max_length", truncation=True, return_tensors="pt"
|
||||
).input_ids
|
||||
return input_ids[0]
|
||||
|
||||
if not isinstance(train_shards_path_or_url, str):
|
||||
train_shards_path_or_url = [list(braceexpand(urls)) for urls in train_shards_path_or_url]
|
||||
# flatten list using itertools
|
||||
train_shards_path_or_url = list(itertools.chain.from_iterable(train_shards_path_or_url))
|
||||
|
||||
if not isinstance(eval_shards_path_or_url, str):
|
||||
eval_shards_path_or_url = [list(braceexpand(urls)) for urls in eval_shards_path_or_url]
|
||||
# flatten list using itertools
|
||||
eval_shards_path_or_url = list(itertools.chain.from_iterable(eval_shards_path_or_url))
|
||||
|
||||
processing_pipeline = [
|
||||
wds.decode("pil", handler=wds.ignore_and_continue),
|
||||
wds.rename(image="jpg;png;jpeg;webp", control_image="jpg;png;jpeg;webp", input_ids="text;txt;caption", input_ids_2="text;txt;caption", handler=wds.warn_and_continue),
|
||||
wds.map(filter_keys(set(["image", "control_image", "input_ids", "input_ids_2"]))),
|
||||
wds.map_dict(image=transform.train_transform, control_image=transform.train_control_transform, input_ids=tokenize, input_ids_2=tokenize_2),
|
||||
wds.to_tuple("image", "control_image", "input_ids", "input_ids_2"),
|
||||
]
|
||||
|
||||
# Create train dataset and loader
|
||||
pipeline = [
|
||||
wds.ResampledShards(train_shards_path_or_url),
|
||||
tarfile_to_samples_nothrow,
|
||||
wds.shuffle(shuffle_buffer_size),
|
||||
*processing_pipeline,
|
||||
wds.batched(per_gpu_batch_size, partial=False, collation_fn=default_collate),
|
||||
]
|
||||
|
||||
num_batches = math.ceil(num_train_examples / global_batch_size)
|
||||
num_worker_batches = math.ceil(num_train_examples / (global_batch_size * num_workers)) # per dataloader worker
|
||||
num_batches = num_worker_batches * num_workers
|
||||
num_samples = num_batches * global_batch_size
|
||||
|
||||
# each worker is iterating over this
|
||||
self._train_dataset = wds.DataPipeline(*pipeline).with_epoch(num_worker_batches)
|
||||
self._train_dataloader = wds.WebLoader(
|
||||
self._train_dataset,
|
||||
batch_size=None,
|
||||
shuffle=False,
|
||||
num_workers=num_workers,
|
||||
pin_memory=pin_memory,
|
||||
persistent_workers=persistent_workers,
|
||||
)
|
||||
# add meta-data to dataloader instance for convenience
|
||||
self._train_dataloader.num_batches = num_batches
|
||||
self._train_dataloader.num_samples = num_samples
|
||||
|
||||
# Create eval dataset and loader
|
||||
pipeline = [
|
||||
wds.SimpleShardList(eval_shards_path_or_url),
|
||||
wds.split_by_worker,
|
||||
wds.tarfile_to_samples(handler=wds.ignore_and_continue),
|
||||
*processing_pipeline,
|
||||
wds.batched(per_gpu_batch_size, partial=False, collation_fn=default_collate),
|
||||
]
|
||||
self._eval_dataset = wds.DataPipeline(*pipeline)
|
||||
self._eval_dataloader = wds.WebLoader(
|
||||
self._eval_dataset,
|
||||
batch_size=None,
|
||||
shuffle=False,
|
||||
num_workers=num_workers,
|
||||
pin_memory=pin_memory,
|
||||
persistent_workers=persistent_workers,
|
||||
)
|
||||
|
||||
@property
|
||||
def train_dataset(self):
|
||||
return self._train_dataset
|
||||
|
||||
@property
|
||||
def train_dataloader(self):
|
||||
return self._train_dataloader
|
||||
|
||||
@property
|
||||
def eval_dataset(self):
|
||||
return self._eval_dataset
|
||||
|
||||
@property
|
||||
def eval_dataloader(self):
|
||||
return self._eval_dataloader
|
||||
28
examples/controlnet/run.sh
Normal file
28
examples/controlnet/run.sh
Normal file
@@ -0,0 +1,28 @@
|
||||
export MODEL_DIR="stabilityai/stable-diffusion-xl-base-0.9"
|
||||
export OUTPUT_DIR="controlnet-0-9-canny"
|
||||
|
||||
# --max_train_steps=15000 \
|
||||
accelerate launch train_controlnet_webdatasets.py \
|
||||
--pretrained_model_name_or_path=$MODEL_DIR \
|
||||
--pretrained_vae_model_name_or_path="madebyollin/sdxl-vae-fp16-fix" \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--mixed_precision="fp16" \
|
||||
--resolution=1024 \
|
||||
--learning_rate=1e-5 \
|
||||
--max_train_steps=30000 \
|
||||
--max_train_samples=12000000 \
|
||||
--dataloader_num_workers=4 \
|
||||
--validation_image "./c_image_0.png" "./c_image_1.png" "./c_image_2.png" "./c_image_3.png" "./c_image_4.png" "./c_image_5.png" "./c_image_6.png" "./c_image_7.png" \
|
||||
--validation_prompt "beautiful room" "two paradise birds" "a snowy house behind a forest" "a couple watching a romantic sunset" "boats in the Amazonas" "a beautiful face of a woman" "a skater in Brooklyn" "a tornado in Iowa" \
|
||||
--train_shards_path_or_url "pipe:aws s3 cp s3://muse-datasets/laion-aesthetic6plus-data/{00000..01208}.tar -" \
|
||||
--eval_shards_path_or_url "pipe:aws s3 cp s3://muse-datasets/laion-aesthetic6plus-data/{01209..01210}.tar -" \
|
||||
--proportion_empty_prompts 0.5 \
|
||||
--validation_steps=1000 \
|
||||
--train_batch_size=12 \
|
||||
--gradient_checkpointing \
|
||||
--use_8bit_adam \
|
||||
--enable_xformers_memory_efficient_attention \
|
||||
--gradient_accumulation_steps=1 \
|
||||
--seed=42 \
|
||||
--report_to="wandb" \
|
||||
--push_to_hub
|
||||
1333
examples/controlnet/train_controlnet_webdatasets.py
Normal file
1333
examples/controlnet/train_controlnet_webdatasets.py
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user