mirror of
https://github.com/huggingface/diffusers.git
synced 2026-03-13 03:57:51 +08:00
Compare commits
1 Commits
controlnet
...
fix-widget
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
aa39fd7cb6 |
14
.github/workflows/delete_doc_comment.yml
vendored
Normal file
14
.github/workflows/delete_doc_comment.yml
vendored
Normal file
@@ -0,0 +1,14 @@
|
||||
name: Delete doc comment
|
||||
|
||||
on:
|
||||
workflow_run:
|
||||
workflows: ["Delete doc comment trigger"]
|
||||
types:
|
||||
- completed
|
||||
|
||||
|
||||
jobs:
|
||||
delete:
|
||||
uses: huggingface/doc-builder/.github/workflows/delete_doc_comment.yml@main
|
||||
secrets:
|
||||
comment_bot_token: ${{ secrets.COMMENT_BOT_TOKEN }}
|
||||
12
.github/workflows/delete_doc_comment_trigger.yml
vendored
Normal file
12
.github/workflows/delete_doc_comment_trigger.yml
vendored
Normal file
@@ -0,0 +1,12 @@
|
||||
name: Delete doc comment trigger
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
types: [ closed ]
|
||||
|
||||
|
||||
jobs:
|
||||
delete:
|
||||
uses: huggingface/doc-builder/.github/workflows/delete_doc_comment_trigger.yml@main
|
||||
with:
|
||||
pr_number: ${{ github.event.number }}
|
||||
@@ -96,8 +96,6 @@ bfloat16 reduces the latency from 7.36 seconds to 4.63 seconds:
|
||||
|
||||
</div>
|
||||
|
||||
_(We later ran the experiments in float16 and found out that the recent versions of torchao do not incur numerical problems from float16.)_
|
||||
|
||||
**Why bfloat16?**
|
||||
|
||||
* Using a reduced numerical precision (such as float16, bfloat16) to run inference doesn’t affect the generation quality but significantly improves latency.
|
||||
@@ -166,9 +164,7 @@ prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
|
||||
image = pipe(prompt, num_inference_steps=30).images[0]
|
||||
```
|
||||
|
||||
`torch.compile` offers different backends and modes. As we’re aiming for maximum inference speed, we opt for the inductor backend using the “max-autotune”. “max-autotune” uses CUDA graphs and optimizes the compilation graph specifically for latency. Using CUDA graphs greatly reduces the overhead of launching GPU operations. It saves time by using a mechanism to launch multiple GPU operations through a single CPU operation.
|
||||
|
||||
Specifying fullgraph to be True ensures that there are no graph breaks in the underlying model, ensuring the fullest potential of `torch.compile`.
|
||||
`torch.compile` offers different backends and modes. As we’re aiming for maximum inference speed, we opt for the inductor backend using the “max-autotune”. “max-autotune” uses CUDA graphs and optimizes the compilation graph specifically for latency. Specifying fullgraph to be True ensures that there are no graph breaks in the underlying model, ensuring the fullest potential of `torch.compile`.
|
||||
|
||||
Using SDPA attention and compiling both the UNet and VAE reduces the latency from 3.31 seconds to 2.54 seconds.
|
||||
|
||||
@@ -214,7 +210,7 @@ Through experimentation, we found that certain linear layers in the UNet and the
|
||||
|
||||
</Tip>
|
||||
|
||||
You will leverage the ultra-lightweight pure PyTorch library [torchao](https://github.com/pytorch-labs/ao) (commit SHA: 54bcd5a10d0abbe7b0c045052029257099f83fd9) to use its user-friendly APIs for quantization.
|
||||
You will leverage the ultra-lightweight pure PyTorch library [torchao](https://github.com/pytorch-labs/ao) to use its user-friendly APIs for quantization.
|
||||
|
||||
First, configure all the compiler tags:
|
||||
|
||||
@@ -319,26 +315,4 @@ Applying dynamic quantization improves the latency from 2.52 seconds to 2.43 sec
|
||||
|
||||
<img src="https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/progressive-acceleration-sdxl/SDXL%2C_Batch_Size%3A_1%2C_Steps%3A_30_5.png" width=500>
|
||||
|
||||
</div>
|
||||
|
||||
## Misc
|
||||
|
||||
### No graph breaks during torch.compile
|
||||
|
||||
Ensuring that the underlying model/method can be fully compiled is crucial for performance (torch.compile with fullgraph=True). This means having no graph breaks. We did this for the UNet and VAE by changing how we access the returning variables. Consider the following example:
|
||||
|
||||
```diff
|
||||
- latents = unet(
|
||||
- latents, timestep=timestep, encoder_hidden_states=prompt_embeds
|
||||
-).sample
|
||||
|
||||
+ latents = unet(
|
||||
+ latents, timestep=timestep, encoder_hidden_states=prompt_embeds, return_dict=False
|
||||
+)[0]
|
||||
```
|
||||
|
||||
### Getting rid of GPU syncs after compilation
|
||||
|
||||
During the iterative reverse diffusion process, we [call](https://github.com/huggingface/diffusers/blob/1d686bac8146037e97f3fd8c56e4063230f71751/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L1228) `step()` on the scheduler each time after the denoiser predicts the less noisy latent embeddings. Inside `step()`, the `sigmas` variable is [indexed](https://github.com/huggingface/diffusers/blob/1d686bac8146037e97f3fd8c56e4063230f71751/src/diffusers/schedulers/scheduling_euler_discrete.py#L476). If the `sigmas` array is placed on the GPU, indexing causes a communication sync between the CPU and GPU. This causes a latency, and it becomes more evident when the denoiser has already been compiled.
|
||||
|
||||
But if the `sigmas` array always stays on the CPU (refer to [this line](https://github.com/huggingface/diffusers/blob/35a969d297cba69110d175ee79c59312b9f49e1e/src/diffusers/schedulers/scheduling_euler_discrete.py#L240)), this sync doesn’t take place, hence improved latency. In general, any CPU <-> GPU communication sync should be none or be kept to a bare minimum as it can impact inference latency.
|
||||
</div>
|
||||
@@ -318,7 +318,7 @@ make_image_grid([init_image, image], rows=1, cols=2)
|
||||
|
||||
The trade-off of using a non-inpaint specific checkpoint is the overall image quality may be lower, but it generally tends to preserve the mask area (that is why you can see the mask outline). The inpaint specific checkpoints are intentionally trained to generate higher quality inpainted images, and that includes creating a more natural transition between the masked and unmasked areas. As a result, these checkpoints are more likely to change your unmasked area.
|
||||
|
||||
If preserving the unmasked area is important for your task, you can use the `apply_overlay` method of [`VaeImageProcessor`] to force the unmasked area of an image to remain the same at the expense of some more unnatural transitions between the masked and unmasked areas.
|
||||
If preserving the unmasked area is important for your task, you can use the code below to force the unmasked area of an image to remain the same at the expense of some more unnatural transitions between the masked and unmasked areas.
|
||||
|
||||
```py
|
||||
import PIL
|
||||
@@ -345,7 +345,18 @@ prompt = "Face of a yellow cat, high resolution, sitting on a park bench"
|
||||
repainted_image = pipeline(prompt=prompt, image=init_image, mask_image=mask_image).images[0]
|
||||
repainted_image.save("repainted_image.png")
|
||||
|
||||
unmasked_unchanged_image = pipeline.image_processor.apply_overlay(mask_image, init_image, repainted_image)
|
||||
# Convert mask to grayscale NumPy array
|
||||
mask_image_arr = np.array(mask_image.convert("L"))
|
||||
# Add a channel dimension to the end of the grayscale mask
|
||||
mask_image_arr = mask_image_arr[:, :, None]
|
||||
# Binarize the mask: 1s correspond to the pixels which are repainted
|
||||
mask_image_arr = mask_image_arr.astype(np.float32) / 255.0
|
||||
mask_image_arr[mask_image_arr < 0.5] = 0
|
||||
mask_image_arr[mask_image_arr >= 0.5] = 1
|
||||
|
||||
# Take the masked pixels from the repainted image and the unmasked pixels from the initial image
|
||||
unmasked_unchanged_image_arr = (1 - mask_image_arr) * init_image + mask_image_arr * repainted_image
|
||||
unmasked_unchanged_image = PIL.Image.fromarray(unmasked_unchanged_image_arr.round().astype("uint8"))
|
||||
unmasked_unchanged_image.save("force_unmasked_unchanged.png")
|
||||
make_image_grid([init_image, mask_image, repainted_image, unmasked_unchanged_image], rows=2, cols=2)
|
||||
```
|
||||
|
||||
@@ -20,7 +20,6 @@ import itertools
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
@@ -38,11 +37,9 @@ from accelerate.logging import get_logger
|
||||
from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
|
||||
from huggingface_hub import create_repo, upload_folder
|
||||
from packaging import version
|
||||
from peft import LoraConfig
|
||||
from peft.utils import get_peft_model_state_dict
|
||||
from PIL import Image
|
||||
from PIL.ImageOps import exif_transpose
|
||||
from safetensors.torch import load_file, save_file
|
||||
from safetensors.torch import save_file
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision import transforms
|
||||
from tqdm.auto import tqdm
|
||||
@@ -57,15 +54,10 @@ from diffusers import (
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.loaders import LoraLoaderMixin
|
||||
from diffusers.models.lora import LoRALinearLayer
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import compute_snr
|
||||
from diffusers.utils import (
|
||||
check_min_version,
|
||||
convert_all_state_dict_to_peft,
|
||||
convert_state_dict_to_diffusers,
|
||||
convert_state_dict_to_kohya,
|
||||
is_wandb_available,
|
||||
)
|
||||
from diffusers.training_utils import compute_snr, unet_lora_state_dict
|
||||
from diffusers.utils import check_min_version, is_wandb_available
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
@@ -75,6 +67,39 @@ check_min_version("0.25.0.dev0")
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
# TODO: This function should be removed once training scripts are rewritten in PEFT
|
||||
def text_encoder_lora_state_dict(text_encoder):
|
||||
state_dict = {}
|
||||
|
||||
def text_encoder_attn_modules(text_encoder):
|
||||
from transformers import CLIPTextModel, CLIPTextModelWithProjection
|
||||
|
||||
attn_modules = []
|
||||
|
||||
if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
|
||||
for i, layer in enumerate(text_encoder.text_model.encoder.layers):
|
||||
name = f"text_model.encoder.layers.{i}.self_attn"
|
||||
mod = layer.self_attn
|
||||
attn_modules.append((name, mod))
|
||||
|
||||
return attn_modules
|
||||
|
||||
for name, module in text_encoder_attn_modules(text_encoder):
|
||||
for k, v in module.q_proj.lora_linear_layer.state_dict().items():
|
||||
state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v
|
||||
|
||||
for k, v in module.k_proj.lora_linear_layer.state_dict().items():
|
||||
state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v
|
||||
|
||||
for k, v in module.v_proj.lora_linear_layer.state_dict().items():
|
||||
state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v
|
||||
|
||||
for k, v in module.out_proj.lora_linear_layer.state_dict().items():
|
||||
state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v
|
||||
|
||||
return state_dict
|
||||
|
||||
|
||||
def save_model_card(
|
||||
repo_id: str,
|
||||
images=None,
|
||||
@@ -100,17 +125,10 @@ def save_model_card(
|
||||
img_str += f"""
|
||||
- text: '{instance_prompt}'
|
||||
"""
|
||||
embeddings_filename = f"{repo_folder}_emb"
|
||||
instance_prompt_webui = re.sub(r"<s\d+>", "", re.sub(r"<s\d+>", embeddings_filename, instance_prompt, count=1))
|
||||
ti_keys = ", ".join(f'"{match}"' for match in re.findall(r"<s\d+>", instance_prompt))
|
||||
if instance_prompt_webui != embeddings_filename:
|
||||
instance_prompt_sentence = f"For example, `{instance_prompt_webui}`"
|
||||
else:
|
||||
instance_prompt_sentence = ""
|
||||
|
||||
trigger_str = f"You should use {instance_prompt} to trigger the image generation."
|
||||
diffusers_imports_pivotal = ""
|
||||
diffusers_example_pivotal = ""
|
||||
webui_example_pivotal = ""
|
||||
if train_text_encoder_ti:
|
||||
trigger_str = (
|
||||
"To trigger image generation of trained concept(or concepts) replace each concept identifier "
|
||||
@@ -119,16 +137,11 @@ def save_model_card(
|
||||
diffusers_imports_pivotal = """from huggingface_hub import hf_hub_download
|
||||
from safetensors.torch import load_file
|
||||
"""
|
||||
diffusers_example_pivotal = f"""embedding_path = hf_hub_download(repo_id='{repo_id}', filename='{embeddings_filename}.safetensors' repo_type="model")
|
||||
diffusers_example_pivotal = f"""embedding_path = hf_hub_download(repo_id='{repo_id}', filename="embeddings.safetensors", repo_type="model")
|
||||
state_dict = load_file(embedding_path)
|
||||
pipeline.load_textual_inversion(state_dict["clip_l"], token=[{ti_keys}], text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer)
|
||||
pipeline.load_textual_inversion(state_dict["clip_g"], token=[{ti_keys}], text_encoder=pipeline.text_encoder_2, tokenizer=pipeline.tokenizer_2)
|
||||
pipeline.load_textual_inversion(state_dict["clip_l"], token=["<s0>", "<s1>"], text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer)
|
||||
pipeline.load_textual_inversion(state_dict["clip_g"], token=["<s0>", "<s1>"], text_encoder=pipeline.text_encoder_2, tokenizer=pipeline.tokenizer_2)
|
||||
"""
|
||||
webui_example_pivotal = f"""- *Embeddings*: download **[`{embeddings_filename}.safetensors` here 💾](/{repo_id}/blob/main/{embeddings_filename}.safetensors)**.
|
||||
- Place it on it on your `embeddings` folder
|
||||
- Use it by adding `{embeddings_filename}` to your prompt. {instance_prompt_sentence}
|
||||
(you need both the LoRA and the embeddings as they were trained together for this LoRA)
|
||||
"""
|
||||
if token_abstraction_dict:
|
||||
for key, value in token_abstraction_dict.items():
|
||||
tokens = "".join(value)
|
||||
@@ -160,14 +173,9 @@ license: openrail++
|
||||
|
||||
### These are {repo_id} LoRA adaption weights for {base_model}.
|
||||
|
||||
## Download model
|
||||
## Trigger words
|
||||
|
||||
### Use it with UIs such as AUTOMATIC1111, Comfy UI, SD.Next, Invoke
|
||||
|
||||
- **LoRA**: download **[`{repo_folder}.safetensors` here 💾](/{repo_id}/blob/main/{repo_folder}.safetensors)**.
|
||||
- Place it on your `models/Lora` folder.
|
||||
- On AUTOMATIC1111, load the LoRA by adding `<lora:{repo_folder}:1>` to your prompt. On ComfyUI just [load it as a regular LoRA](https://comfyanonymous.github.io/ComfyUI_examples/lora/).
|
||||
{webui_example_pivotal}
|
||||
{trigger_str}
|
||||
|
||||
## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers)
|
||||
|
||||
@@ -183,12 +191,16 @@ image = pipeline('{validation_prompt if validation_prompt else instance_prompt}'
|
||||
|
||||
For more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters)
|
||||
|
||||
## Trigger words
|
||||
## Download model
|
||||
|
||||
{trigger_str}
|
||||
### Use it with UIs such as AUTOMATIC1111, Comfy UI, SD.Next, Invoke
|
||||
|
||||
- Download the LoRA *.safetensors [here](/{repo_id}/blob/main/pytorch_lora_weights.safetensors). Rename it and place it on your Lora folder.
|
||||
- Download the text embeddings *.safetensors [here](/{repo_id}/blob/main/embeddings.safetensors). Rename it and place it on it on your embeddings folder.
|
||||
|
||||
All [Files & versions](/{repo_id}/tree/main).
|
||||
|
||||
## Details
|
||||
All [Files & versions](/{repo_id}/tree/main).
|
||||
|
||||
The weights were trained using [🧨 diffusers Advanced Dreambooth Training Script](https://github.com/huggingface/diffusers/blob/main/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py).
|
||||
|
||||
@@ -1250,25 +1262,54 @@ def main(args):
|
||||
text_encoder_two.gradient_checkpointing_enable()
|
||||
|
||||
# now we will add new LoRA weights to the attention layers
|
||||
unet_lora_config = LoraConfig(
|
||||
r=args.rank,
|
||||
lora_alpha=args.rank,
|
||||
init_lora_weights="gaussian",
|
||||
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
|
||||
)
|
||||
unet.add_adapter(unet_lora_config)
|
||||
# Set correct lora layers
|
||||
unet_lora_parameters = []
|
||||
for attn_processor_name, attn_processor in unet.attn_processors.items():
|
||||
# Parse the attention module.
|
||||
attn_module = unet
|
||||
for n in attn_processor_name.split(".")[:-1]:
|
||||
attn_module = getattr(attn_module, n)
|
||||
|
||||
# Set the `lora_layer` attribute of the attention-related matrices.
|
||||
attn_module.to_q.set_lora_layer(
|
||||
LoRALinearLayer(
|
||||
in_features=attn_module.to_q.in_features, out_features=attn_module.to_q.out_features, rank=args.rank
|
||||
)
|
||||
)
|
||||
attn_module.to_k.set_lora_layer(
|
||||
LoRALinearLayer(
|
||||
in_features=attn_module.to_k.in_features, out_features=attn_module.to_k.out_features, rank=args.rank
|
||||
)
|
||||
)
|
||||
attn_module.to_v.set_lora_layer(
|
||||
LoRALinearLayer(
|
||||
in_features=attn_module.to_v.in_features, out_features=attn_module.to_v.out_features, rank=args.rank
|
||||
)
|
||||
)
|
||||
attn_module.to_out[0].set_lora_layer(
|
||||
LoRALinearLayer(
|
||||
in_features=attn_module.to_out[0].in_features,
|
||||
out_features=attn_module.to_out[0].out_features,
|
||||
rank=args.rank,
|
||||
)
|
||||
)
|
||||
|
||||
# Accumulate the LoRA params to optimize.
|
||||
unet_lora_parameters.extend(attn_module.to_q.lora_layer.parameters())
|
||||
unet_lora_parameters.extend(attn_module.to_k.lora_layer.parameters())
|
||||
unet_lora_parameters.extend(attn_module.to_v.lora_layer.parameters())
|
||||
unet_lora_parameters.extend(attn_module.to_out[0].lora_layer.parameters())
|
||||
|
||||
# The text encoder comes from 🤗 transformers, so we cannot directly modify it.
|
||||
# So, instead, we monkey-patch the forward calls of its attention-blocks.
|
||||
if args.train_text_encoder:
|
||||
text_lora_config = LoraConfig(
|
||||
r=args.rank,
|
||||
lora_alpha=args.rank,
|
||||
init_lora_weights="gaussian",
|
||||
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
|
||||
# ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
|
||||
text_lora_parameters_one = LoraLoaderMixin._modify_text_encoder(
|
||||
text_encoder_one, dtype=torch.float32, rank=args.rank
|
||||
)
|
||||
text_lora_parameters_two = LoraLoaderMixin._modify_text_encoder(
|
||||
text_encoder_two, dtype=torch.float32, rank=args.rank
|
||||
)
|
||||
text_encoder_one.add_adapter(text_lora_config)
|
||||
text_encoder_two.add_adapter(text_lora_config)
|
||||
|
||||
# if we use textual inversion, we freeze all parameters except for the token embeddings
|
||||
# in text encoder
|
||||
@@ -1292,17 +1333,6 @@ def main(args):
|
||||
else:
|
||||
param.requires_grad = False
|
||||
|
||||
# Make sure the trainable params are in float32.
|
||||
if args.mixed_precision == "fp16":
|
||||
models = [unet]
|
||||
if args.train_text_encoder:
|
||||
models.extend([text_encoder_one, text_encoder_two])
|
||||
for model in models:
|
||||
for param in model.parameters():
|
||||
# only upcast trainable parameters (LoRA) into fp32
|
||||
if param.requires_grad:
|
||||
param.data = param.to(torch.float32)
|
||||
|
||||
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
||||
def save_model_hook(models, weights, output_dir):
|
||||
if accelerator.is_main_process:
|
||||
@@ -1314,15 +1344,11 @@ def main(args):
|
||||
|
||||
for model in models:
|
||||
if isinstance(model, type(accelerator.unwrap_model(unet))):
|
||||
unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model))
|
||||
unet_lora_layers_to_save = unet_lora_state_dict(model)
|
||||
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
|
||||
text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers(
|
||||
get_peft_model_state_dict(model)
|
||||
)
|
||||
text_encoder_one_lora_layers_to_save = text_encoder_lora_state_dict(model)
|
||||
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))):
|
||||
text_encoder_two_lora_layers_to_save = convert_state_dict_to_diffusers(
|
||||
get_peft_model_state_dict(model)
|
||||
)
|
||||
text_encoder_two_lora_layers_to_save = text_encoder_lora_state_dict(model)
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
|
||||
@@ -1379,12 +1405,6 @@ def main(args):
|
||||
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
|
||||
)
|
||||
|
||||
unet_lora_parameters = list(filter(lambda p: p.requires_grad, unet.parameters()))
|
||||
|
||||
if args.train_text_encoder:
|
||||
text_lora_parameters_one = list(filter(lambda p: p.requires_grad, text_encoder_one.parameters()))
|
||||
text_lora_parameters_two = list(filter(lambda p: p.requires_grad, text_encoder_two.parameters()))
|
||||
|
||||
# If neither --train_text_encoder nor --train_text_encoder_ti, text_encoders remain frozen during training
|
||||
freeze_text_encoder = not (args.train_text_encoder or args.train_text_encoder_ti)
|
||||
|
||||
@@ -1819,17 +1839,9 @@ def main(args):
|
||||
# Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
|
||||
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
|
||||
# This is discussed in Section 4.2 of the same paper.
|
||||
|
||||
if args.with_prior_preservation:
|
||||
# if we're using prior preservation, we calc snr for instance loss only -
|
||||
# and hence only need timesteps corresponding to instance images
|
||||
snr_timesteps, _ = torch.chunk(timesteps, 2, dim=0)
|
||||
else:
|
||||
snr_timesteps = timesteps
|
||||
|
||||
snr = compute_snr(noise_scheduler, snr_timesteps)
|
||||
snr = compute_snr(noise_scheduler, timesteps)
|
||||
base_weight = (
|
||||
torch.stack([snr, args.snr_gamma * torch.ones_like(snr_timesteps)], dim=1).min(dim=1)[0] / snr
|
||||
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
|
||||
)
|
||||
|
||||
if noise_scheduler.config.prediction_type == "v_prediction":
|
||||
@@ -1983,17 +1995,13 @@ def main(args):
|
||||
if accelerator.is_main_process:
|
||||
unet = accelerator.unwrap_model(unet)
|
||||
unet = unet.to(torch.float32)
|
||||
unet_lora_layers = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet))
|
||||
unet_lora_layers = unet_lora_state_dict(unet)
|
||||
|
||||
if args.train_text_encoder:
|
||||
text_encoder_one = accelerator.unwrap_model(text_encoder_one)
|
||||
text_encoder_lora_layers = convert_state_dict_to_diffusers(
|
||||
get_peft_model_state_dict(text_encoder_one.to(torch.float32))
|
||||
)
|
||||
text_encoder_lora_layers = text_encoder_lora_state_dict(text_encoder_one.to(torch.float32))
|
||||
text_encoder_two = accelerator.unwrap_model(text_encoder_two)
|
||||
text_encoder_2_lora_layers = convert_state_dict_to_diffusers(
|
||||
get_peft_model_state_dict(text_encoder_two.to(torch.float32))
|
||||
)
|
||||
text_encoder_2_lora_layers = text_encoder_lora_state_dict(text_encoder_two.to(torch.float32))
|
||||
else:
|
||||
text_encoder_lora_layers = None
|
||||
text_encoder_2_lora_layers = None
|
||||
@@ -2063,15 +2071,8 @@ def main(args):
|
||||
|
||||
if args.train_text_encoder_ti:
|
||||
embedding_handler.save_embeddings(
|
||||
f"{args.output_dir}/{args.output_dir}_emb.safetensors",
|
||||
f"{args.output_dir}/embeddings.safetensors",
|
||||
)
|
||||
|
||||
# Conver to WebUI format
|
||||
lora_state_dict = load_file(f"{args.output_dir}/pytorch_lora_weights.safetensors")
|
||||
peft_state_dict = convert_all_state_dict_to_peft(lora_state_dict)
|
||||
kohya_state_dict = convert_state_dict_to_kohya(peft_state_dict)
|
||||
save_file(kohya_state_dict, f"{args.output_dir}/{args.output_dir}.safetensors")
|
||||
|
||||
save_model_card(
|
||||
model_id if not args.push_to_hub else repo_id,
|
||||
images=images,
|
||||
|
||||
@@ -12,21 +12,14 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import (
|
||||
CLIPImageProcessor,
|
||||
CLIPTextModel,
|
||||
CLIPTextModelWithProjection,
|
||||
CLIPTokenizer,
|
||||
CLIPVisionModelWithProjection,
|
||||
)
|
||||
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
|
||||
|
||||
from diffusers import DiffusionPipeline, StableDiffusionXLPipeline
|
||||
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from diffusers.loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel
|
||||
from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||
from diffusers.models.attention_processor import (
|
||||
AttnProcessor2_0,
|
||||
FusedAttnProcessor2_0,
|
||||
LoRAAttnProcessor2_0,
|
||||
LoRAXFormersAttnProcessor,
|
||||
XFormersAttnProcessor,
|
||||
@@ -34,7 +27,6 @@ from diffusers.models.attention_processor import (
|
||||
from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
|
||||
from diffusers.schedulers import KarrasDiffusionSchedulers
|
||||
from diffusers.utils import (
|
||||
deprecate,
|
||||
is_accelerate_available,
|
||||
is_accelerate_version,
|
||||
is_invisible_watermark_available,
|
||||
@@ -260,7 +252,6 @@ def get_weighted_text_embeddings_sdxl(
|
||||
neg_prompt_2: str = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
device: Optional[torch.device] = None,
|
||||
clip_skip: Optional[int] = None,
|
||||
):
|
||||
"""
|
||||
This function can process long prompt with weights, no length limitation
|
||||
@@ -274,7 +265,6 @@ def get_weighted_text_embeddings_sdxl(
|
||||
neg_prompt_2 (str)
|
||||
num_images_per_prompt (int)
|
||||
device (torch.device)
|
||||
clip_skip (int)
|
||||
Returns:
|
||||
prompt_embeds (torch.Tensor)
|
||||
neg_prompt_embeds (torch.Tensor)
|
||||
@@ -287,24 +277,17 @@ def get_weighted_text_embeddings_sdxl(
|
||||
if neg_prompt_2:
|
||||
neg_prompt = f"{neg_prompt} {neg_prompt_2}"
|
||||
|
||||
prompt_t1 = prompt_t2 = prompt
|
||||
neg_prompt_t1 = neg_prompt_t2 = neg_prompt
|
||||
|
||||
if isinstance(pipe, TextualInversionLoaderMixin):
|
||||
prompt_t1 = pipe.maybe_convert_prompt(prompt_t1, pipe.tokenizer)
|
||||
neg_prompt_t1 = pipe.maybe_convert_prompt(neg_prompt_t1, pipe.tokenizer)
|
||||
prompt_t2 = pipe.maybe_convert_prompt(prompt_t2, pipe.tokenizer_2)
|
||||
neg_prompt_t2 = pipe.maybe_convert_prompt(neg_prompt_t2, pipe.tokenizer_2)
|
||||
|
||||
eos = pipe.tokenizer.eos_token_id
|
||||
|
||||
# tokenizer 1
|
||||
prompt_tokens, prompt_weights = get_prompts_tokens_with_weights(pipe.tokenizer, prompt_t1)
|
||||
neg_prompt_tokens, neg_prompt_weights = get_prompts_tokens_with_weights(pipe.tokenizer, neg_prompt_t1)
|
||||
prompt_tokens, prompt_weights = get_prompts_tokens_with_weights(pipe.tokenizer, prompt)
|
||||
|
||||
neg_prompt_tokens, neg_prompt_weights = get_prompts_tokens_with_weights(pipe.tokenizer, neg_prompt)
|
||||
|
||||
# tokenizer 2
|
||||
prompt_tokens_2, prompt_weights_2 = get_prompts_tokens_with_weights(pipe.tokenizer_2, prompt_t2)
|
||||
neg_prompt_tokens_2, neg_prompt_weights_2 = get_prompts_tokens_with_weights(pipe.tokenizer_2, neg_prompt_t2)
|
||||
prompt_tokens_2, prompt_weights_2 = get_prompts_tokens_with_weights(pipe.tokenizer_2, prompt)
|
||||
|
||||
neg_prompt_tokens_2, neg_prompt_weights_2 = get_prompts_tokens_with_weights(pipe.tokenizer_2, neg_prompt)
|
||||
|
||||
# padding the shorter one for prompt set 1
|
||||
prompt_token_len = len(prompt_tokens)
|
||||
@@ -359,19 +342,13 @@ def get_weighted_text_embeddings_sdxl(
|
||||
|
||||
# use first text encoder
|
||||
prompt_embeds_1 = pipe.text_encoder(token_tensor.to(device), output_hidden_states=True)
|
||||
prompt_embeds_1_hidden_states = prompt_embeds_1.hidden_states[-2]
|
||||
|
||||
# use second text encoder
|
||||
prompt_embeds_2 = pipe.text_encoder_2(token_tensor_2.to(device), output_hidden_states=True)
|
||||
prompt_embeds_2_hidden_states = prompt_embeds_2.hidden_states[-2]
|
||||
pooled_prompt_embeds = prompt_embeds_2[0]
|
||||
|
||||
if clip_skip is None:
|
||||
prompt_embeds_1_hidden_states = prompt_embeds_1.hidden_states[-2]
|
||||
prompt_embeds_2_hidden_states = prompt_embeds_2.hidden_states[-2]
|
||||
else:
|
||||
# "2" because SDXL always indexes from the penultimate layer.
|
||||
prompt_embeds_1_hidden_states = prompt_embeds_1.hidden_states[-(clip_skip + 2)]
|
||||
prompt_embeds_2_hidden_states = prompt_embeds_2.hidden_states[-(clip_skip + 2)]
|
||||
|
||||
prompt_embeds_list = [prompt_embeds_1_hidden_states, prompt_embeds_2_hidden_states]
|
||||
token_embedding = torch.concat(prompt_embeds_list, dim=-1).squeeze(0)
|
||||
|
||||
@@ -544,21 +521,19 @@ def retrieve_timesteps(
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
class SDXLLongPromptWeightingPipeline(
|
||||
DiffusionPipeline, FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
):
|
||||
class SDXLLongPromptWeightingPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoaderMixin):
|
||||
r"""
|
||||
Pipeline for text-to-image generation using Stable Diffusion XL.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
||||
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
The pipeline also inherits the following loading methods:
|
||||
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
|
||||
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
|
||||
- [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
||||
- [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
|
||||
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
|
||||
In addition the pipeline inherits the following loading methods:
|
||||
- *LoRA*: [`StableDiffusionXLPipeline.load_lora_weights`]
|
||||
- *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`]
|
||||
|
||||
as well as the following saving methods:
|
||||
- *LoRA*: [`loaders.StableDiffusionXLPipeline.save_lora_weights`]
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
@@ -579,34 +554,12 @@ class SDXLLongPromptWeightingPipeline(
|
||||
tokenizer_2 (`CLIPTokenizer`):
|
||||
Second Tokenizer of class
|
||||
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
||||
unet ([`UNet2DConditionModel`]):
|
||||
Conditional U-Net architecture to denoise the encoded image latents.
|
||||
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
|
||||
scheduler ([`SchedulerMixin`]):
|
||||
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
||||
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
||||
feature_extractor ([`~transformers.CLIPImageProcessor`]):
|
||||
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae"
|
||||
_optional_components = [
|
||||
"tokenizer",
|
||||
"tokenizer_2",
|
||||
"text_encoder",
|
||||
"text_encoder_2",
|
||||
"image_encoder",
|
||||
"feature_extractor",
|
||||
]
|
||||
_callback_tensor_inputs = [
|
||||
"latents",
|
||||
"prompt_embeds",
|
||||
"negative_prompt_embeds",
|
||||
"add_text_embeds",
|
||||
"add_time_ids",
|
||||
"negative_pooled_prompt_embeds",
|
||||
"negative_add_time_ids",
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vae: AutoencoderKL,
|
||||
@@ -616,8 +569,6 @@ class SDXLLongPromptWeightingPipeline(
|
||||
tokenizer_2: CLIPTokenizer,
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: KarrasDiffusionSchedulers,
|
||||
feature_extractor: Optional[CLIPImageProcessor] = None,
|
||||
image_encoder: Optional[CLIPVisionModelWithProjection] = None,
|
||||
force_zeros_for_empty_prompt: bool = True,
|
||||
add_watermarker: Optional[bool] = None,
|
||||
):
|
||||
@@ -631,8 +582,6 @@ class SDXLLongPromptWeightingPipeline(
|
||||
tokenizer_2=tokenizer_2,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
feature_extractor=feature_extractor,
|
||||
image_encoder=image_encoder,
|
||||
)
|
||||
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
@@ -712,7 +661,6 @@ class SDXLLongPromptWeightingPipeline(
|
||||
# We'll offload the last model manually.
|
||||
self.final_offload_hook = hook
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: str,
|
||||
@@ -904,31 +852,6 @@ class SDXLLongPromptWeightingPipeline(
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
|
||||
def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
|
||||
dtype = next(self.image_encoder.parameters()).dtype
|
||||
|
||||
if not isinstance(image, torch.Tensor):
|
||||
image = self.feature_extractor(image, return_tensors="pt").pixel_values
|
||||
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
if output_hidden_states:
|
||||
image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
|
||||
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
uncond_image_enc_hidden_states = self.image_encoder(
|
||||
torch.zeros_like(image), output_hidden_states=True
|
||||
).hidden_states[-2]
|
||||
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
|
||||
num_images_per_prompt, dim=0
|
||||
)
|
||||
return image_enc_hidden_states, uncond_image_enc_hidden_states
|
||||
else:
|
||||
image_embeds = self.image_encoder(image).image_embeds
|
||||
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
uncond_image_embeds = torch.zeros_like(image_embeds)
|
||||
|
||||
return image_embeds, uncond_image_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
@@ -961,7 +884,6 @@ class SDXLLongPromptWeightingPipeline(
|
||||
negative_prompt_embeds=None,
|
||||
pooled_prompt_embeds=None,
|
||||
negative_pooled_prompt_embeds=None,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
):
|
||||
if height % 8 != 0 or width % 8 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
||||
@@ -969,19 +891,14 @@ class SDXLLongPromptWeightingPipeline(
|
||||
if strength < 0 or strength > 1:
|
||||
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
|
||||
|
||||
if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
|
||||
if (callback_steps is None) or (
|
||||
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
||||
f" {type(callback_steps)}."
|
||||
)
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
@@ -1030,95 +947,6 @@ class SDXLLongPromptWeightingPipeline(
|
||||
"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
|
||||
)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_freeu
|
||||
def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
|
||||
r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497.
|
||||
|
||||
The suffixes after the scaling factors represent the stages where they are being applied.
|
||||
|
||||
Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values
|
||||
that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
|
||||
|
||||
Args:
|
||||
s1 (`float`):
|
||||
Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
|
||||
mitigate "oversmoothing effect" in the enhanced denoising process.
|
||||
s2 (`float`):
|
||||
Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
|
||||
mitigate "oversmoothing effect" in the enhanced denoising process.
|
||||
b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
|
||||
b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
|
||||
"""
|
||||
if not hasattr(self, "unet"):
|
||||
raise ValueError("The pipeline must have `unet` for using FreeU.")
|
||||
self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_freeu
|
||||
def disable_freeu(self):
|
||||
"""Disables the FreeU mechanism if enabled."""
|
||||
self.unet.disable_freeu()
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.fuse_qkv_projections
|
||||
def fuse_qkv_projections(self, unet: bool = True, vae: bool = True):
|
||||
"""
|
||||
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
|
||||
key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This API is 🧪 experimental.
|
||||
|
||||
</Tip>
|
||||
|
||||
Args:
|
||||
unet (`bool`, defaults to `True`): To apply fusion on the UNet.
|
||||
vae (`bool`, defaults to `True`): To apply fusion on the VAE.
|
||||
"""
|
||||
self.fusing_unet = False
|
||||
self.fusing_vae = False
|
||||
|
||||
if unet:
|
||||
self.fusing_unet = True
|
||||
self.unet.fuse_qkv_projections()
|
||||
self.unet.set_attn_processor(FusedAttnProcessor2_0())
|
||||
|
||||
if vae:
|
||||
if not isinstance(self.vae, AutoencoderKL):
|
||||
raise ValueError("`fuse_qkv_projections()` is only supported for the VAE of type `AutoencoderKL`.")
|
||||
|
||||
self.fusing_vae = True
|
||||
self.vae.fuse_qkv_projections()
|
||||
self.vae.set_attn_processor(FusedAttnProcessor2_0())
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.unfuse_qkv_projections
|
||||
def unfuse_qkv_projections(self, unet: bool = True, vae: bool = True):
|
||||
"""Disable QKV projection fusion if enabled.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This API is 🧪 experimental.
|
||||
|
||||
</Tip>
|
||||
|
||||
Args:
|
||||
unet (`bool`, defaults to `True`): To apply fusion on the UNet.
|
||||
vae (`bool`, defaults to `True`): To apply fusion on the VAE.
|
||||
|
||||
"""
|
||||
if unet:
|
||||
if not self.fusing_unet:
|
||||
logger.warning("The UNet was not initially fused for QKV projections. Doing nothing.")
|
||||
else:
|
||||
self.unet.unfuse_qkv_projections()
|
||||
self.fusing_unet = False
|
||||
|
||||
if vae:
|
||||
if not self.fusing_vae:
|
||||
logger.warning("The VAE was not initially fused for QKV projections. Doing nothing.")
|
||||
else:
|
||||
self.vae.unfuse_qkv_projections()
|
||||
self.fusing_vae = False
|
||||
|
||||
def get_timesteps(self, num_inference_steps, strength, device, denoising_start=None):
|
||||
# get the original timestep using init_timestep
|
||||
if denoising_start is None:
|
||||
@@ -1413,35 +1241,6 @@ class SDXLLongPromptWeightingPipeline(
|
||||
self.vae.decoder.conv_in.to(dtype)
|
||||
self.vae.decoder.mid_block.to(dtype)
|
||||
|
||||
# Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
|
||||
def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
|
||||
"""
|
||||
See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
|
||||
|
||||
Args:
|
||||
timesteps (`torch.Tensor`):
|
||||
generate embedding vectors at these timesteps
|
||||
embedding_dim (`int`, *optional*, defaults to 512):
|
||||
dimension of the embeddings to generate
|
||||
dtype:
|
||||
data type of the generated embeddings
|
||||
|
||||
Returns:
|
||||
`torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
|
||||
"""
|
||||
assert len(w.shape) == 1
|
||||
w = w * 1000.0
|
||||
|
||||
half_dim = embedding_dim // 2
|
||||
emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
|
||||
emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
|
||||
emb = w.to(dtype)[:, None] * emb[None, :]
|
||||
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
||||
if embedding_dim % 2 == 1: # zero pad
|
||||
emb = torch.nn.functional.pad(emb, (0, 1))
|
||||
assert emb.shape == (w.shape[0], embedding_dim)
|
||||
return emb
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
@@ -1450,10 +1249,6 @@ class SDXLLongPromptWeightingPipeline(
|
||||
def guidance_rescale(self):
|
||||
return self._guidance_rescale
|
||||
|
||||
@property
|
||||
def clip_skip(self):
|
||||
return self._clip_skip
|
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
@@ -1500,22 +1295,19 @@ class SDXLLongPromptWeightingPipeline(
|
||||
eta: float = 0.0,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
ip_adapter_image: Optional[PipelineImageInput] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
callback_steps: int = 1,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
guidance_rescale: float = 0.0,
|
||||
original_size: Optional[Tuple[int, int]] = None,
|
||||
crops_coords_top_left: Tuple[int, int] = (0, 0),
|
||||
target_size: Optional[Tuple[int, int]] = None,
|
||||
clip_skip: Optional[int] = None,
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
@@ -1592,8 +1384,6 @@ class SDXLLongPromptWeightingPipeline(
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
ip_adapter_image: (`PipelineImageInput`, *optional*):
|
||||
Optional image input to work with IP Adapters.
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
@@ -1614,6 +1404,12 @@ class SDXLLongPromptWeightingPipeline(
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
|
||||
of a plain tuple.
|
||||
callback (`Callable`, *optional*):
|
||||
A function that will be called every `callback_steps` steps during inference. The function will be
|
||||
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
||||
callback_steps (`int`, *optional*, defaults to 1):
|
||||
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
||||
called at every step.
|
||||
cross_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
@@ -1637,18 +1433,6 @@ class SDXLLongPromptWeightingPipeline(
|
||||
For most cases, `target_size` should be set to the desired height and width of the generated image. If
|
||||
not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in
|
||||
section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
|
||||
clip_skip (`int`, *optional*):
|
||||
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
||||
the output of the pre-final layer will be used for computing the prompt embeddings.
|
||||
callback_on_step_end (`Callable`, *optional*):
|
||||
A function that calls at the end of each denoising steps during the inference. The function is called
|
||||
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
||||
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
||||
`callback_on_step_end_tensor_inputs`.
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeine class.
|
||||
|
||||
Examples:
|
||||
|
||||
@@ -1657,23 +1441,6 @@ class SDXLLongPromptWeightingPipeline(
|
||||
[`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
|
||||
`tuple`. When returning a tuple, the first element is a list with the generated images.
|
||||
"""
|
||||
|
||||
callback = kwargs.pop("callback", None)
|
||||
callback_steps = kwargs.pop("callback_steps", None)
|
||||
|
||||
if callback is not None:
|
||||
deprecate(
|
||||
"callback",
|
||||
"1.0.0",
|
||||
"Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
|
||||
)
|
||||
if callback_steps is not None:
|
||||
deprecate(
|
||||
"callback_steps",
|
||||
"1.0.0",
|
||||
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
|
||||
)
|
||||
|
||||
# 0. Default height and width to unet
|
||||
height = height or self.default_sample_size * self.vae_scale_factor
|
||||
width = width or self.default_sample_size * self.vae_scale_factor
|
||||
@@ -1695,12 +1462,10 @@ class SDXLLongPromptWeightingPipeline(
|
||||
negative_prompt_embeds,
|
||||
pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds,
|
||||
callback_on_step_end_tensor_inputs,
|
||||
)
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._guidance_rescale = guidance_rescale
|
||||
self._clip_skip = clip_skip
|
||||
self._cross_attention_kwargs = cross_attention_kwargs
|
||||
self._denoising_end = denoising_end
|
||||
self._denoising_start = denoising_start
|
||||
@@ -1715,16 +1480,13 @@ class SDXLLongPromptWeightingPipeline(
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
if ip_adapter_image is not None:
|
||||
output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
|
||||
image_embeds, negative_image_embeds = self.encode_image(
|
||||
ip_adapter_image, device, num_images_per_prompt, output_hidden_state
|
||||
)
|
||||
if self.do_classifier_free_guidance:
|
||||
image_embeds = torch.cat([negative_image_embeds, image_embeds])
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
# 3. Encode input prompt
|
||||
(self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None)
|
||||
(cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None)
|
||||
|
||||
negative_prompt = negative_prompt if negative_prompt is not None else ""
|
||||
|
||||
@@ -1734,11 +1496,7 @@ class SDXLLongPromptWeightingPipeline(
|
||||
pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds,
|
||||
) = get_weighted_text_embeddings_sdxl(
|
||||
pipe=self,
|
||||
prompt=prompt,
|
||||
neg_prompt=negative_prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
clip_skip=clip_skip,
|
||||
pipe=self, prompt=prompt, neg_prompt=negative_prompt, num_images_per_prompt=num_images_per_prompt
|
||||
)
|
||||
dtype = prompt_embeds.dtype
|
||||
|
||||
@@ -1818,7 +1576,7 @@ class SDXLLongPromptWeightingPipeline(
|
||||
else:
|
||||
latents, noise = latents
|
||||
|
||||
# 5.1 Prepare mask latent variables
|
||||
# 5.1. Prepare mask latent variables
|
||||
if mask is not None:
|
||||
mask, masked_image_latents = self.prepare_mask_latents(
|
||||
mask=mask,
|
||||
@@ -1832,7 +1590,7 @@ class SDXLLongPromptWeightingPipeline(
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
)
|
||||
|
||||
# Check that sizes of mask, masked image and latents match
|
||||
# 8. Check that sizes of mask, masked image and latents match
|
||||
if num_channels_unet == 9:
|
||||
# default case for runwayml/stable-diffusion-inpainting
|
||||
num_channels_mask = mask.shape[1]
|
||||
@@ -1853,9 +1611,6 @@ class SDXLLongPromptWeightingPipeline(
|
||||
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 6.1 Add image embeds for IP-Adapter
|
||||
added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else {}
|
||||
|
||||
height, width = latents.shape[-2:]
|
||||
height = height * self.vae_scale_factor
|
||||
width = width * self.vae_scale_factor
|
||||
@@ -1869,7 +1624,7 @@ class SDXLLongPromptWeightingPipeline(
|
||||
original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype
|
||||
)
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
if do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
|
||||
add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)
|
||||
@@ -1916,7 +1671,7 @@ class SDXLLongPromptWeightingPipeline(
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
@@ -1924,7 +1679,7 @@ class SDXLLongPromptWeightingPipeline(
|
||||
latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
|
||||
|
||||
# predict the noise residual
|
||||
added_cond_kwargs.update({"text_embeds": add_text_embeds, "time_ids": add_time_ids})
|
||||
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
|
||||
noise_pred = self.unet(
|
||||
latent_model_input,
|
||||
t,
|
||||
@@ -1936,11 +1691,11 @@ class SDXLLongPromptWeightingPipeline(
|
||||
)[0]
|
||||
|
||||
# perform guidance
|
||||
if self.do_classifier_free_guidance:
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
if self.do_classifier_free_guidance and guidance_rescale > 0.0:
|
||||
if do_classifier_free_guidance and guidance_rescale > 0.0:
|
||||
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
||||
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
|
||||
|
||||
@@ -1963,21 +1718,6 @@ class SDXLLongPromptWeightingPipeline(
|
||||
|
||||
latents = (1 - init_mask) * init_latents_proper + init_mask * latents
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
|
||||
negative_pooled_prompt_embeds = callback_outputs.pop(
|
||||
"negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
|
||||
)
|
||||
add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
@@ -2034,28 +1774,20 @@ class SDXLLongPromptWeightingPipeline(
|
||||
eta: float = 0.0,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
ip_adapter_image: Optional[PipelineImageInput] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
callback_steps: int = 1,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
guidance_rescale: float = 0.0,
|
||||
original_size: Optional[Tuple[int, int]] = None,
|
||||
crops_coords_top_left: Tuple[int, int] = (0, 0),
|
||||
target_size: Optional[Tuple[int, int]] = None,
|
||||
clip_skip: Optional[int] = None,
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling pipeline for text-to-image.
|
||||
|
||||
Refer to the documentation of the `__call__` method for parameter descriptions.
|
||||
"""
|
||||
return self.__call__(
|
||||
prompt=prompt,
|
||||
prompt_2=prompt_2,
|
||||
@@ -2072,22 +1804,19 @@ class SDXLLongPromptWeightingPipeline(
|
||||
eta=eta,
|
||||
generator=generator,
|
||||
latents=latents,
|
||||
ip_adapter_image=ip_adapter_image,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
pooled_prompt_embeds=pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
||||
output_type=output_type,
|
||||
return_dict=return_dict,
|
||||
callback=callback,
|
||||
callback_steps=callback_steps,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
guidance_rescale=guidance_rescale,
|
||||
original_size=original_size,
|
||||
crops_coords_top_left=crops_coords_top_left,
|
||||
target_size=target_size,
|
||||
clip_skip=clip_skip,
|
||||
callback_on_step_end=callback_on_step_end,
|
||||
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def img2img(
|
||||
@@ -2109,28 +1838,20 @@ class SDXLLongPromptWeightingPipeline(
|
||||
eta: float = 0.0,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
ip_adapter_image: Optional[PipelineImageInput] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
callback_steps: int = 1,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
guidance_rescale: float = 0.0,
|
||||
original_size: Optional[Tuple[int, int]] = None,
|
||||
crops_coords_top_left: Tuple[int, int] = (0, 0),
|
||||
target_size: Optional[Tuple[int, int]] = None,
|
||||
clip_skip: Optional[int] = None,
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling pipeline for image-to-image.
|
||||
|
||||
Refer to the documentation of the `__call__` method for parameter descriptions.
|
||||
"""
|
||||
return self.__call__(
|
||||
prompt=prompt,
|
||||
prompt_2=prompt_2,
|
||||
@@ -2149,22 +1870,19 @@ class SDXLLongPromptWeightingPipeline(
|
||||
eta=eta,
|
||||
generator=generator,
|
||||
latents=latents,
|
||||
ip_adapter_image=ip_adapter_image,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
pooled_prompt_embeds=pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
||||
output_type=output_type,
|
||||
return_dict=return_dict,
|
||||
callback=callback,
|
||||
callback_steps=callback_steps,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
guidance_rescale=guidance_rescale,
|
||||
original_size=original_size,
|
||||
crops_coords_top_left=crops_coords_top_left,
|
||||
target_size=target_size,
|
||||
clip_skip=clip_skip,
|
||||
callback_on_step_end=callback_on_step_end,
|
||||
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def inpaint(
|
||||
@@ -2188,28 +1906,20 @@ class SDXLLongPromptWeightingPipeline(
|
||||
eta: float = 0.0,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
ip_adapter_image: Optional[PipelineImageInput] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
callback_steps: int = 1,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
guidance_rescale: float = 0.0,
|
||||
original_size: Optional[Tuple[int, int]] = None,
|
||||
crops_coords_top_left: Tuple[int, int] = (0, 0),
|
||||
target_size: Optional[Tuple[int, int]] = None,
|
||||
clip_skip: Optional[int] = None,
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling pipeline for inpainting.
|
||||
|
||||
Refer to the documentation of the `__call__` method for parameter descriptions.
|
||||
"""
|
||||
return self.__call__(
|
||||
prompt=prompt,
|
||||
prompt_2=prompt_2,
|
||||
@@ -2230,22 +1940,19 @@ class SDXLLongPromptWeightingPipeline(
|
||||
eta=eta,
|
||||
generator=generator,
|
||||
latents=latents,
|
||||
ip_adapter_image=ip_adapter_image,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
pooled_prompt_embeds=pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
||||
output_type=output_type,
|
||||
return_dict=return_dict,
|
||||
callback=callback,
|
||||
callback_steps=callback_steps,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
guidance_rescale=guidance_rescale,
|
||||
original_size=original_size,
|
||||
crops_coords_top_left=crops_coords_top_left,
|
||||
target_size=target_size,
|
||||
clip_skip=clip_skip,
|
||||
callback_on_step_end=callback_on_step_end,
|
||||
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Overrride to properly handle the loading and unloading of the additional text encoder.
|
||||
|
||||
@@ -51,7 +51,7 @@ from diffusers import (
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available
|
||||
from diffusers.utils import check_min_version, is_wandb_available
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
@@ -113,7 +113,7 @@ def log_validation(vae, args, accelerator, weight_dtype, step, unet=None, is_fin
|
||||
if unet is None:
|
||||
raise ValueError("Must provide a `unet` when doing intermediate validation.")
|
||||
unet = accelerator.unwrap_model(unet)
|
||||
state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet))
|
||||
state_dict = get_peft_model_state_dict(unet)
|
||||
to_load = state_dict
|
||||
else:
|
||||
to_load = args.output_dir
|
||||
@@ -819,7 +819,7 @@ def main(args):
|
||||
unet_ = accelerator.unwrap_model(unet)
|
||||
# also save the checkpoints in native `diffusers` format so that it can be easily
|
||||
# be independently loaded via `load_lora_weights()`.
|
||||
state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet_))
|
||||
state_dict = get_peft_model_state_dict(unet_)
|
||||
StableDiffusionXLPipeline.save_lora_weights(output_dir, unet_lora_layers=state_dict)
|
||||
|
||||
for _, model in enumerate(models):
|
||||
@@ -1184,7 +1184,7 @@ def main(args):
|
||||
# solver timestep.
|
||||
|
||||
# With the adapters disabled, the `unet` is the regular teacher model.
|
||||
accelerator.unwrap_model(unet).disable_adapters()
|
||||
unet.disable_adapters()
|
||||
with torch.no_grad():
|
||||
# 1. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and conditional embedding c
|
||||
cond_teacher_output = unet(
|
||||
@@ -1248,7 +1248,7 @@ def main(args):
|
||||
x_prev = solver.ddim_step(pred_x0, pred_noise, index).to(unet.dtype)
|
||||
|
||||
# re-enable unet adapters to turn the `unet` into a student unet.
|
||||
accelerator.unwrap_model(unet).enable_adapters()
|
||||
unet.enable_adapters()
|
||||
|
||||
# 9. Get target LCM prediction on x_prev, w, c, t_n (timesteps)
|
||||
# Note that we do not use a separate target network for LCM-LoRA distillation.
|
||||
@@ -1332,7 +1332,7 @@ def main(args):
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
unet = accelerator.unwrap_model(unet)
|
||||
unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet))
|
||||
unet_lora_state_dict = get_peft_model_state_dict(unet)
|
||||
StableDiffusionXLPipeline.save_lora_weights(args.output_dir, unet_lora_layers=unet_lora_state_dict)
|
||||
|
||||
if args.push_to_hub:
|
||||
|
||||
@@ -86,7 +86,7 @@ accelerate launch train_dreambooth_lora_sdxl.py \
|
||||
--resolution=1024 \
|
||||
--train_batch_size=1 \
|
||||
--gradient_accumulation_steps=4 \
|
||||
--learning_rate=1e-4 \
|
||||
--learning_rate=1e-5 \
|
||||
--report_to="wandb" \
|
||||
--lr_scheduler="constant" \
|
||||
--lr_warmup_steps=0 \
|
||||
|
||||
@@ -54,7 +54,7 @@ from diffusers import (
|
||||
)
|
||||
from diffusers.loaders import LoraLoaderMixin
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available
|
||||
from diffusers.utils import check_min_version, is_wandb_available
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
@@ -853,11 +853,9 @@ def main(args):
|
||||
|
||||
for model in models:
|
||||
if isinstance(model, type(accelerator.unwrap_model(unet))):
|
||||
unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model))
|
||||
unet_lora_layers_to_save = get_peft_model_state_dict(model)
|
||||
elif isinstance(model, type(accelerator.unwrap_model(text_encoder))):
|
||||
text_encoder_lora_layers_to_save = convert_state_dict_to_diffusers(
|
||||
get_peft_model_state_dict(model)
|
||||
)
|
||||
text_encoder_lora_layers_to_save = get_peft_model_state_dict(model)
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
|
||||
@@ -1287,11 +1285,11 @@ def main(args):
|
||||
unet = accelerator.unwrap_model(unet)
|
||||
unet = unet.to(torch.float32)
|
||||
|
||||
unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet))
|
||||
unet_lora_state_dict = get_peft_model_state_dict(unet)
|
||||
|
||||
if args.train_text_encoder:
|
||||
text_encoder = accelerator.unwrap_model(text_encoder)
|
||||
text_encoder_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(text_encoder))
|
||||
text_encoder_state_dict = get_peft_model_state_dict(text_encoder)
|
||||
else:
|
||||
text_encoder_state_dict = None
|
||||
|
||||
|
||||
@@ -494,7 +494,9 @@ class ControlNetXSModel(ModelMixin, ConfigMixin):
|
||||
"""
|
||||
return self.control_model.attn_processors
|
||||
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
def set_attn_processor(
|
||||
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
|
||||
):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
@@ -507,7 +509,7 @@ class ControlNetXSModel(ModelMixin, ConfigMixin):
|
||||
processor. This is strongly recommended when setting trainable attention processors.
|
||||
|
||||
"""
|
||||
self.control_model.set_attn_processor(processor)
|
||||
self.control_model.set_attn_processor(processor, _remove_lora)
|
||||
|
||||
def set_default_attn_processor(self):
|
||||
"""
|
||||
|
||||
@@ -1,93 +0,0 @@
|
||||
# Multi Subject Dreambooth for Inpainting Models
|
||||
|
||||
Please note that this project is not actively maintained. However, you can open an issue and tag @gzguevara.
|
||||
|
||||
[DreamBooth](https://arxiv.org/abs/2208.12242) is a method to personalize text2image models like stable diffusion given just a few(3~5) images of a subject. This project consists of **two parts**. Training Stable Diffusion for inpainting requieres prompt-image-mask pairs. The Unet of inpainiting models have 5 additional input channels (4 for the encoded masked-image and 1 for the mask itself).
|
||||
|
||||
**The first part**, the `multi_inpaint_dataset.ipynb` notebook, demonstrates how make a 🤗 dataset of prompt-image-mask pairs. You can, however, skip the first part and move straight to the second part with the example datasets in this project. ([cat toy dataset masked](https://huggingface.co/datasets/gzguevara/cat_toy_masked), [mr. potato head dataset masked](https://huggingface.co/datasets/gzguevara/mr_potato_head_masked))
|
||||
|
||||
**The second part**, the `train_multi_subject_inpainting.py` training script, demonstrates how to implement a training procedure for one or more subjects and adapt it for stable diffusion for inpainting.
|
||||
|
||||
## 1. Data Collection: Make Prompt-Image-Mask Pairs
|
||||
|
||||
Earlier training scripts have provided approaches like random masking for the training images. This project provides a notebook for more precise mask setting.
|
||||
|
||||
The notebook can be found here: [](https://colab.research.google.com/drive/1JNEASI_B7pLW1srxhgln6nM0HoGAQT32?usp=sharing)
|
||||
|
||||
The `multi_inpaint_dataset.ipynb` notebook, takes training & validation images, on which the user draws masks and provides prompts to make a prompt-image-mask pairs. This ensures that during training, the loss is computed on the area masking the object of interest, rather than on random areas. Moreover, the `multi_inpaint_dataset.ipynb` notebook allows you to build a validation dataset with corresponding masks for monitoring the training process. Example below:
|
||||
|
||||

|
||||
|
||||
You can build multiple datasets for every subject and upload them to the 🤗 hub. Later, when launching the training script you can indicate the paths of the datasets, on which you would like to finetune Stable Diffusion for inpaining.
|
||||
|
||||
## 2. Train Multi Subject Dreambooth for Inpainting
|
||||
|
||||
### 2.1. Setting The Training Configuration
|
||||
|
||||
Before launching the training script, make sure to select the inpainting the target model, the output directory and the 🤗 datasets.
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="runwayml/stable-diffusion-inpainting"
|
||||
export OUTPUT_DIR="path-to-save-model"
|
||||
|
||||
export DATASET_1="gzguevara/mr_potato_head_masked"
|
||||
export DATASET_2="gzguevara/cat_toy_masked"
|
||||
... # Further paths to 🤗 datasets
|
||||
```
|
||||
|
||||
### 2.2. Launching The Training Script
|
||||
|
||||
```bash
|
||||
accelerate launch train_multi_subject_dreambooth_inpaint.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||
--instance_data_dir $DATASET_1 $DATASET_2 \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--resolution=512 \
|
||||
--train_batch_size=1 \
|
||||
--gradient_accumulation_steps=2 \
|
||||
--learning_rate=3e-6 \
|
||||
--max_train_steps=500 \
|
||||
--report_to_wandb
|
||||
```
|
||||
|
||||
### 2.3. Fine-tune text encoder with the UNet.
|
||||
|
||||
The script also allows to fine-tune the `text_encoder` along with the `unet`. It's been observed experimentally that fine-tuning `text_encoder` gives much better results especially on faces.
|
||||
Pass the `--train_text_encoder` argument to the script to enable training `text_encoder`.
|
||||
|
||||
___Note: Training text encoder requires more memory, with this option the training won't fit on 16GB GPU. It needs at least 24GB VRAM.___
|
||||
|
||||
```bash
|
||||
accelerate launch train_multi_subject_dreambooth_inpaint.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||
--instance_data_dir $DATASET_1 $DATASET_2 \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--resolution=512 \
|
||||
--train_batch_size=1 \
|
||||
--gradient_accumulation_steps=2 \
|
||||
--learning_rate=2e-6 \
|
||||
--max_train_steps=500 \
|
||||
--report_to_wandb \
|
||||
--train_text_encoder
|
||||
```
|
||||
|
||||
## 3. Results
|
||||
|
||||
A [](https://wandb.ai/gzguevara/uncategorized/reports/Multi-Subject-Dreambooth-for-Inpainting--Vmlldzo2MzY5NDQ4?accessToken=y0nya2d7baguhbryxaikbfr1203amvn1jsmyl07vk122mrs7tnph037u1nqgse8t) is provided showing the training progress by every 50 steps. Note, the reported weights & baises run was performed on a A100 GPU with the following stetting:
|
||||
|
||||
```bash
|
||||
accelerate launch train_multi_subject_dreambooth_inpaint.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||
--instance_data_dir $DATASET_1 $DATASET_2 \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--resolution=512 \
|
||||
--train_batch_size=10 \
|
||||
--gradient_accumulation_steps=1 \
|
||||
--learning_rate=1e-6 \
|
||||
--max_train_steps=500 \
|
||||
--report_to_wandb \
|
||||
--train_text_encoder
|
||||
```
|
||||
Here you can see the target objects on my desk and next to my plant:
|
||||
|
||||

|
||||
@@ -1,8 +0,0 @@
|
||||
accelerate>=0.16.0
|
||||
torchvision
|
||||
transformers>=4.25.1
|
||||
datasets>=2.16.0
|
||||
wandb>=0.16.1
|
||||
ftfy
|
||||
tensorboard
|
||||
Jinja2
|
||||
@@ -1,661 +0,0 @@
|
||||
import argparse
|
||||
import copy
|
||||
import itertools
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
from accelerate import Accelerator
|
||||
from accelerate.logging import get_logger
|
||||
from accelerate.utils import ProjectConfiguration, set_seed
|
||||
from datasets import concatenate_datasets, load_dataset
|
||||
from PIL import Image
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision import transforms
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
DDPMScheduler,
|
||||
StableDiffusionInpaintPipeline,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.utils import check_min_version, is_wandb_available
|
||||
|
||||
|
||||
if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.13.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
||||
parser.add_argument(
|
||||
"--pretrained_model_name_or_path",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
||||
)
|
||||
parser.add_argument("--instance_data_dir", nargs="+", help="Instance data directories")
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
default="text-inversion-model",
|
||||
help="The output directory where the model predictions and checkpoints will be written.",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
|
||||
parser.add_argument(
|
||||
"--resolution",
|
||||
type=int,
|
||||
default=512,
|
||||
help=(
|
||||
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
|
||||
" resolution"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train_text_encoder", default=False, action="store_true", help="Whether to train the text encoder"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_train_steps",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gradient_accumulation_steps",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--learning_rate",
|
||||
type=float,
|
||||
default=5e-6,
|
||||
help="Initial learning rate (after the potential warmup period) to use.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--scale_lr",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lr_scheduler",
|
||||
type=str,
|
||||
default="constant",
|
||||
help=(
|
||||
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
|
||||
' "constant", "constant_with_warmup"]'
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
|
||||
)
|
||||
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
|
||||
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
|
||||
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
|
||||
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
|
||||
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
||||
parser.add_argument(
|
||||
"--logging_dir",
|
||||
type=str,
|
||||
default="logs",
|
||||
help=(
|
||||
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
|
||||
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mixed_precision",
|
||||
type=str,
|
||||
default="no",
|
||||
choices=["no", "fp16", "bf16"],
|
||||
help=(
|
||||
"Whether to use mixed precision. Choose"
|
||||
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
|
||||
"and an Nvidia Ampere GPU."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--checkpointing_steps",
|
||||
type=int,
|
||||
default=1000,
|
||||
help=(
|
||||
"Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
|
||||
" checkpoints in case they are better than the last checkpoint and are suitable for resuming training"
|
||||
" using `--resume_from_checkpoint`."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--checkpointing_from",
|
||||
type=int,
|
||||
default=1000,
|
||||
help=("Start to checkpoint from step"),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--validation_steps",
|
||||
type=int,
|
||||
default=50,
|
||||
help=(
|
||||
"Run validation every X steps. Validation consists of running the prompt"
|
||||
" `args.validation_prompt` multiple times: `args.num_validation_images`"
|
||||
" and logging the images."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--validation_from",
|
||||
type=int,
|
||||
default=0,
|
||||
help=("Start to validate from step"),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--checkpoints_total_limit",
|
||||
type=int,
|
||||
default=None,
|
||||
help=(
|
||||
"Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
|
||||
" See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
|
||||
" for more docs"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--resume_from_checkpoint",
|
||||
type=str,
|
||||
default=None,
|
||||
help=(
|
||||
"Whether training should be resumed from a previous checkpoint. Use a path saved by"
|
||||
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--validation_project_name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The w&b name.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--report_to_wandb", default=False, action="store_true", help="Whether to report to weights and biases"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def prepare_mask_and_masked_image(image, mask):
|
||||
image = np.array(image.convert("RGB"))
|
||||
image = image[None].transpose(0, 3, 1, 2)
|
||||
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
|
||||
|
||||
mask = np.array(mask.convert("L"))
|
||||
mask = mask.astype(np.float32) / 255.0
|
||||
mask = mask[None, None]
|
||||
mask[mask < 0.5] = 0
|
||||
mask[mask >= 0.5] = 1
|
||||
mask = torch.from_numpy(mask)
|
||||
|
||||
masked_image = image * (mask < 0.5)
|
||||
|
||||
return mask, masked_image
|
||||
|
||||
|
||||
class DreamBoothDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer,
|
||||
datasets_paths,
|
||||
):
|
||||
self.tokenizer = tokenizer
|
||||
self.datasets_paths = (datasets_paths,)
|
||||
self.datasets = [load_dataset(dataset_path) for dataset_path in self.datasets_paths[0]]
|
||||
self.train_data = concatenate_datasets([dataset["train"] for dataset in self.datasets])
|
||||
self.test_data = concatenate_datasets([dataset["test"] for dataset in self.datasets])
|
||||
|
||||
self.image_normalize = transforms.Compose(
|
||||
[
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.5], [0.5]),
|
||||
]
|
||||
)
|
||||
|
||||
def set_image(self, img, switch):
|
||||
if img.mode not in ["RGB", "L"]:
|
||||
img = img.convert("RGB")
|
||||
|
||||
if switch:
|
||||
img = img.transpose(Image.FLIP_LEFT_RIGHT)
|
||||
|
||||
img = img.resize((512, 512), Image.BILINEAR)
|
||||
|
||||
return img
|
||||
|
||||
def __len__(self):
|
||||
return len(self.train_data)
|
||||
|
||||
def __getitem__(self, index):
|
||||
# Lettings
|
||||
example = {}
|
||||
img_idx = index % len(self.train_data)
|
||||
switch = random.choice([True, False])
|
||||
|
||||
# Load image
|
||||
image = self.set_image(self.train_data[img_idx]["image"], switch)
|
||||
|
||||
# Normalize image
|
||||
image_norm = self.image_normalize(image)
|
||||
|
||||
# Tokenise prompt
|
||||
tokenized_prompt = self.tokenizer(
|
||||
self.train_data[img_idx]["prompt"],
|
||||
padding="do_not_pad",
|
||||
truncation=True,
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
).input_ids
|
||||
|
||||
# Load masks for image
|
||||
masks = [
|
||||
self.set_image(self.train_data[img_idx][key], switch) for key in self.train_data[img_idx] if "mask" in key
|
||||
]
|
||||
|
||||
# Build example
|
||||
example["PIL_image"] = image
|
||||
example["instance_image"] = image_norm
|
||||
example["instance_prompt_id"] = tokenized_prompt
|
||||
example["instance_masks"] = masks
|
||||
|
||||
return example
|
||||
|
||||
|
||||
def weighted_mask(masks):
|
||||
# Convert each mask to a NumPy array and ensure it's binary
|
||||
mask_arrays = [np.array(mask) / 255 for mask in masks] # Normalizing to 0-1 range
|
||||
|
||||
# Generate random weights and apply them to each mask
|
||||
weights = [random.random() for _ in masks]
|
||||
weights = [weight / sum(weights) for weight in weights]
|
||||
weighted_masks = [mask * weight for mask, weight in zip(mask_arrays, weights)]
|
||||
|
||||
# Sum the weighted masks
|
||||
summed_mask = np.sum(weighted_masks, axis=0)
|
||||
|
||||
# Apply a threshold to create the final mask
|
||||
threshold = 0.5 # This threshold can be adjusted
|
||||
result_mask = summed_mask >= threshold
|
||||
|
||||
# Convert the result back to a PIL image
|
||||
return Image.fromarray(result_mask.astype(np.uint8) * 255)
|
||||
|
||||
|
||||
def collate_fn(examples, tokenizer):
|
||||
input_ids = [example["instance_prompt_id"] for example in examples]
|
||||
pixel_values = [example["instance_image"] for example in examples]
|
||||
|
||||
masks, masked_images = [], []
|
||||
|
||||
for example in examples:
|
||||
# generate a random mask
|
||||
mask = weighted_mask(example["instance_masks"])
|
||||
|
||||
# prepare mask and masked image
|
||||
mask, masked_image = prepare_mask_and_masked_image(example["PIL_image"], mask)
|
||||
|
||||
masks.append(mask)
|
||||
masked_images.append(masked_image)
|
||||
|
||||
pixel_values = torch.stack(pixel_values).to(memory_format=torch.contiguous_format).float()
|
||||
masks = torch.stack(masks)
|
||||
masked_images = torch.stack(masked_images)
|
||||
input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids
|
||||
|
||||
batch = {"input_ids": input_ids, "pixel_values": pixel_values, "masks": masks, "masked_images": masked_images}
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
def log_validation(pipeline, text_encoder, unet, val_pairs, accelerator):
|
||||
# update pipeline (note: unet and vae are loaded again in float32)
|
||||
pipeline.text_encoder = accelerator.unwrap_model(text_encoder)
|
||||
pipeline.unet = accelerator.unwrap_model(unet)
|
||||
|
||||
with torch.autocast("cuda"):
|
||||
val_results = [{"data_or_path": pipeline(**pair).images[0], "caption": pair["prompt"]} for pair in val_pairs]
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
wandb.log({"validation": [wandb.Image(**val_result) for val_result in val_results]})
|
||||
|
||||
|
||||
def checkpoint(args, global_step, accelerator):
|
||||
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
||||
accelerator.save_state(save_path)
|
||||
logger.info(f"Saved state to {save_path}")
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
project_config = ProjectConfiguration(
|
||||
total_limit=args.checkpoints_total_limit,
|
||||
project_dir=args.output_dir,
|
||||
logging_dir=Path(args.output_dir, args.logging_dir),
|
||||
)
|
||||
|
||||
accelerator = Accelerator(
|
||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||
mixed_precision=args.mixed_precision,
|
||||
project_config=project_config,
|
||||
log_with="wandb" if args.report_to_wandb else None,
|
||||
)
|
||||
|
||||
if args.report_to_wandb and not is_wandb_available():
|
||||
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
|
||||
|
||||
if args.seed is not None:
|
||||
set_seed(args.seed)
|
||||
|
||||
# Make one log on every process with the configuration for debugging.
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
level=logging.INFO,
|
||||
)
|
||||
logger.info(accelerator.state, main_process_only=False)
|
||||
|
||||
# Load the tokenizer & models and create wrapper for stable diffusion
|
||||
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
|
||||
text_encoder = CLIPTextModel.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="text_encoder"
|
||||
).requires_grad_(args.train_text_encoder)
|
||||
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae").requires_grad_(False)
|
||||
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
|
||||
|
||||
if args.scale_lr:
|
||||
args.learning_rate = (
|
||||
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
|
||||
)
|
||||
|
||||
optimizer = torch.optim.AdamW(
|
||||
params=itertools.chain(unet.parameters(), text_encoder.parameters())
|
||||
if args.train_text_encoder
|
||||
else unet.parameters(),
|
||||
lr=args.learning_rate,
|
||||
betas=(args.adam_beta1, args.adam_beta2),
|
||||
weight_decay=args.adam_weight_decay,
|
||||
eps=args.adam_epsilon,
|
||||
)
|
||||
|
||||
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
|
||||
|
||||
train_dataset = DreamBoothDataset(
|
||||
tokenizer=tokenizer,
|
||||
datasets_paths=args.instance_data_dir,
|
||||
)
|
||||
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
train_dataset,
|
||||
batch_size=args.train_batch_size,
|
||||
shuffle=True,
|
||||
collate_fn=lambda examples: collate_fn(examples, tokenizer),
|
||||
)
|
||||
|
||||
# Scheduler and math around the number of training steps.
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
|
||||
lr_scheduler = get_scheduler(
|
||||
args.lr_scheduler,
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
|
||||
num_training_steps=args.max_train_steps * accelerator.num_processes,
|
||||
)
|
||||
|
||||
if args.train_text_encoder:
|
||||
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
else:
|
||||
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
unet, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
|
||||
accelerator.register_for_checkpointing(lr_scheduler)
|
||||
|
||||
if args.mixed_precision == "fp16":
|
||||
weight_dtype = torch.float16
|
||||
elif args.mixed_precision == "bf16":
|
||||
weight_dtype = torch.bfloat16
|
||||
else:
|
||||
weight_dtype = torch.float32
|
||||
|
||||
# Move text_encode and vae to gpu.
|
||||
# For mixed precision training we cast the text_encoder and vae weights to half-precision
|
||||
# as these models are only used for inference, keeping weights in full precision is not required.
|
||||
vae.to(accelerator.device, dtype=weight_dtype)
|
||||
if not args.train_text_encoder:
|
||||
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
||||
|
||||
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
|
||||
# Afterwards we calculate our number of training epochs
|
||||
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||
|
||||
# We need to initialize the trackers we use, and also store our configuration.
|
||||
# The trackers initializes automatically on the main process.
|
||||
if accelerator.is_main_process:
|
||||
tracker_config = vars(copy.deepcopy(args))
|
||||
accelerator.init_trackers(args.validation_project_name, config=tracker_config)
|
||||
|
||||
# create validation pipeline (note: unet and vae are loaded again in float32)
|
||||
val_pipeline = StableDiffusionInpaintPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
unet=unet,
|
||||
vae=vae,
|
||||
torch_dtype=weight_dtype,
|
||||
safety_checker=None,
|
||||
)
|
||||
val_pipeline.set_progress_bar_config(disable=True)
|
||||
|
||||
# prepare validation dataset
|
||||
val_pairs = [
|
||||
{
|
||||
"image": example["image"],
|
||||
"mask_image": mask,
|
||||
"prompt": example["prompt"],
|
||||
}
|
||||
for example in train_dataset.test_data
|
||||
for mask in [example[key] for key in example if "mask" in key]
|
||||
]
|
||||
|
||||
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
||||
def save_model_hook(models, weights, output_dir):
|
||||
if accelerator.is_main_process:
|
||||
for model in models:
|
||||
sub_dir = "unet" if isinstance(model, type(accelerator.unwrap_model(unet))) else "text_encoder"
|
||||
model.save_pretrained(os.path.join(output_dir, sub_dir))
|
||||
|
||||
# make sure to pop weight so that corresponding model is not saved again
|
||||
weights.pop()
|
||||
|
||||
accelerator.register_save_state_pre_hook(save_model_hook)
|
||||
|
||||
print()
|
||||
|
||||
# Train!
|
||||
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
||||
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(f" Num batches each epoch = {len(train_dataloader)}")
|
||||
logger.info(f" Num Epochs = {num_train_epochs}")
|
||||
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
|
||||
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
||||
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
|
||||
logger.info(f" Total optimization steps = {args.max_train_steps}")
|
||||
|
||||
global_step = 0
|
||||
first_epoch = 0
|
||||
|
||||
if args.resume_from_checkpoint:
|
||||
if args.resume_from_checkpoint != "latest":
|
||||
path = os.path.basename(args.resume_from_checkpoint)
|
||||
else:
|
||||
# Get the most recent checkpoint
|
||||
dirs = os.listdir(args.output_dir)
|
||||
dirs = [d for d in dirs if d.startswith("checkpoint")]
|
||||
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
|
||||
path = dirs[-1] if len(dirs) > 0 else None
|
||||
|
||||
if path is None:
|
||||
accelerator.print(
|
||||
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
|
||||
)
|
||||
args.resume_from_checkpoint = None
|
||||
else:
|
||||
accelerator.print(f"Resuming from checkpoint {path}")
|
||||
accelerator.load_state(os.path.join(args.output_dir, path))
|
||||
global_step = int(path.split("-")[1])
|
||||
|
||||
resume_global_step = global_step * args.gradient_accumulation_steps
|
||||
first_epoch = global_step // num_update_steps_per_epoch
|
||||
resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
|
||||
|
||||
# Only show the progress bar once on each machine.
|
||||
progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
|
||||
progress_bar.set_description("Steps")
|
||||
|
||||
for epoch in range(first_epoch, num_train_epochs):
|
||||
unet.train()
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
# Skip steps until we reach the resumed step
|
||||
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
|
||||
if step % args.gradient_accumulation_steps == 0:
|
||||
progress_bar.update(1)
|
||||
continue
|
||||
|
||||
with accelerator.accumulate(unet):
|
||||
# Convert images to latent space
|
||||
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
|
||||
latents = latents * vae.config.scaling_factor
|
||||
|
||||
# Convert masked images to latent space
|
||||
masked_latents = vae.encode(
|
||||
batch["masked_images"].reshape(batch["pixel_values"].shape).to(dtype=weight_dtype)
|
||||
).latent_dist.sample()
|
||||
masked_latents = masked_latents * vae.config.scaling_factor
|
||||
|
||||
masks = batch["masks"]
|
||||
# resize the mask to latents shape as we concatenate the mask to the latents
|
||||
mask = torch.stack(
|
||||
[
|
||||
torch.nn.functional.interpolate(mask, size=(args.resolution // 8, args.resolution // 8))
|
||||
for mask in masks
|
||||
]
|
||||
)
|
||||
mask = mask.reshape(-1, 1, args.resolution // 8, args.resolution // 8)
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(latents)
|
||||
bsz = latents.shape[0]
|
||||
# Sample a random timestep for each image
|
||||
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
|
||||
timesteps = timesteps.long()
|
||||
|
||||
# Add noise to the latents according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process)
|
||||
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||
|
||||
# concatenate the noised latents with the mask and the masked latents
|
||||
latent_model_input = torch.cat([noisy_latents, mask, masked_latents], dim=1)
|
||||
|
||||
# Get the text embedding for conditioning
|
||||
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
|
||||
|
||||
# Predict the noise residual
|
||||
noise_pred = unet(latent_model_input, timesteps, encoder_hidden_states).sample
|
||||
|
||||
# Get the target for loss depending on the prediction type
|
||||
if noise_scheduler.config.prediction_type == "epsilon":
|
||||
target = noise
|
||||
elif noise_scheduler.config.prediction_type == "v_prediction":
|
||||
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
||||
else:
|
||||
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
|
||||
|
||||
loss = F.mse_loss(noise_pred.float(), target.float(), reduction="mean")
|
||||
|
||||
accelerator.backward(loss)
|
||||
if accelerator.sync_gradients:
|
||||
params_to_clip = (
|
||||
itertools.chain(unet.parameters(), text_encoder.parameters())
|
||||
if args.train_text_encoder
|
||||
else unet.parameters()
|
||||
)
|
||||
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
||||
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
|
||||
if accelerator.is_main_process:
|
||||
if (
|
||||
global_step % args.validation_steps == 0
|
||||
and global_step >= args.validation_from
|
||||
and args.report_to_wandb
|
||||
):
|
||||
log_validation(
|
||||
val_pipeline,
|
||||
text_encoder,
|
||||
unet,
|
||||
val_pairs,
|
||||
accelerator,
|
||||
)
|
||||
|
||||
if global_step % args.checkpointing_steps == 0 and global_step >= args.checkpointing_from:
|
||||
checkpoint(
|
||||
args,
|
||||
global_step,
|
||||
accelerator,
|
||||
)
|
||||
|
||||
# Step logging
|
||||
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
|
||||
progress_bar.set_postfix(**logs)
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
if global_step >= args.max_train_steps:
|
||||
break
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# Terminate training
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -44,7 +44,7 @@ import diffusers
|
||||
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, StableDiffusionPipeline, UNet2DConditionModel
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import compute_snr
|
||||
from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available
|
||||
from diffusers.utils import check_min_version, is_wandb_available
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
@@ -486,9 +486,6 @@ def main():
|
||||
|
||||
lora_layers = filter(lambda p: p.requires_grad, unet.parameters())
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
unet.enable_gradient_checkpointing()
|
||||
|
||||
# Enable TF32 for faster training on Ampere GPUs,
|
||||
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
|
||||
if args.allow_tf32:
|
||||
@@ -812,9 +809,7 @@ def main():
|
||||
accelerator.save_state(save_path)
|
||||
|
||||
unwrapped_unet = accelerator.unwrap_model(unet)
|
||||
unet_lora_state_dict = convert_state_dict_to_diffusers(
|
||||
get_peft_model_state_dict(unwrapped_unet)
|
||||
)
|
||||
unet_lora_state_dict = get_peft_model_state_dict(unwrapped_unet)
|
||||
|
||||
StableDiffusionPipeline.save_lora_weights(
|
||||
save_directory=save_path,
|
||||
@@ -881,7 +876,7 @@ def main():
|
||||
unet = unet.to(torch.float32)
|
||||
|
||||
unwrapped_unet = accelerator.unwrap_model(unet)
|
||||
unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unwrapped_unet))
|
||||
unet_lora_state_dict = get_peft_model_state_dict(unwrapped_unet)
|
||||
StableDiffusionPipeline.save_lora_weights(
|
||||
save_directory=args.output_dir,
|
||||
unet_lora_layers=unet_lora_state_dict,
|
||||
|
||||
@@ -52,7 +52,7 @@ from diffusers import (
|
||||
from diffusers.loaders import LoraLoaderMixin
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import compute_snr
|
||||
from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available
|
||||
from diffusers.utils import check_min_version, is_wandb_available
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
@@ -651,15 +651,11 @@ def main(args):
|
||||
|
||||
for model in models:
|
||||
if isinstance(model, type(accelerator.unwrap_model(unet))):
|
||||
unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model))
|
||||
unet_lora_layers_to_save = get_peft_model_state_dict(model)
|
||||
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
|
||||
text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers(
|
||||
get_peft_model_state_dict(model)
|
||||
)
|
||||
text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model)
|
||||
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))):
|
||||
text_encoder_two_lora_layers_to_save = convert_state_dict_to_diffusers(
|
||||
get_peft_model_state_dict(model)
|
||||
)
|
||||
text_encoder_two_lora_layers_to_save = get_peft_model_state_dict(model)
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
|
||||
@@ -706,12 +702,6 @@ def main(args):
|
||||
accelerator.register_save_state_pre_hook(save_model_hook)
|
||||
accelerator.register_load_state_pre_hook(load_model_hook)
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
unet.enable_gradient_checkpointing()
|
||||
if args.train_text_encoder:
|
||||
text_encoder_one.gradient_checkpointing_enable()
|
||||
text_encoder_two.gradient_checkpointing_enable()
|
||||
|
||||
# Enable TF32 for faster training on Ampere GPUs,
|
||||
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
|
||||
if args.allow_tf32:
|
||||
@@ -1170,14 +1160,14 @@ def main(args):
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
unet = accelerator.unwrap_model(unet)
|
||||
unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet))
|
||||
unet_lora_state_dict = get_peft_model_state_dict(unet)
|
||||
|
||||
if args.train_text_encoder:
|
||||
text_encoder_one = accelerator.unwrap_model(text_encoder_one)
|
||||
text_encoder_two = accelerator.unwrap_model(text_encoder_two)
|
||||
|
||||
text_encoder_lora_layers = convert_state_dict_to_diffusers(get_peft_model_state_dict(text_encoder_one))
|
||||
text_encoder_2_lora_layers = convert_state_dict_to_diffusers(get_peft_model_state_dict(text_encoder_two))
|
||||
text_encoder_lora_layers = get_peft_model_state_dict(text_encoder_one)
|
||||
text_encoder_2_lora_layers = get_peft_model_state_dict(text_encoder_two)
|
||||
else:
|
||||
text_encoder_lora_layers = None
|
||||
text_encoder_2_lora_layers = None
|
||||
|
||||
@@ -1,55 +0,0 @@
|
||||
# Script for converting a Hugging Face Diffusers trained SDXL LoRAs to Kohya format
|
||||
# This means that you can input your diffusers-trained LoRAs and
|
||||
# Get the output to work with WebUIs such as AUTOMATIC1111, ComfyUI, SD.Next and others.
|
||||
|
||||
# To get started you can find some cool `diffusers` trained LoRAs such as this cute Corgy
|
||||
# https://huggingface.co/ignasbud/corgy_dog_LoRA/, download its `pytorch_lora_weights.safetensors` file
|
||||
# and run the script:
|
||||
# python convert_diffusers_sdxl_lora_to_webui.py --input_lora pytorch_lora_weights.safetensors --output_lora corgy.safetensors
|
||||
# now you can use corgy.safetensors in your WebUI of choice!
|
||||
|
||||
# To train your own, here are some diffusers training scripts and utils that you can use and then convert:
|
||||
# LoRA Ease - no code SDXL Dreambooth LoRA trainer: https://huggingface.co/spaces/multimodalart/lora-ease
|
||||
# Dreambooth Advanced Training Script - state of the art techniques such as pivotal tuning and prodigy optimizer:
|
||||
# - Script: https://github.com/huggingface/diffusers/blob/main/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py
|
||||
# - Colab (only on Pro): https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/SDXL_Dreambooth_LoRA_advanced_example.ipynb
|
||||
# Canonical diffusers training scripts:
|
||||
# - Script: https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth_lora_sdxl.py
|
||||
# - Colab (runs on free tier): https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/SDXL_DreamBooth_LoRA_.ipynb
|
||||
|
||||
import argparse
|
||||
import os
|
||||
|
||||
from safetensors.torch import load_file, save_file
|
||||
|
||||
from diffusers.utils import convert_all_state_dict_to_peft, convert_state_dict_to_kohya
|
||||
|
||||
|
||||
def convert_and_save(input_lora, output_lora=None):
|
||||
if output_lora is None:
|
||||
base_name = os.path.splitext(input_lora)[0]
|
||||
output_lora = f"{base_name}_webui.safetensors"
|
||||
|
||||
diffusers_state_dict = load_file(input_lora)
|
||||
peft_state_dict = convert_all_state_dict_to_peft(diffusers_state_dict)
|
||||
kohya_state_dict = convert_state_dict_to_kohya(peft_state_dict)
|
||||
save_file(kohya_state_dict, output_lora)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Convert LoRA model to PEFT and then to Kohya format.")
|
||||
parser.add_argument(
|
||||
"input_lora",
|
||||
type=str,
|
||||
help="Path to the input LoRA model file in the diffusers format.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"output_lora",
|
||||
type=str,
|
||||
nargs="?",
|
||||
help="Path for the converted LoRA (safetensors format for AUTOMATIC1111, ComfyUI, etc.). Optional, defaults to input name with a _webui suffix.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
convert_and_save(args.input_lora, args.output_lora)
|
||||
@@ -49,9 +49,10 @@ class ValueGuidedRLPipeline(DiffusionPipeline):
|
||||
env,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(value_function=value_function, unet=unet, scheduler=scheduler, env=env)
|
||||
|
||||
self.value_function = value_function
|
||||
self.unet = unet
|
||||
self.scheduler = scheduler
|
||||
self.env = env
|
||||
self.data = env.get_dataset()
|
||||
self.means = {}
|
||||
for key in self.data.keys():
|
||||
|
||||
@@ -634,9 +634,7 @@ class VaeImageProcessor(ConfigMixin):
|
||||
init_image_masked = init_image_masked.convert("RGBA")
|
||||
|
||||
if crop_coords is not None:
|
||||
x, y, x2, y2 = crop_coords
|
||||
w = x2 - x
|
||||
h = y2 - y
|
||||
x, y, w, h = crop_coords
|
||||
base_image = PIL.Image.new("RGBA", (width, height))
|
||||
image = self.resize(image, height=h, width=w, resize_mode="crop")
|
||||
base_image.paste(image, (x, y))
|
||||
|
||||
@@ -132,7 +132,7 @@ class IPAdapterMixin:
|
||||
if keys != ["image_proj", "ip_adapter"]:
|
||||
raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.")
|
||||
|
||||
# load CLIP image encoder here if it has not been registered to the pipeline yet
|
||||
# load CLIP image encoer here if it has not been registered to the pipeline yet
|
||||
if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is None:
|
||||
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
|
||||
logger.info(f"loading image_encoder from {pretrained_model_name_or_path_or_dict}")
|
||||
@@ -141,14 +141,12 @@ class IPAdapterMixin:
|
||||
subfolder=os.path.join(subfolder, "image_encoder"),
|
||||
).to(self.device, dtype=self.dtype)
|
||||
self.image_encoder = image_encoder
|
||||
self.register_to_config(image_encoder=["transformers", "CLIPVisionModelWithProjection"])
|
||||
else:
|
||||
raise ValueError("`image_encoder` cannot be None when using IP Adapters.")
|
||||
|
||||
# create feature extractor if it has not been registered to the pipeline yet
|
||||
if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is None:
|
||||
self.feature_extractor = CLIPImageProcessor()
|
||||
self.register_to_config(feature_extractor=["transformers", "CLIPImageProcessor"])
|
||||
|
||||
# load ip-adapter into unet
|
||||
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
|
||||
@@ -159,32 +157,3 @@ class IPAdapterMixin:
|
||||
for attn_processor in unet.attn_processors.values():
|
||||
if isinstance(attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)):
|
||||
attn_processor.scale = scale
|
||||
|
||||
def unload_ip_adapter(self):
|
||||
"""
|
||||
Unloads the IP Adapter weights
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
>>> # Assuming `pipeline` is already loaded with the IP Adapter weights.
|
||||
>>> pipeline.unload_ip_adapter()
|
||||
>>> ...
|
||||
```
|
||||
"""
|
||||
# remove CLIP image encoder
|
||||
if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is not None:
|
||||
self.image_encoder = None
|
||||
self.register_to_config(image_encoder=[None, None])
|
||||
|
||||
# remove feature extractor
|
||||
if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is not None:
|
||||
self.feature_extractor = None
|
||||
self.register_to_config(feature_extractor=[None, None])
|
||||
|
||||
# remove hidden encoder
|
||||
self.unet.encoder_hid_proj = None
|
||||
self.config.encoder_hid_dim_type = None
|
||||
|
||||
# restore original Unet attention processors layers
|
||||
self.unet.set_default_attn_processor()
|
||||
|
||||
@@ -980,7 +980,7 @@ class LoraLoaderMixin:
|
||||
|
||||
if not USE_PEFT_BACKEND:
|
||||
if version.parse(__version__) > version.parse("0.23"):
|
||||
logger.warning(
|
||||
logger.warn(
|
||||
"You are using `unload_lora_weights` to disable and unload lora weights. If you want to iteratively enable and disable adapter weights,"
|
||||
"you can use `pipe.enable_lora()` or `pipe.disable_lora()`. After installing the latest version of PEFT."
|
||||
)
|
||||
|
||||
@@ -24,7 +24,7 @@ import torch.nn.functional as F
|
||||
from huggingface_hub.utils import validate_hf_hub_args
|
||||
from torch import nn
|
||||
|
||||
from ..models.embeddings import ImageProjection, IPAdapterFullImageProjection, IPAdapterPlusImageProjection
|
||||
from ..models.embeddings import ImageProjection, MLPProjection, Resampler
|
||||
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
|
||||
from ..utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
@@ -712,7 +712,7 @@ class UNet2DConditionLoadersMixin:
|
||||
clip_embeddings_dim = state_dict["proj.0.weight"].shape[0]
|
||||
cross_attention_dim = state_dict["proj.3.weight"].shape[0]
|
||||
|
||||
image_projection = IPAdapterFullImageProjection(
|
||||
image_projection = MLPProjection(
|
||||
cross_attention_dim=cross_attention_dim, image_embed_dim=clip_embeddings_dim
|
||||
)
|
||||
|
||||
@@ -730,7 +730,7 @@ class UNet2DConditionLoadersMixin:
|
||||
hidden_dims = state_dict["latents"].shape[2]
|
||||
heads = state_dict["layers.0.0.to_q.weight"].shape[0] // 64
|
||||
|
||||
image_projection = IPAdapterPlusImageProjection(
|
||||
image_projection = Resampler(
|
||||
embed_dims=embed_dims,
|
||||
output_dims=output_dims,
|
||||
hidden_dims=hidden_dims,
|
||||
@@ -780,7 +780,7 @@ class UNet2DConditionLoadersMixin:
|
||||
num_image_text_embeds = state_dict["image_proj"]["latents"].shape[1]
|
||||
|
||||
# Set encoder_hid_proj after loading ip_adapter weights,
|
||||
# because `IPAdapterPlusImageProjection` also has `attn_processors`.
|
||||
# because `Resampler` also has `attn_processors`.
|
||||
self.encoder_hid_proj = None
|
||||
|
||||
# set ip-adapter cross-attention processors & load state_dict
|
||||
|
||||
@@ -498,7 +498,7 @@ class TemporalBasicTransformerBlock(nn.Module):
|
||||
hidden_states = self.norm_in(hidden_states)
|
||||
|
||||
if self._chunk_size is not None:
|
||||
hidden_states = _chunked_feed_forward(self.ff_in, hidden_states, self._chunk_dim, self._chunk_size)
|
||||
hidden_states = _chunked_feed_forward(self.ff, hidden_states, self._chunk_dim, self._chunk_size)
|
||||
else:
|
||||
hidden_states = self.ff_in(hidden_states)
|
||||
|
||||
|
||||
@@ -373,14 +373,29 @@ class Attention(nn.Module):
|
||||
|
||||
self.set_processor(processor)
|
||||
|
||||
def set_processor(self, processor: "AttnProcessor") -> None:
|
||||
def set_processor(self, processor: "AttnProcessor", _remove_lora: bool = False) -> None:
|
||||
r"""
|
||||
Set the attention processor to use.
|
||||
|
||||
Args:
|
||||
processor (`AttnProcessor`):
|
||||
The attention processor to use.
|
||||
_remove_lora (`bool`, *optional*, defaults to `False`):
|
||||
Set to `True` to remove LoRA layers from the model.
|
||||
"""
|
||||
if not USE_PEFT_BACKEND and hasattr(self, "processor") and _remove_lora and self.to_q.lora_layer is not None:
|
||||
deprecate(
|
||||
"set_processor to offload LoRA",
|
||||
"0.26.0",
|
||||
"In detail, removing LoRA layers via calling `set_default_attn_processor` is deprecated. Please make sure to call `pipe.unload_lora_weights()` instead.",
|
||||
)
|
||||
# TODO(Patrick, Sayak) - this can be deprecated once PEFT LoRA integration is complete
|
||||
# We need to remove all LoRA layers
|
||||
# Don't forget to remove ALL `_remove_lora` from the codebase
|
||||
for module in self.modules():
|
||||
if hasattr(module, "set_lora_layer"):
|
||||
module.set_lora_layer(None)
|
||||
|
||||
# if current processor is in `self._modules` and if passed `processor` is not, we need to
|
||||
# pop `processor` from `self._modules`
|
||||
if (
|
||||
|
||||
@@ -182,7 +182,9 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
def set_attn_processor(
|
||||
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
|
||||
):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
@@ -206,9 +208,9 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
module.set_processor(processor, _remove_lora=_remove_lora)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"))
|
||||
module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
@@ -230,7 +232,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
|
||||
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
|
||||
)
|
||||
|
||||
self.set_attn_processor(processor)
|
||||
self.set_attn_processor(processor, _remove_lora=True)
|
||||
|
||||
@apply_forward_hook
|
||||
def encode(
|
||||
|
||||
@@ -267,7 +267,9 @@ class AutoencoderKLTemporalDecoder(ModelMixin, ConfigMixin, FromOriginalVAEMixin
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
def set_attn_processor(
|
||||
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
|
||||
):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
@@ -291,9 +293,9 @@ class AutoencoderKLTemporalDecoder(ModelMixin, ConfigMixin, FromOriginalVAEMixin
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
module.set_processor(processor, _remove_lora=_remove_lora)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"))
|
||||
module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
@@ -312,7 +314,7 @@ class AutoencoderKLTemporalDecoder(ModelMixin, ConfigMixin, FromOriginalVAEMixin
|
||||
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
|
||||
)
|
||||
|
||||
self.set_attn_processor(processor)
|
||||
self.set_attn_processor(processor, _remove_lora=True)
|
||||
|
||||
@apply_forward_hook
|
||||
def encode(
|
||||
|
||||
@@ -212,7 +212,9 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
def set_attn_processor(
|
||||
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
|
||||
):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
@@ -236,9 +238,9 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
module.set_processor(processor, _remove_lora=_remove_lora)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"))
|
||||
module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
@@ -260,7 +262,7 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
|
||||
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
|
||||
)
|
||||
|
||||
self.set_attn_processor(processor)
|
||||
self.set_attn_processor(processor, _remove_lora=True)
|
||||
|
||||
@apply_forward_hook
|
||||
def encode(
|
||||
|
||||
@@ -534,7 +534,9 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
def set_attn_processor(
|
||||
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
|
||||
):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
@@ -558,9 +560,9 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
module.set_processor(processor, _remove_lora=_remove_lora)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"))
|
||||
module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
@@ -582,7 +584,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
|
||||
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
|
||||
)
|
||||
|
||||
self.set_attn_processor(processor)
|
||||
self.set_attn_processor(processor, _remove_lora=True)
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
|
||||
def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None:
|
||||
|
||||
@@ -462,7 +462,7 @@ class ImageProjection(nn.Module):
|
||||
return image_embeds
|
||||
|
||||
|
||||
class IPAdapterFullImageProjection(nn.Module):
|
||||
class MLPProjection(nn.Module):
|
||||
def __init__(self, image_embed_dim=1024, cross_attention_dim=1024):
|
||||
super().__init__()
|
||||
from .attention import FeedForward
|
||||
@@ -621,34 +621,29 @@ class AttentionPooling(nn.Module):
|
||||
return a[:, 0, :] # cls_token
|
||||
|
||||
|
||||
def get_fourier_embeds_from_boundingbox(embed_dim, box):
|
||||
"""
|
||||
Args:
|
||||
embed_dim: int
|
||||
box: a 3-D tensor [B x N x 4] representing the bounding boxes for GLIGEN pipeline
|
||||
Returns:
|
||||
[B x N x embed_dim] tensor of positional embeddings
|
||||
"""
|
||||
class FourierEmbedder(nn.Module):
|
||||
def __init__(self, num_freqs=64, temperature=100):
|
||||
super().__init__()
|
||||
|
||||
batch_size, num_boxes = box.shape[:2]
|
||||
self.num_freqs = num_freqs
|
||||
self.temperature = temperature
|
||||
|
||||
emb = 100 ** (torch.arange(embed_dim) / embed_dim)
|
||||
emb = emb[None, None, None].to(device=box.device, dtype=box.dtype)
|
||||
emb = emb * box.unsqueeze(-1)
|
||||
freq_bands = temperature ** (torch.arange(num_freqs) / num_freqs)
|
||||
freq_bands = freq_bands[None, None, None]
|
||||
self.register_buffer("freq_bands", freq_bands, persistent=False)
|
||||
|
||||
emb = torch.stack((emb.sin(), emb.cos()), dim=-1)
|
||||
emb = emb.permute(0, 1, 3, 4, 2).reshape(batch_size, num_boxes, embed_dim * 2 * 4)
|
||||
|
||||
return emb
|
||||
def __call__(self, x):
|
||||
x = self.freq_bands * x.unsqueeze(-1)
|
||||
return torch.stack((x.sin(), x.cos()), dim=-1).permute(0, 1, 3, 4, 2).reshape(*x.shape[:2], -1)
|
||||
|
||||
|
||||
class GLIGENTextBoundingboxProjection(nn.Module):
|
||||
class PositionNet(nn.Module):
|
||||
def __init__(self, positive_len, out_dim, feature_type="text-only", fourier_freqs=8):
|
||||
super().__init__()
|
||||
self.positive_len = positive_len
|
||||
self.out_dim = out_dim
|
||||
|
||||
self.fourier_embedder_dim = fourier_freqs
|
||||
self.fourier_embedder = FourierEmbedder(num_freqs=fourier_freqs)
|
||||
self.position_dim = fourier_freqs * 2 * 4 # 2: sin/cos, 4: xyxy
|
||||
|
||||
if isinstance(out_dim, tuple):
|
||||
@@ -697,7 +692,7 @@ class GLIGENTextBoundingboxProjection(nn.Module):
|
||||
masks = masks.unsqueeze(-1)
|
||||
|
||||
# embedding position (it may includes padding as placeholder)
|
||||
xyxy_embedding = get_fourier_embeds_from_boundingbox(self.fourier_embedder_dim, boxes) # B*N*4 -> B*N*C
|
||||
xyxy_embedding = self.fourier_embedder(boxes) # B*N*4 -> B*N*C
|
||||
|
||||
# learnable null embedding
|
||||
xyxy_null = self.null_position_feature.view(1, 1, -1)
|
||||
@@ -792,7 +787,7 @@ class PixArtAlphaTextProjection(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class IPAdapterPlusImageProjection(nn.Module):
|
||||
class Resampler(nn.Module):
|
||||
"""Resampler of IP-Adapter Plus.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -192,7 +192,9 @@ class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
def set_attn_processor(
|
||||
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
|
||||
):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
@@ -216,9 +218,9 @@ class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
module.set_processor(processor, _remove_lora=_remove_lora)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"))
|
||||
module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
@@ -240,7 +242,7 @@ class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
||||
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
|
||||
)
|
||||
|
||||
self.set_attn_processor(processor)
|
||||
self.set_attn_processor(processor, _remove_lora=True)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@@ -32,10 +32,10 @@ from .attention_processor import (
|
||||
)
|
||||
from .embeddings import (
|
||||
GaussianFourierProjection,
|
||||
GLIGENTextBoundingboxProjection,
|
||||
ImageHintTimeEmbedding,
|
||||
ImageProjection,
|
||||
ImageTimeEmbedding,
|
||||
PositionNet,
|
||||
TextImageProjection,
|
||||
TextImageTimeEmbedding,
|
||||
TextTimeEmbedding,
|
||||
@@ -615,7 +615,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
positive_len = cross_attention_dim[0]
|
||||
|
||||
feature_type = "text-only" if attention_type == "gated" else "text-image"
|
||||
self.position_net = GLIGENTextBoundingboxProjection(
|
||||
self.position_net = PositionNet(
|
||||
positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type
|
||||
)
|
||||
|
||||
@@ -643,7 +643,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
|
||||
return processors
|
||||
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
def set_attn_processor(
|
||||
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
|
||||
):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
@@ -667,9 +669,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
module.set_processor(processor, _remove_lora=_remove_lora)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"))
|
||||
module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
@@ -690,7 +692,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
|
||||
)
|
||||
|
||||
self.set_attn_processor(processor)
|
||||
self.set_attn_processor(processor, _remove_lora=True)
|
||||
|
||||
def set_attention_slice(self, slice_size):
|
||||
r"""
|
||||
|
||||
@@ -375,7 +375,9 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
fn_recursive_set_attention_slice(module, reversed_slice_size)
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
def set_attn_processor(
|
||||
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
|
||||
):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
@@ -399,9 +401,9 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
module.set_processor(processor, _remove_lora=_remove_lora)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"))
|
||||
module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
@@ -463,7 +465,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
|
||||
)
|
||||
|
||||
self.set_attn_processor(processor)
|
||||
self.set_attn_processor(processor, _remove_lora=True)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
|
||||
if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
|
||||
|
||||
@@ -549,7 +549,9 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
def set_attn_processor(
|
||||
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
|
||||
):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
@@ -573,9 +575,9 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
module.set_processor(processor, _remove_lora=_remove_lora)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"))
|
||||
module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
@@ -639,7 +641,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
||||
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
|
||||
)
|
||||
|
||||
self.set_attn_processor(processor)
|
||||
self.set_attn_processor(processor, _remove_lora=True)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
|
||||
if isinstance(module, (CrossAttnDownBlockMotion, DownBlockMotion, CrossAttnUpBlockMotion, UpBlockMotion)):
|
||||
|
||||
@@ -237,7 +237,9 @@ class UVit2DModel(ModelMixin, ConfigMixin):
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
def set_attn_processor(
|
||||
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
|
||||
):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
@@ -261,9 +263,9 @@ class UVit2DModel(ModelMixin, ConfigMixin):
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
module.set_processor(processor, _remove_lora=_remove_lora)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"))
|
||||
module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
@@ -285,7 +287,7 @@ class UVit2DModel(ModelMixin, ConfigMixin):
|
||||
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
|
||||
)
|
||||
|
||||
self.set_attn_processor(processor)
|
||||
self.set_attn_processor(processor, _remove_lora=True)
|
||||
|
||||
|
||||
class UVit2DConvEmbed(nn.Module):
|
||||
|
||||
@@ -538,7 +538,9 @@ class AudioLDM2UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoad
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
def set_attn_processor(
|
||||
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
|
||||
):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
@@ -562,9 +564,9 @@ class AudioLDM2UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoad
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
module.set_processor(processor, _remove_lora=_remove_lora)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"))
|
||||
module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
@@ -586,7 +588,7 @@ class AudioLDM2UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoad
|
||||
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
|
||||
)
|
||||
|
||||
self.set_attn_processor(processor)
|
||||
self.set_attn_processor(processor, _remove_lora=True)
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
|
||||
def set_attention_slice(self, slice_size):
|
||||
|
||||
@@ -24,7 +24,6 @@ from .controlnet import (
|
||||
StableDiffusionControlNetInpaintPipeline,
|
||||
StableDiffusionControlNetPipeline,
|
||||
StableDiffusionXLControlNetImg2ImgPipeline,
|
||||
StableDiffusionXLControlNetInpaintPipeline,
|
||||
StableDiffusionXLControlNetPipeline,
|
||||
)
|
||||
from .deepfloyd_if import IFImg2ImgPipeline, IFInpaintingPipeline, IFPipeline
|
||||
@@ -98,7 +97,6 @@ AUTO_INPAINT_PIPELINES_MAPPING = OrderedDict(
|
||||
("kandinsky", KandinskyInpaintCombinedPipeline),
|
||||
("kandinsky22", KandinskyV22InpaintCombinedPipeline),
|
||||
("stable-diffusion-controlnet", StableDiffusionControlNetInpaintPipeline),
|
||||
("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetInpaintPipeline),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@@ -187,7 +187,7 @@ class FourierEmbedder(nn.Module):
|
||||
return torch.stack((x.sin(), x.cos()), dim=-1).permute(0, 1, 3, 4, 2).reshape(*x.shape[:2], -1)
|
||||
|
||||
|
||||
class GLIGENTextBoundingboxProjection(nn.Module):
|
||||
class PositionNet(nn.Module):
|
||||
def __init__(self, positive_len, out_dim, feature_type, fourier_freqs=8):
|
||||
super().__init__()
|
||||
self.positive_len = positive_len
|
||||
@@ -820,7 +820,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
positive_len = cross_attention_dim[0]
|
||||
|
||||
feature_type = "text-only" if attention_type == "gated" else "text-image"
|
||||
self.position_net = GLIGENTextBoundingboxProjection(
|
||||
self.position_net = PositionNet(
|
||||
positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type
|
||||
)
|
||||
|
||||
@@ -848,7 +848,9 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
|
||||
return processors
|
||||
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
def set_attn_processor(
|
||||
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
|
||||
):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
@@ -872,9 +874,9 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
module.set_processor(processor, _remove_lora=_remove_lora)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"))
|
||||
module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
@@ -895,7 +897,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
|
||||
)
|
||||
|
||||
self.set_attn_processor(processor)
|
||||
self.set_attn_processor(processor, _remove_lora=True)
|
||||
|
||||
def set_attention_slice(self, slice_size):
|
||||
r"""
|
||||
|
||||
@@ -530,36 +530,6 @@ def load_sub_model(
|
||||
return loaded_sub_model
|
||||
|
||||
|
||||
def _fetch_class_library_tuple(module):
|
||||
# import it here to avoid circular import
|
||||
diffusers_module = importlib.import_module(__name__.split(".")[0])
|
||||
pipelines = getattr(diffusers_module, "pipelines")
|
||||
|
||||
# register the config from the original module, not the dynamo compiled one
|
||||
not_compiled_module = _unwrap_model(module)
|
||||
library = not_compiled_module.__module__.split(".")[0]
|
||||
|
||||
# check if the module is a pipeline module
|
||||
module_path_items = not_compiled_module.__module__.split(".")
|
||||
pipeline_dir = module_path_items[-2] if len(module_path_items) > 2 else None
|
||||
|
||||
path = not_compiled_module.__module__.split(".")
|
||||
is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir)
|
||||
|
||||
# if library is not in LOADABLE_CLASSES, then it is a custom module.
|
||||
# Or if it's a pipeline module, then the module is inside the pipeline
|
||||
# folder so we set the library to module name.
|
||||
if is_pipeline_module:
|
||||
library = pipeline_dir
|
||||
elif library not in LOADABLE_CLASSES:
|
||||
library = not_compiled_module.__module__
|
||||
|
||||
# retrieve class_name
|
||||
class_name = not_compiled_module.__class__.__name__
|
||||
|
||||
return (library, class_name)
|
||||
|
||||
|
||||
class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
r"""
|
||||
Base class for all pipelines.
|
||||
@@ -586,12 +556,38 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
_is_onnx = False
|
||||
|
||||
def register_modules(self, **kwargs):
|
||||
# import it here to avoid circular import
|
||||
diffusers_module = importlib.import_module(__name__.split(".")[0])
|
||||
pipelines = getattr(diffusers_module, "pipelines")
|
||||
|
||||
for name, module in kwargs.items():
|
||||
# retrieve library
|
||||
if module is None or isinstance(module, (tuple, list)) and module[0] is None:
|
||||
register_dict = {name: (None, None)}
|
||||
else:
|
||||
library, class_name = _fetch_class_library_tuple(module)
|
||||
# register the config from the original module, not the dynamo compiled one
|
||||
not_compiled_module = _unwrap_model(module)
|
||||
|
||||
library = not_compiled_module.__module__.split(".")[0]
|
||||
|
||||
# check if the module is a pipeline module
|
||||
module_path_items = not_compiled_module.__module__.split(".")
|
||||
pipeline_dir = module_path_items[-2] if len(module_path_items) > 2 else None
|
||||
|
||||
path = not_compiled_module.__module__.split(".")
|
||||
is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir)
|
||||
|
||||
# if library is not in LOADABLE_CLASSES, then it is a custom module.
|
||||
# Or if it's a pipeline module, then the module is inside the pipeline
|
||||
# folder so we set the library to module name.
|
||||
if is_pipeline_module:
|
||||
library = pipeline_dir
|
||||
elif library not in LOADABLE_CLASSES:
|
||||
library = not_compiled_module.__module__
|
||||
|
||||
# retrieve class_name
|
||||
class_name = not_compiled_module.__class__.__name__
|
||||
|
||||
register_dict = {name: (library, class_name)}
|
||||
|
||||
# save model index config
|
||||
@@ -605,7 +601,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
# We need to overwrite the config if name exists in config
|
||||
if isinstance(getattr(self.config, name), (tuple, list)):
|
||||
if value is not None and self.config[name][0] is not None:
|
||||
class_library_tuple = _fetch_class_library_tuple(value)
|
||||
class_library_tuple = (value.__module__.split(".")[0], value.__class__.__name__)
|
||||
else:
|
||||
class_library_tuple = (None, None)
|
||||
|
||||
|
||||
@@ -730,7 +730,7 @@ class StableDiffusionGLIGENPipeline(DiffusionPipeline):
|
||||
)
|
||||
gligen_phrases = gligen_phrases[:max_objs]
|
||||
gligen_boxes = gligen_boxes[:max_objs]
|
||||
# prepare batched input to the GLIGENTextBoundingboxProjection (boxes, phrases, mask)
|
||||
# prepare batched input to the PositionNet (boxes, phrases, mask)
|
||||
# Get tokens for phrases from pre-trained CLIPTokenizer
|
||||
tokenizer_inputs = self.tokenizer(gligen_phrases, padding=True, return_tensors="pt").to(device)
|
||||
# For the token, we use the same pre-trained text encoder
|
||||
|
||||
@@ -311,7 +311,7 @@ class StableVideoDiffusionPipeline(DiffusionPipeline):
|
||||
max_guidance_scale: float = 3.0,
|
||||
fps: int = 7,
|
||||
motion_bucket_id: int = 127,
|
||||
noise_aug_strength: float = 0.02,
|
||||
noise_aug_strength: int = 0.02,
|
||||
decode_chunk_size: Optional[int] = None,
|
||||
num_videos_per_prompt: Optional[int] = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
@@ -346,7 +346,7 @@ class StableVideoDiffusionPipeline(DiffusionPipeline):
|
||||
Note that Stable Diffusion Video's UNet was micro-conditioned on fps-1 during training.
|
||||
motion_bucket_id (`int`, *optional*, defaults to 127):
|
||||
The motion bucket ID. Used as conditioning for the generation. The higher the number the more motion will be in the video.
|
||||
noise_aug_strength (`float`, *optional*, defaults to 0.02):
|
||||
noise_aug_strength (`int`, *optional*, defaults to 0.02):
|
||||
The amount of noise added to the init image, the higher it is the less the video will look like the init image. Increase it for more motion.
|
||||
decode_chunk_size (`int`, *optional*):
|
||||
The number of frames to decode at a time. The higher the chunk size, the higher the temporal consistency
|
||||
|
||||
@@ -91,7 +91,9 @@ class WuerstchenPrior(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
def set_attn_processor(
|
||||
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
|
||||
):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
@@ -115,9 +117,9 @@ class WuerstchenPrior(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
module.set_processor(processor, _remove_lora=_remove_lora)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"))
|
||||
module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
@@ -139,7 +141,7 @@ class WuerstchenPrior(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
||||
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
|
||||
)
|
||||
|
||||
self.set_attn_processor(processor)
|
||||
self.set_attn_processor(processor, _remove_lora=True)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
self.gradient_checkpointing = value
|
||||
|
||||
@@ -277,11 +277,7 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.sigmas_up = torch.cat([sigmas_up[:1], sigmas_up[1:].repeat_interleave(2), sigmas_up[-1:]])
|
||||
self.sigmas_down = torch.cat([sigmas_down[:1], sigmas_down[1:].repeat_interleave(2), sigmas_down[-1:]])
|
||||
|
||||
if str(device).startswith("mps"):
|
||||
timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32)
|
||||
else:
|
||||
timesteps = torch.from_numpy(timesteps).to(device)
|
||||
|
||||
timesteps = torch.from_numpy(timesteps).to(device)
|
||||
sigmas_interpol = sigmas_interpol.cpu()
|
||||
log_sigmas = self.log_sigmas.cpu()
|
||||
timesteps_interpol = np.array(
|
||||
|
||||
@@ -98,9 +98,7 @@ from .peft_utils import (
|
||||
)
|
||||
from .pil_utils import PIL_INTERPOLATION, make_image_grid, numpy_to_pil, pt_to_pil
|
||||
from .state_dict_utils import (
|
||||
convert_all_state_dict_to_peft,
|
||||
convert_state_dict_to_diffusers,
|
||||
convert_state_dict_to_kohya,
|
||||
convert_state_dict_to_peft,
|
||||
convert_unet_state_dict_to_peft,
|
||||
)
|
||||
|
||||
@@ -16,11 +16,6 @@ State dict utilities: utility methods for converting state dicts easily
|
||||
"""
|
||||
import enum
|
||||
|
||||
from .logging import get_logger
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class StateDictType(enum.Enum):
|
||||
"""
|
||||
@@ -28,7 +23,7 @@ class StateDictType(enum.Enum):
|
||||
"""
|
||||
|
||||
DIFFUSERS_OLD = "diffusers_old"
|
||||
KOHYA_SS = "kohya_ss"
|
||||
# KOHYA_SS = "kohya_ss" # TODO: implement this
|
||||
PEFT = "peft"
|
||||
DIFFUSERS = "diffusers"
|
||||
|
||||
@@ -105,14 +100,6 @@ DIFFUSERS_OLD_TO_DIFFUSERS = {
|
||||
".to_out_lora.down": ".out_proj.lora_linear_layer.down",
|
||||
}
|
||||
|
||||
PEFT_TO_KOHYA_SS = {
|
||||
"lora_A": "lora_down",
|
||||
"lora_B": "lora_up",
|
||||
# This is not a comprehensive dict as kohya format requires replacing `.` with `_` in keys,
|
||||
# adding prefixes and adding alpha values
|
||||
# Check `convert_state_dict_to_kohya` for more
|
||||
}
|
||||
|
||||
PEFT_STATE_DICT_MAPPINGS = {
|
||||
StateDictType.DIFFUSERS_OLD: DIFFUSERS_OLD_TO_PEFT,
|
||||
StateDictType.DIFFUSERS: DIFFUSERS_TO_PEFT,
|
||||
@@ -123,8 +110,6 @@ DIFFUSERS_STATE_DICT_MAPPINGS = {
|
||||
StateDictType.PEFT: PEFT_TO_DIFFUSERS,
|
||||
}
|
||||
|
||||
KOHYA_STATE_DICT_MAPPINGS = {StateDictType.PEFT: PEFT_TO_KOHYA_SS}
|
||||
|
||||
KEYS_TO_ALWAYS_REPLACE = {
|
||||
".processor.": ".",
|
||||
}
|
||||
@@ -243,82 +228,3 @@ def convert_unet_state_dict_to_peft(state_dict):
|
||||
"""
|
||||
mapping = UNET_TO_DIFFUSERS
|
||||
return convert_state_dict(state_dict, mapping)
|
||||
|
||||
|
||||
def convert_all_state_dict_to_peft(state_dict):
|
||||
r"""
|
||||
Attempts to first `convert_state_dict_to_peft`, and if it doesn't detect `lora_linear_layer`
|
||||
for a valid `DIFFUSERS` LoRA for example, attempts to exclusively convert the Unet `convert_unet_state_dict_to_peft`
|
||||
"""
|
||||
try:
|
||||
peft_dict = convert_state_dict_to_peft(state_dict)
|
||||
except Exception as e:
|
||||
if str(e) == "Could not automatically infer state dict type":
|
||||
peft_dict = convert_unet_state_dict_to_peft(state_dict)
|
||||
else:
|
||||
raise
|
||||
|
||||
if not any("lora_A" in key or "lora_B" in key for key in peft_dict.keys()):
|
||||
raise ValueError("Your LoRA was not converted to PEFT")
|
||||
|
||||
return peft_dict
|
||||
|
||||
|
||||
def convert_state_dict_to_kohya(state_dict, original_type=None, **kwargs):
|
||||
r"""
|
||||
Converts a `PEFT` state dict to `Kohya` format that can be used in AUTOMATIC1111, ComfyUI, SD.Next, InvokeAI, etc.
|
||||
The method only supports the conversion from PEFT to Kohya for now.
|
||||
|
||||
Args:
|
||||
state_dict (`dict[str, torch.Tensor]`):
|
||||
The state dict to convert.
|
||||
original_type (`StateDictType`, *optional*):
|
||||
The original type of the state dict, if not provided, the method will try to infer it automatically.
|
||||
kwargs (`dict`, *args*):
|
||||
Additional arguments to pass to the method.
|
||||
|
||||
- **adapter_name**: For example, in case of PEFT, some keys will be pre-pended
|
||||
with the adapter name, therefore needs a special handling. By default PEFT also takes care of that in
|
||||
`get_peft_model_state_dict` method:
|
||||
https://github.com/huggingface/peft/blob/ba0477f2985b1ba311b83459d29895c809404e99/src/peft/utils/save_and_load.py#L92
|
||||
but we add it here in case we don't want to rely on that method.
|
||||
"""
|
||||
try:
|
||||
import torch
|
||||
except ImportError:
|
||||
logger.error("Converting PEFT state dicts to Kohya requires torch to be installed.")
|
||||
raise
|
||||
|
||||
peft_adapter_name = kwargs.pop("adapter_name", None)
|
||||
if peft_adapter_name is not None:
|
||||
peft_adapter_name = "." + peft_adapter_name
|
||||
else:
|
||||
peft_adapter_name = ""
|
||||
|
||||
if original_type is None:
|
||||
if any(f".lora_A{peft_adapter_name}.weight" in k for k in state_dict.keys()):
|
||||
original_type = StateDictType.PEFT
|
||||
|
||||
if original_type not in KOHYA_STATE_DICT_MAPPINGS.keys():
|
||||
raise ValueError(f"Original type {original_type} is not supported")
|
||||
|
||||
# Use the convert_state_dict function with the appropriate mapping
|
||||
kohya_ss_partial_state_dict = convert_state_dict(state_dict, KOHYA_STATE_DICT_MAPPINGS[StateDictType.PEFT])
|
||||
kohya_ss_state_dict = {}
|
||||
|
||||
# Additional logic for replacing header, alpha parameters `.` with `_` in all keys
|
||||
for kohya_key, weight in kohya_ss_partial_state_dict.items():
|
||||
if "text_encoder_2." in kohya_key:
|
||||
kohya_key = kohya_key.replace("text_encoder_2.", "lora_te2.")
|
||||
elif "text_encoder." in kohya_key:
|
||||
kohya_key = kohya_key.replace("text_encoder.", "lora_te1.")
|
||||
elif "unet" in kohya_key:
|
||||
kohya_key = kohya_key.replace("unet", "lora_unet")
|
||||
kohya_key = kohya_key.replace(".", "_", kohya_key.count(".") - 2)
|
||||
kohya_key = kohya_key.replace(peft_adapter_name, "") # Kohya doesn't take names
|
||||
kohya_ss_state_dict[kohya_key] = weight
|
||||
if "lora_down" in kohya_key:
|
||||
alpha_key = f'{kohya_key.split(".")[0]}.alpha'
|
||||
kohya_ss_state_dict[alpha_key] = torch.tensor(len(weight))
|
||||
|
||||
return kohya_ss_state_dict
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -22,6 +22,7 @@ import unittest
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub.repocard import RepoCard
|
||||
from packaging import version
|
||||
@@ -40,6 +41,8 @@ from diffusers import (
|
||||
StableDiffusionXLPipeline,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.loaders import AttnProcsLayers
|
||||
from diffusers.models.attention_processor import LoRAAttnProcessor, LoRAAttnProcessor2_0
|
||||
from diffusers.utils.import_utils import is_accelerate_available, is_peft_available
|
||||
from diffusers.utils.testing_utils import (
|
||||
floats_tensor,
|
||||
@@ -75,6 +78,28 @@ def state_dicts_almost_equal(sd1, sd2):
|
||||
return models_are_equal
|
||||
|
||||
|
||||
def create_unet_lora_layers(unet: nn.Module):
|
||||
lora_attn_procs = {}
|
||||
for name in unet.attn_processors.keys():
|
||||
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
|
||||
if name.startswith("mid_block"):
|
||||
hidden_size = unet.config.block_out_channels[-1]
|
||||
elif name.startswith("up_blocks"):
|
||||
block_id = int(name[len("up_blocks.")])
|
||||
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
|
||||
elif name.startswith("down_blocks"):
|
||||
block_id = int(name[len("down_blocks.")])
|
||||
hidden_size = unet.config.block_out_channels[block_id]
|
||||
lora_attn_processor_class = (
|
||||
LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
|
||||
)
|
||||
lora_attn_procs[name] = lora_attn_processor_class(
|
||||
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
|
||||
)
|
||||
unet_lora_layers = AttnProcsLayers(lora_attn_procs)
|
||||
return lora_attn_procs, unet_lora_layers
|
||||
|
||||
|
||||
@require_peft_backend
|
||||
class PeftLoraLoaderMixinTests:
|
||||
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
@@ -115,6 +140,8 @@ class PeftLoraLoaderMixinTests:
|
||||
r=rank, lora_alpha=rank, target_modules=["to_q", "to_k", "to_v", "to_out.0"], init_lora_weights=False
|
||||
)
|
||||
|
||||
unet_lora_attn_procs, unet_lora_layers = create_unet_lora_layers(unet)
|
||||
|
||||
if self.has_two_text_encoders:
|
||||
pipeline_components = {
|
||||
"unet": unet,
|
||||
@@ -138,8 +165,11 @@ class PeftLoraLoaderMixinTests:
|
||||
"feature_extractor": None,
|
||||
"image_encoder": None,
|
||||
}
|
||||
|
||||
return pipeline_components, text_lora_config, unet_lora_config
|
||||
lora_components = {
|
||||
"unet_lora_layers": unet_lora_layers,
|
||||
"unet_lora_attn_procs": unet_lora_attn_procs,
|
||||
}
|
||||
return pipeline_components, lora_components, text_lora_config, unet_lora_config
|
||||
|
||||
def get_dummy_inputs(self, with_generator=True):
|
||||
batch_size = 1
|
||||
@@ -186,7 +216,7 @@ class PeftLoraLoaderMixinTests:
|
||||
Tests a simple inference and makes sure it works as expected
|
||||
"""
|
||||
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
|
||||
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
|
||||
components, _, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(self.torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
@@ -201,7 +231,7 @@ class PeftLoraLoaderMixinTests:
|
||||
and makes sure it works as expected
|
||||
"""
|
||||
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
|
||||
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
|
||||
components, _, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(self.torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
@@ -232,7 +262,7 @@ class PeftLoraLoaderMixinTests:
|
||||
and makes sure it works as expected
|
||||
"""
|
||||
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
|
||||
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
|
||||
components, _, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(self.torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
@@ -279,7 +309,7 @@ class PeftLoraLoaderMixinTests:
|
||||
and makes sure it works as expected
|
||||
"""
|
||||
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
|
||||
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
|
||||
components, _, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(self.torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
@@ -321,7 +351,7 @@ class PeftLoraLoaderMixinTests:
|
||||
and makes sure it works as expected
|
||||
"""
|
||||
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
|
||||
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
|
||||
components, _, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(self.torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
@@ -364,7 +394,7 @@ class PeftLoraLoaderMixinTests:
|
||||
Tests a simple usecase where users could use saving utilities for LoRA.
|
||||
"""
|
||||
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
|
||||
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
|
||||
components, _, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(self.torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
@@ -429,7 +459,7 @@ class PeftLoraLoaderMixinTests:
|
||||
Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained
|
||||
"""
|
||||
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
|
||||
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
|
||||
components, _, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(self.torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
@@ -480,7 +510,7 @@ class PeftLoraLoaderMixinTests:
|
||||
Tests a simple usecase where users could use saving utilities for LoRA for Unet + text encoder
|
||||
"""
|
||||
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
|
||||
components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
|
||||
components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(self.torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
@@ -553,7 +583,7 @@ class PeftLoraLoaderMixinTests:
|
||||
and makes sure it works as expected
|
||||
"""
|
||||
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
|
||||
components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
|
||||
components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(self.torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
@@ -607,7 +637,7 @@ class PeftLoraLoaderMixinTests:
|
||||
and makes sure it works as expected - with unet
|
||||
"""
|
||||
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
|
||||
components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
|
||||
components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(self.torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
@@ -653,7 +683,7 @@ class PeftLoraLoaderMixinTests:
|
||||
and makes sure it works as expected
|
||||
"""
|
||||
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
|
||||
components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
|
||||
components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(self.torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
@@ -700,7 +730,7 @@ class PeftLoraLoaderMixinTests:
|
||||
and makes sure it works as expected
|
||||
"""
|
||||
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
|
||||
components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
|
||||
components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(self.torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
@@ -750,7 +780,7 @@ class PeftLoraLoaderMixinTests:
|
||||
multiple adapters and set them
|
||||
"""
|
||||
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
|
||||
components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
|
||||
components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(self.torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
@@ -818,7 +848,7 @@ class PeftLoraLoaderMixinTests:
|
||||
multiple adapters and set/delete them
|
||||
"""
|
||||
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
|
||||
components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
|
||||
components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(self.torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
@@ -908,7 +938,7 @@ class PeftLoraLoaderMixinTests:
|
||||
multiple adapters and set them
|
||||
"""
|
||||
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
|
||||
components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
|
||||
components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(self.torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
@@ -980,7 +1010,7 @@ class PeftLoraLoaderMixinTests:
|
||||
|
||||
def test_lora_fuse_nan(self):
|
||||
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
|
||||
components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
|
||||
components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(self.torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
@@ -1018,7 +1048,7 @@ class PeftLoraLoaderMixinTests:
|
||||
are the expected results
|
||||
"""
|
||||
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
|
||||
components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
|
||||
components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(self.torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
@@ -1045,7 +1075,7 @@ class PeftLoraLoaderMixinTests:
|
||||
are the expected results
|
||||
"""
|
||||
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
|
||||
components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
|
||||
components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(self.torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
@@ -1083,7 +1113,7 @@ class PeftLoraLoaderMixinTests:
|
||||
and makes sure it works as expected - with unet and multi-adapter case
|
||||
"""
|
||||
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
|
||||
components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
|
||||
components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(self.torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
@@ -1145,7 +1175,7 @@ class PeftLoraLoaderMixinTests:
|
||||
and makes sure it works as expected
|
||||
"""
|
||||
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
|
||||
components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
|
||||
components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(self.torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
@@ -1,64 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from diffusers import DiffusionPipeline
|
||||
from diffusers.utils.testing_utils import torch_device
|
||||
|
||||
|
||||
class PEFTLoRALoading(unittest.TestCase):
|
||||
def get_dummy_inputs(self):
|
||||
pipeline_inputs = {
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 6.0,
|
||||
"output_type": "np",
|
||||
"generator": torch.manual_seed(0),
|
||||
}
|
||||
return pipeline_inputs
|
||||
|
||||
def test_stable_diffusion_peft_lora_loading_in_non_peft(self):
|
||||
sd_pipe = DiffusionPipeline.from_pretrained("hf-internal-testing/tiny-sd-pipe").to(torch_device)
|
||||
# This LoRA was obtained using similarly as how it's done in the training scripts.
|
||||
# For details on how the LoRA was obtained, refer to:
|
||||
# https://hf.co/datasets/diffusers/notebooks/blob/main/check_logits_with_serialization_peft_lora.py
|
||||
sd_pipe.load_lora_weights("hf-internal-testing/tiny-sd-lora-peft")
|
||||
|
||||
inputs = self.get_dummy_inputs()
|
||||
outputs = sd_pipe(**inputs).images
|
||||
|
||||
predicted_slice = outputs[0, -3:, -3:, -1].flatten()
|
||||
expected_slice = np.array([0.5396, 0.5707, 0.477, 0.4665, 0.5419, 0.4594, 0.4857, 0.4741, 0.4804])
|
||||
|
||||
self.assertTrue(outputs.shape == (1, 64, 64, 3))
|
||||
assert np.allclose(expected_slice, predicted_slice, atol=1e-3, rtol=1e-3)
|
||||
|
||||
def test_stable_diffusion_xl_peft_lora_loading_in_non_peft(self):
|
||||
sd_pipe = DiffusionPipeline.from_pretrained("hf-internal-testing/tiny-sdxl-pipe").to(torch_device)
|
||||
# This LoRA was obtained using similarly as how it's done in the training scripts.
|
||||
sd_pipe.load_lora_weights("hf-internal-testing/tiny-sdxl-lora-peft")
|
||||
|
||||
inputs = self.get_dummy_inputs()
|
||||
outputs = sd_pipe(**inputs).images
|
||||
|
||||
predicted_slice = outputs[0, -3:, -3:, -1].flatten()
|
||||
expected_slice = np.array([0.613, 0.5566, 0.54, 0.4162, 0.4042, 0.4596, 0.5374, 0.5286, 0.5038])
|
||||
|
||||
self.assertTrue(outputs.shape == (1, 64, 64, 3))
|
||||
assert np.allclose(expected_slice, predicted_slice, atol=1e-3, rtol=1e-3)
|
||||
@@ -26,7 +26,7 @@ from pytest import mark
|
||||
|
||||
from diffusers import UNet2DConditionModel
|
||||
from diffusers.models.attention_processor import CustomDiffusionAttnProcessor, IPAdapterAttnProcessor
|
||||
from diffusers.models.embeddings import ImageProjection, IPAdapterPlusImageProjection
|
||||
from diffusers.models.embeddings import ImageProjection, Resampler
|
||||
from diffusers.utils import logging
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.utils.testing_utils import (
|
||||
@@ -133,7 +133,7 @@ def create_ip_adapter_plus_state_dict(model):
|
||||
|
||||
# "image_proj" (ImageProjection layer weights)
|
||||
cross_attention_dim = model.config["cross_attention_dim"]
|
||||
image_projection = IPAdapterPlusImageProjection(
|
||||
image_projection = Resampler(
|
||||
embed_dims=cross_attention_dim, output_dims=cross_attention_dim, dim_head=32, heads=2, num_queries=4
|
||||
)
|
||||
|
||||
|
||||
@@ -31,7 +31,6 @@ from diffusers import (
|
||||
StableDiffusionXLInpaintPipeline,
|
||||
StableDiffusionXLPipeline,
|
||||
)
|
||||
from diffusers.models.attention_processor import AttnProcessor, AttnProcessor2_0
|
||||
from diffusers.utils import load_image
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
@@ -229,25 +228,6 @@ class IPAdapterSDIntegrationTests(IPAdapterNightlyTestsMixin):
|
||||
|
||||
assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4)
|
||||
|
||||
def test_unload(self):
|
||||
image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder")
|
||||
pipeline = StableDiffusionPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5", image_encoder=image_encoder, safety_checker=None, torch_dtype=self.dtype
|
||||
)
|
||||
pipeline.to(torch_device)
|
||||
pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin")
|
||||
pipeline.set_ip_adapter_scale(0.7)
|
||||
|
||||
pipeline.unload_ip_adapter()
|
||||
|
||||
assert getattr(pipeline, "image_encoder") is None
|
||||
assert getattr(pipeline, "feature_extractor") is None
|
||||
processors = [
|
||||
isinstance(attn_proc, (AttnProcessor, AttnProcessor2_0))
|
||||
for name, attn_proc in pipeline.unet.attn_processors.items()
|
||||
]
|
||||
assert processors == [True] * len(processors)
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
|
||||
Reference in New Issue
Block a user