|
|
|
|
@@ -123,16 +123,26 @@ def save_model_card(
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
trigger_str = f"You should use {instance_prompt} to trigger the image generation."
|
|
|
|
|
diffusers_imports_pivotal = ""
|
|
|
|
|
diffusers_example_pivotal = ""
|
|
|
|
|
if train_text_encoder_ti:
|
|
|
|
|
trigger_str = (
|
|
|
|
|
"To trigger image generation of trained concept(or concepts) replace each concept identifier "
|
|
|
|
|
"in you prompt with the new inserted tokens:\n"
|
|
|
|
|
)
|
|
|
|
|
diffusers_imports_pivotal = """from huggingface_hub import hf_hub_download
|
|
|
|
|
from safetensors.torch import load_file
|
|
|
|
|
"""
|
|
|
|
|
diffusers_example_pivotal = f"""embedding_path = hf_hub_download(repo_id="{repo_id}", filename="embeddings.safetensors", repo_type="model")
|
|
|
|
|
state_dict = load_file(embedding_path)
|
|
|
|
|
pipeline.load_textual_inversion(state_dict["clip_l"], token=["<s0>", "<s1>"], text_encoder=pipe.text_encoder, tokenizer=pipe.tokenizer)
|
|
|
|
|
pipeline.load_textual_inversion(state_dict["clip_g"], token=["<s0>", "<s1>"], text_encoder=pipe.text_encoder_2, tokenizer=pipe.tokenizer_2)
|
|
|
|
|
"""
|
|
|
|
|
if token_abstraction_dict:
|
|
|
|
|
for key, value in token_abstraction_dict.items():
|
|
|
|
|
tokens = "".join(value)
|
|
|
|
|
trigger_str += f"""
|
|
|
|
|
to trigger concept `{key}->` use `{tokens}` in your prompt \n
|
|
|
|
|
to trigger concept `{key}` → use `{tokens}` in your prompt \n
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
yaml = f"""
|
|
|
|
|
@@ -172,7 +182,21 @@ Special VAE used for training: {vae_path}.
|
|
|
|
|
|
|
|
|
|
{trigger_str}
|
|
|
|
|
|
|
|
|
|
## Download model
|
|
|
|
|
## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers)
|
|
|
|
|
|
|
|
|
|
```py
|
|
|
|
|
from diffusers import AutoPipelineForText2Image
|
|
|
|
|
import torch
|
|
|
|
|
{diffusers_imports_pivotal}
|
|
|
|
|
pipeline = AutoPipelineForText2Image.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0', torch_dtype=torch.float16).to('cuda')
|
|
|
|
|
pipeline.load_lora_weights('{repo_id}', weight_name='pytorch_lora_weights.safetensors')
|
|
|
|
|
{diffusers_example_pivotal}
|
|
|
|
|
image = pipeline('{validation_prompt if validation_prompt else instance_prompt}').images[0]
|
|
|
|
|
```
|
|
|
|
|
|
|
|
|
|
For more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters)
|
|
|
|
|
|
|
|
|
|
## Download model (use it with UIs such as AUTO1111, Comfy, SD.Next, Invoke)
|
|
|
|
|
|
|
|
|
|
Weights for this model are available in Safetensors format.
|
|
|
|
|
|
|
|
|
|
@@ -791,6 +815,12 @@ class DreamBoothDataset(Dataset):
|
|
|
|
|
instance_data_root,
|
|
|
|
|
instance_prompt,
|
|
|
|
|
class_prompt,
|
|
|
|
|
dataset_name,
|
|
|
|
|
dataset_config_name,
|
|
|
|
|
cache_dir,
|
|
|
|
|
image_column,
|
|
|
|
|
caption_column,
|
|
|
|
|
train_text_encoder_ti,
|
|
|
|
|
class_data_root=None,
|
|
|
|
|
class_num=None,
|
|
|
|
|
token_abstraction_dict=None, # token mapping for textual inversion
|
|
|
|
|
@@ -805,10 +835,10 @@ class DreamBoothDataset(Dataset):
|
|
|
|
|
self.custom_instance_prompts = None
|
|
|
|
|
self.class_prompt = class_prompt
|
|
|
|
|
self.token_abstraction_dict = token_abstraction_dict
|
|
|
|
|
|
|
|
|
|
self.train_text_encoder_ti = train_text_encoder_ti
|
|
|
|
|
# if --dataset_name is provided or a metadata jsonl file is provided in the local --instance_data directory,
|
|
|
|
|
# we load the training data using load_dataset
|
|
|
|
|
if args.dataset_name is not None:
|
|
|
|
|
if dataset_name is not None:
|
|
|
|
|
try:
|
|
|
|
|
from datasets import load_dataset
|
|
|
|
|
except ImportError:
|
|
|
|
|
@@ -821,26 +851,25 @@ class DreamBoothDataset(Dataset):
|
|
|
|
|
# See more about loading custom images at
|
|
|
|
|
# https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script
|
|
|
|
|
dataset = load_dataset(
|
|
|
|
|
args.dataset_name,
|
|
|
|
|
args.dataset_config_name,
|
|
|
|
|
cache_dir=args.cache_dir,
|
|
|
|
|
dataset_name,
|
|
|
|
|
dataset_config_name,
|
|
|
|
|
cache_dir=cache_dir,
|
|
|
|
|
)
|
|
|
|
|
# Preprocessing the datasets.
|
|
|
|
|
column_names = dataset["train"].column_names
|
|
|
|
|
|
|
|
|
|
# 6. Get the column names for input/target.
|
|
|
|
|
if args.image_column is None:
|
|
|
|
|
if image_column is None:
|
|
|
|
|
image_column = column_names[0]
|
|
|
|
|
logger.info(f"image column defaulting to {image_column}")
|
|
|
|
|
else:
|
|
|
|
|
image_column = args.image_column
|
|
|
|
|
if image_column not in column_names:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
|
|
|
|
|
f"`--image_column` value '{image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
|
|
|
|
|
)
|
|
|
|
|
instance_images = dataset["train"][image_column]
|
|
|
|
|
|
|
|
|
|
if args.caption_column is None:
|
|
|
|
|
if caption_column is None:
|
|
|
|
|
logger.info(
|
|
|
|
|
"No caption column provided, defaulting to instance_prompt for all images. If your dataset "
|
|
|
|
|
"contains captions/prompts for the images, make sure to specify the "
|
|
|
|
|
@@ -848,11 +877,11 @@ class DreamBoothDataset(Dataset):
|
|
|
|
|
)
|
|
|
|
|
self.custom_instance_prompts = None
|
|
|
|
|
else:
|
|
|
|
|
if args.caption_column not in column_names:
|
|
|
|
|
if caption_column not in column_names:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
|
|
|
|
|
f"`--caption_column` value '{caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
|
|
|
|
|
)
|
|
|
|
|
custom_instance_prompts = dataset["train"][args.caption_column]
|
|
|
|
|
custom_instance_prompts = dataset["train"][caption_column]
|
|
|
|
|
# create final list of captions according to --repeats
|
|
|
|
|
self.custom_instance_prompts = []
|
|
|
|
|
for caption in custom_instance_prompts:
|
|
|
|
|
@@ -907,7 +936,7 @@ class DreamBoothDataset(Dataset):
|
|
|
|
|
if self.custom_instance_prompts:
|
|
|
|
|
caption = self.custom_instance_prompts[index % self.num_instance_images]
|
|
|
|
|
if caption:
|
|
|
|
|
if args.train_text_encoder_ti:
|
|
|
|
|
if self.train_text_encoder_ti:
|
|
|
|
|
# replace instances of --token_abstraction in caption with the new tokens: "<si><si+1>" etc.
|
|
|
|
|
for token_abs, token_replacement in self.token_abstraction_dict.items():
|
|
|
|
|
caption = caption.replace(token_abs, "".join(token_replacement))
|
|
|
|
|
@@ -1093,10 +1122,10 @@ def main(args):
|
|
|
|
|
if args.output_dir is not None:
|
|
|
|
|
os.makedirs(args.output_dir, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
model_id = args.hub_model_id or Path(args.output_dir).name
|
|
|
|
|
repo_id = None
|
|
|
|
|
if args.push_to_hub:
|
|
|
|
|
repo_id = create_repo(
|
|
|
|
|
repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
|
|
|
|
|
).repo_id
|
|
|
|
|
repo_id = create_repo(repo_id=model_id, exist_ok=True, token=args.hub_token).repo_id
|
|
|
|
|
|
|
|
|
|
# Load the tokenizers
|
|
|
|
|
tokenizer_one = AutoTokenizer.from_pretrained(
|
|
|
|
|
@@ -1464,6 +1493,12 @@ def main(args):
|
|
|
|
|
instance_data_root=args.instance_data_dir,
|
|
|
|
|
instance_prompt=args.instance_prompt,
|
|
|
|
|
class_prompt=args.class_prompt,
|
|
|
|
|
dataset_name=args.dataset_name,
|
|
|
|
|
dataset_config_name=args.dataset_config_name,
|
|
|
|
|
cache_dir=args.cache_dir,
|
|
|
|
|
image_column=args.image_column,
|
|
|
|
|
train_text_encoder_ti=args.train_text_encoder_ti,
|
|
|
|
|
caption_column=args.caption_column,
|
|
|
|
|
class_data_root=args.class_data_dir if args.with_prior_preservation else None,
|
|
|
|
|
token_abstraction_dict=token_abstraction_dict if args.train_text_encoder_ti else None,
|
|
|
|
|
class_num=args.num_class_images,
|
|
|
|
|
@@ -2004,23 +2039,23 @@ def main(args):
|
|
|
|
|
}
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if args.push_to_hub:
|
|
|
|
|
if args.train_text_encoder_ti:
|
|
|
|
|
embedding_handler.save_embeddings(
|
|
|
|
|
f"{args.output_dir}/embeddings.safetensors",
|
|
|
|
|
)
|
|
|
|
|
save_model_card(
|
|
|
|
|
repo_id,
|
|
|
|
|
images=images,
|
|
|
|
|
base_model=args.pretrained_model_name_or_path,
|
|
|
|
|
train_text_encoder=args.train_text_encoder,
|
|
|
|
|
train_text_encoder_ti=args.train_text_encoder_ti,
|
|
|
|
|
token_abstraction_dict=train_dataset.token_abstraction_dict,
|
|
|
|
|
instance_prompt=args.instance_prompt,
|
|
|
|
|
validation_prompt=args.validation_prompt,
|
|
|
|
|
repo_folder=args.output_dir,
|
|
|
|
|
vae_path=args.pretrained_vae_model_name_or_path,
|
|
|
|
|
if args.train_text_encoder_ti:
|
|
|
|
|
embedding_handler.save_embeddings(
|
|
|
|
|
f"{args.output_dir}/embeddings.safetensors",
|
|
|
|
|
)
|
|
|
|
|
save_model_card(
|
|
|
|
|
model_id if not args.push_to_hub else repo_id,
|
|
|
|
|
images=images,
|
|
|
|
|
base_model=args.pretrained_model_name_or_path,
|
|
|
|
|
train_text_encoder=args.train_text_encoder,
|
|
|
|
|
train_text_encoder_ti=args.train_text_encoder_ti,
|
|
|
|
|
token_abstraction_dict=train_dataset.token_abstraction_dict,
|
|
|
|
|
instance_prompt=args.instance_prompt,
|
|
|
|
|
validation_prompt=args.validation_prompt,
|
|
|
|
|
repo_folder=args.output_dir,
|
|
|
|
|
vae_path=args.pretrained_vae_model_name_or_path,
|
|
|
|
|
)
|
|
|
|
|
if args.push_to_hub:
|
|
|
|
|
upload_folder(
|
|
|
|
|
repo_id=repo_id,
|
|
|
|
|
folder_path=args.output_dir,
|
|
|
|
|
|