Compare commits

...

2 Commits

Author SHA1 Message Date
apolinário
a719c3d67f style 2023-12-06 14:18:58 +01:00
apolinário
85d93ab09a add cache latents 2023-12-06 14:12:29 +01:00

View File

@@ -133,7 +133,7 @@ def save_model_card(
diffusers_imports_pivotal = """from huggingface_hub import hf_hub_download diffusers_imports_pivotal = """from huggingface_hub import hf_hub_download
from safetensors.torch import load_file from safetensors.torch import load_file
""" """
diffusers_example_pivotal = f"""embedding_path = hf_hub_download(repo_id="{repo_id}", filename="embeddings.safetensors", repo_type="model") diffusers_example_pivotal = f"""embedding_path = hf_hub_download(repo_id='{repo_id}', filename="embeddings.safetensors", repo_type="model")
state_dict = load_file(embedding_path) state_dict = load_file(embedding_path)
pipeline.load_textual_inversion(state_dict["clip_l"], token=["<s0>", "<s1>"], text_encoder=pipe.text_encoder, tokenizer=pipe.tokenizer) pipeline.load_textual_inversion(state_dict["clip_l"], token=["<s0>", "<s1>"], text_encoder=pipe.text_encoder, tokenizer=pipe.tokenizer)
pipeline.load_textual_inversion(state_dict["clip_g"], token=["<s0>", "<s1>"], text_encoder=pipe.text_encoder_2, tokenizer=pipe.tokenizer_2) pipeline.load_textual_inversion(state_dict["clip_g"], token=["<s0>", "<s1>"], text_encoder=pipe.text_encoder_2, tokenizer=pipe.tokenizer_2)
@@ -145,8 +145,7 @@ pipeline.load_textual_inversion(state_dict["clip_g"], token=["<s0>", "<s1>"], te
to trigger concept `{key}` → use `{tokens}` in your prompt \n to trigger concept `{key}` → use `{tokens}` in your prompt \n
""" """
yaml = f""" yaml = f"""---
---
tags: tags:
- stable-diffusion-xl - stable-diffusion-xl
- stable-diffusion-xl-diffusers - stable-diffusion-xl-diffusers
@@ -159,7 +158,7 @@ base_model: {base_model}
instance_prompt: {instance_prompt} instance_prompt: {instance_prompt}
license: openrail++ license: openrail++
--- ---
""" """
model_card = f""" model_card = f"""
# SDXL LoRA DreamBooth - {repo_id} # SDXL LoRA DreamBooth - {repo_id}
@@ -170,14 +169,6 @@ license: openrail++
### These are {repo_id} LoRA adaption weights for {base_model}. ### These are {repo_id} LoRA adaption weights for {base_model}.
The weights were trained using [DreamBooth](https://dreambooth.github.io/).
LoRA for the text encoder was enabled: {train_text_encoder}.
Pivotal tuning was enabled: {train_text_encoder_ti}.
Special VAE used for training: {vae_path}.
## Trigger words ## Trigger words
{trigger_str} {trigger_str}
@@ -196,11 +187,24 @@ image = pipeline('{validation_prompt if validation_prompt else instance_prompt}'
For more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters) For more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters)
## Download model (use it with UIs such as AUTO1111, Comfy, SD.Next, Invoke) ## Download model
Weights for this model are available in Safetensors format. ### Use it with UIs such as AUTOMATIC1111, Comfy UI, SD.Next, Invoke
[Download]({repo_id}/tree/main) them in the Files & versions tab. - Download the LoRA *.safetensors [here](/{repo_id}/blob/main/pytorch_lora_weights.safetensors). Rename it and place it on your Lora folder.
- Download the text embeddings *.safetensors [here](/{repo_id}/blob/main/embeddings.safetensors). Rename it and place it on it on your embeddings folder.
All [Files & versions](/{repo_id}/tree/main).
## Details
The weights were trained using [🧨 diffusers Advanced Dreambooth Training Script](https://github.com/huggingface/diffusers/blob/main/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py).
LoRA for the text encoder was enabled. {train_text_encoder}.
Pivotal tuning was enabled: {train_text_encoder_ti}.
Special VAE used for training: {vae_path}.
""" """
with open(os.path.join(repo_folder, "README.md"), "w") as f: with open(os.path.join(repo_folder, "README.md"), "w") as f:
@@ -667,6 +671,12 @@ def parse_args(input_args=None):
default=4, default=4,
help=("The dimension of the LoRA update matrices."), help=("The dimension of the LoRA update matrices."),
) )
parser.add_argument(
"--cache_latents",
action="store_true",
default=False,
help="Cache the VAE latents",
)
if input_args is not None: if input_args is not None:
args = parser.parse_args(input_args) args = parser.parse_args(input_args)
@@ -1170,6 +1180,7 @@ def main(args):
revision=args.revision, revision=args.revision,
variant=args.variant, variant=args.variant,
) )
vae_scaling_factor = vae.config.scaling_factor
unet = UNet2DConditionModel.from_pretrained( unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
) )
@@ -1600,6 +1611,20 @@ def main(args):
args.validation_prompt = args.validation_prompt.replace(token_abs, "".join(token_replacement)) args.validation_prompt = args.validation_prompt.replace(token_abs, "".join(token_replacement))
print("validation prompt:", args.validation_prompt) print("validation prompt:", args.validation_prompt)
if args.cache_latents:
latents_cache = []
for batch in tqdm(train_dataloader, desc="Caching latents"):
with torch.no_grad():
batch["pixel_values"] = batch["pixel_values"].to(
accelerator.device, non_blocking=True, dtype=torch.float32
)
latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist)
if args.validation_prompt is None:
del vae
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Scheduler and math around the number of training steps. # Scheduler and math around the number of training steps.
overrode_max_train_steps = False overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
@@ -1715,9 +1740,7 @@ def main(args):
unet.train() unet.train()
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
with accelerator.accumulate(unet): with accelerator.accumulate(unet):
pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
prompts = batch["prompts"] prompts = batch["prompts"]
# print(prompts)
# encode batch prompts when custom prompts are provided for each image - # encode batch prompts when custom prompts are provided for each image -
if train_dataset.custom_instance_prompts: if train_dataset.custom_instance_prompts:
if freeze_text_encoder: if freeze_text_encoder:
@@ -1729,9 +1752,13 @@ def main(args):
tokens_one = tokenize_prompt(tokenizer_one, prompts, add_special_tokens) tokens_one = tokenize_prompt(tokenizer_one, prompts, add_special_tokens)
tokens_two = tokenize_prompt(tokenizer_two, prompts, add_special_tokens) tokens_two = tokenize_prompt(tokenizer_two, prompts, add_special_tokens)
# Convert images to latent space if args.cache_latents:
model_input = vae.encode(pixel_values).latent_dist.sample() model_input = latents_cache[step].sample()
model_input = model_input * vae.config.scaling_factor else:
pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
model_input = vae.encode(pixel_values).latent_dist.sample()
model_input = model_input * vae_scaling_factor
if args.pretrained_vae_model_name_or_path is None: if args.pretrained_vae_model_name_or_path is None:
model_input = model_input.to(weight_dtype) model_input = model_input.to(weight_dtype)