Compare commits

...

2 Commits

Author SHA1 Message Date
apolinário
a719c3d67f style 2023-12-06 14:18:58 +01:00
apolinário
85d93ab09a add cache latents 2023-12-06 14:12:29 +01:00

View File

@@ -133,7 +133,7 @@ def save_model_card(
diffusers_imports_pivotal = """from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
"""
diffusers_example_pivotal = f"""embedding_path = hf_hub_download(repo_id="{repo_id}", filename="embeddings.safetensors", repo_type="model")
diffusers_example_pivotal = f"""embedding_path = hf_hub_download(repo_id='{repo_id}', filename="embeddings.safetensors", repo_type="model")
state_dict = load_file(embedding_path)
pipeline.load_textual_inversion(state_dict["clip_l"], token=["<s0>", "<s1>"], text_encoder=pipe.text_encoder, tokenizer=pipe.tokenizer)
pipeline.load_textual_inversion(state_dict["clip_g"], token=["<s0>", "<s1>"], text_encoder=pipe.text_encoder_2, tokenizer=pipe.tokenizer_2)
@@ -145,8 +145,7 @@ pipeline.load_textual_inversion(state_dict["clip_g"], token=["<s0>", "<s1>"], te
to trigger concept `{key}` → use `{tokens}` in your prompt \n
"""
yaml = f"""
---
yaml = f"""---
tags:
- stable-diffusion-xl
- stable-diffusion-xl-diffusers
@@ -159,7 +158,7 @@ base_model: {base_model}
instance_prompt: {instance_prompt}
license: openrail++
---
"""
"""
model_card = f"""
# SDXL LoRA DreamBooth - {repo_id}
@@ -170,14 +169,6 @@ license: openrail++
### These are {repo_id} LoRA adaption weights for {base_model}.
The weights were trained using [DreamBooth](https://dreambooth.github.io/).
LoRA for the text encoder was enabled: {train_text_encoder}.
Pivotal tuning was enabled: {train_text_encoder_ti}.
Special VAE used for training: {vae_path}.
## Trigger words
{trigger_str}
@@ -196,11 +187,24 @@ image = pipeline('{validation_prompt if validation_prompt else instance_prompt}'
For more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters)
## Download model (use it with UIs such as AUTO1111, Comfy, SD.Next, Invoke)
## Download model
Weights for this model are available in Safetensors format.
### Use it with UIs such as AUTOMATIC1111, Comfy UI, SD.Next, Invoke
[Download]({repo_id}/tree/main) them in the Files & versions tab.
- Download the LoRA *.safetensors [here](/{repo_id}/blob/main/pytorch_lora_weights.safetensors). Rename it and place it on your Lora folder.
- Download the text embeddings *.safetensors [here](/{repo_id}/blob/main/embeddings.safetensors). Rename it and place it on it on your embeddings folder.
All [Files & versions](/{repo_id}/tree/main).
## Details
The weights were trained using [🧨 diffusers Advanced Dreambooth Training Script](https://github.com/huggingface/diffusers/blob/main/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py).
LoRA for the text encoder was enabled. {train_text_encoder}.
Pivotal tuning was enabled: {train_text_encoder_ti}.
Special VAE used for training: {vae_path}.
"""
with open(os.path.join(repo_folder, "README.md"), "w") as f:
@@ -667,6 +671,12 @@ def parse_args(input_args=None):
default=4,
help=("The dimension of the LoRA update matrices."),
)
parser.add_argument(
"--cache_latents",
action="store_true",
default=False,
help="Cache the VAE latents",
)
if input_args is not None:
args = parser.parse_args(input_args)
@@ -1170,6 +1180,7 @@ def main(args):
revision=args.revision,
variant=args.variant,
)
vae_scaling_factor = vae.config.scaling_factor
unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
)
@@ -1600,6 +1611,20 @@ def main(args):
args.validation_prompt = args.validation_prompt.replace(token_abs, "".join(token_replacement))
print("validation prompt:", args.validation_prompt)
if args.cache_latents:
latents_cache = []
for batch in tqdm(train_dataloader, desc="Caching latents"):
with torch.no_grad():
batch["pixel_values"] = batch["pixel_values"].to(
accelerator.device, non_blocking=True, dtype=torch.float32
)
latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist)
if args.validation_prompt is None:
del vae
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
@@ -1715,9 +1740,7 @@ def main(args):
unet.train()
for step, batch in enumerate(train_dataloader):
with accelerator.accumulate(unet):
pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
prompts = batch["prompts"]
# print(prompts)
# encode batch prompts when custom prompts are provided for each image -
if train_dataset.custom_instance_prompts:
if freeze_text_encoder:
@@ -1729,9 +1752,13 @@ def main(args):
tokens_one = tokenize_prompt(tokenizer_one, prompts, add_special_tokens)
tokens_two = tokenize_prompt(tokenizer_two, prompts, add_special_tokens)
# Convert images to latent space
model_input = vae.encode(pixel_values).latent_dist.sample()
model_input = model_input * vae.config.scaling_factor
if args.cache_latents:
model_input = latents_cache[step].sample()
else:
pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
model_input = vae.encode(pixel_values).latent_dist.sample()
model_input = model_input * vae_scaling_factor
if args.pretrained_vae_model_name_or_path is None:
model_input = model_input.to(weight_dtype)