mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-11 15:04:45 +08:00
Compare commits
8 Commits
v0.19.1-pa
...
v0.19.3-pa
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4308bc5dbb | ||
|
|
de9c72d58c | ||
|
|
7b022df49c | ||
|
|
965e52ce61 | ||
|
|
b1e52794a2 | ||
|
|
c3e3a1ee10 | ||
|
|
9cde56a729 | ||
|
|
c63d7cdba0 |
@@ -21,7 +21,7 @@ For example, to perform Image-to-Image with the SD1.5 checkpoint, you can do
|
||||
```python
|
||||
from diffusers import PipelineForImageToImage
|
||||
|
||||
pipe_i2i = PipelineForImageoImage.from_pretrained("runwayml/stable-diffusion-v1-5")
|
||||
pipe_i2i = PipelineForImageToImage.from_pretrained("runwayml/stable-diffusion-v1-5")
|
||||
```
|
||||
|
||||
It will also help you switch between tasks seamlessly using the same checkpoint without reallocating additional memory. For example, to re-use the Image-to-Image pipeline we just created for inpainting, you can do
|
||||
|
||||
@@ -38,9 +38,25 @@ You can install the libraries as follows:
|
||||
pip install transformers
|
||||
pip install accelerate
|
||||
pip install safetensors
|
||||
```
|
||||
|
||||
### Watermarker
|
||||
|
||||
We recommend to add an invisible watermark to images generating by Stable Diffusion XL, this can help with identifying if an image is machine-synthesised for downstream applications. To do so, please install
|
||||
the [invisible-watermark library](https://pypi.org/project/invisible-watermark/) via:
|
||||
|
||||
```
|
||||
pip install invisible-watermark>=0.2.0
|
||||
```
|
||||
|
||||
If the `invisible-watermark` library is installed the watermarker will be used **by default**.
|
||||
|
||||
If you have other provisions for generating or deploying images safely, you can disable the watermarker as follows:
|
||||
|
||||
```py
|
||||
pipe = StableDiffusionXLPipeline.from_pretrained(..., add_watermarker=False)
|
||||
```
|
||||
|
||||
### Text-to-Image
|
||||
|
||||
You can use SDXL as follows for *text-to-image*:
|
||||
|
||||
@@ -354,4 +354,52 @@ directly with [`~diffusers.loaders.LoraLoaderMixin.load_lora_weights`] like so:
|
||||
lora_model_id = "sayakpaul/civitai-light-shadow-lora"
|
||||
lora_filename = "light_and_shadow.safetensors"
|
||||
pipeline.load_lora_weights(lora_model_id, weight_name=lora_filename)
|
||||
```
|
||||
```
|
||||
|
||||
### Supporting Stable Diffusion XL LoRAs trained using the Kohya-trainer
|
||||
|
||||
With this [PR](https://github.com/huggingface/diffusers/pull/4287), there should now be better support for loading Kohya-style LoRAs trained on Stable Diffusion XL (SDXL).
|
||||
|
||||
Here are some example checkpoints we tried out:
|
||||
|
||||
* SDXL 0.9:
|
||||
* https://civitai.com/models/22279?modelVersionId=118556
|
||||
* https://civitai.com/models/104515/sdxlor30costumesrevue-starlight-saijoclaudine-lora
|
||||
* https://civitai.com/models/108448/daiton-sdxl-test
|
||||
* https://filebin.net/2ntfqqnapiu9q3zx/pixelbuildings128-v1.safetensors
|
||||
* SDXL 1.0:
|
||||
* https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_offset_example-lora_1.0.safetensors
|
||||
|
||||
Here is an example of how to perform inference with these checkpoints in `diffusers`:
|
||||
|
||||
```python
|
||||
from diffusers import DiffusionPipeline
|
||||
import torch
|
||||
|
||||
base_model_id = "stabilityai/stable-diffusion-xl-base-0.9"
|
||||
pipeline = DiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16).to("cuda")
|
||||
pipeline.load_lora_weights(".", weight_name="Kamepan.safetensors")
|
||||
|
||||
prompt = "anime screencap, glint, drawing, best quality, light smile, shy, a full body of a girl wearing wedding dress in the middle of the forest beneath the trees, fireflies, big eyes, 2d, cute, anime girl, waifu, cel shading, magical girl, vivid colors, (outline:1.1), manga anime artstyle, masterpiece, offical wallpaper, glint <lora:kame_sdxl_v2:1>"
|
||||
negative_prompt = "(deformed, bad quality, sketch, depth of field, blurry:1.1), grainy, bad anatomy, bad perspective, old, ugly, realistic, cartoon, disney, bad propotions"
|
||||
generator = torch.manual_seed(2947883060)
|
||||
num_inference_steps = 30
|
||||
guidance_scale = 7
|
||||
|
||||
image = pipeline(
|
||||
prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=num_inference_steps,
|
||||
generator=generator, guidance_scale=guidance_scale
|
||||
).images[0]
|
||||
image.save("Kamepan.png")
|
||||
```
|
||||
|
||||
`Kamepan.safetensors` comes from https://civitai.com/models/22279?modelVersionId=118556 .
|
||||
|
||||
If you notice carefully, the inference UX is exactly identical to what we presented in the sections above.
|
||||
|
||||
Thanks to [@isidentical](https://github.com/isidentical) for helping us on integrating this feature.
|
||||
|
||||
### Known limitations specific to the Kohya-styled LoRAs
|
||||
|
||||
* SDXL LoRAs that have both the text encoders are currently leading to weird results. We're actively investigating the issue.
|
||||
* When images don't looks similar to other UIs such ComfyUI, it can be beacause of multiple reasons as explained [here](https://github.com/huggingface/diffusers/pull/4287/#issuecomment-1655110736).
|
||||
@@ -4,6 +4,5 @@ transformers>=4.25.1
|
||||
ftfy
|
||||
tensorboard
|
||||
Jinja2
|
||||
invisible-watermark>=0.2.0
|
||||
datasets
|
||||
wandb
|
||||
|
||||
@@ -4,4 +4,3 @@ transformers>=4.25.1
|
||||
ftfy
|
||||
tensorboard
|
||||
Jinja2
|
||||
invisible-watermark>=0.2.0
|
||||
@@ -924,10 +924,10 @@ def main(args):
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
|
||||
lora_state_dict, network_alpha = LoraLoaderMixin.lora_state_dict(input_dir)
|
||||
LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alpha=network_alpha, unet=unet_)
|
||||
lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir)
|
||||
LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_)
|
||||
LoraLoaderMixin.load_lora_into_text_encoder(
|
||||
lora_state_dict, network_alpha=network_alpha, text_encoder=text_encoder_
|
||||
lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_
|
||||
)
|
||||
|
||||
accelerator.register_save_state_pre_hook(save_model_hook)
|
||||
|
||||
@@ -829,13 +829,13 @@ def main(args):
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
|
||||
lora_state_dict, network_alpha = LoraLoaderMixin.lora_state_dict(input_dir)
|
||||
LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alpha=network_alpha, unet=unet_)
|
||||
lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir)
|
||||
LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_)
|
||||
LoraLoaderMixin.load_lora_into_text_encoder(
|
||||
lora_state_dict, network_alpha=network_alpha, text_encoder=text_encoder_one_
|
||||
lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_one_
|
||||
)
|
||||
LoraLoaderMixin.load_lora_into_text_encoder(
|
||||
lora_state_dict, network_alpha=network_alpha, text_encoder=text_encoder_two_
|
||||
lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_two_
|
||||
)
|
||||
|
||||
accelerator.register_save_state_pre_hook(save_model_hook)
|
||||
|
||||
2
setup.py
2
setup.py
@@ -233,7 +233,7 @@ install_requires = [
|
||||
|
||||
setup(
|
||||
name="diffusers",
|
||||
version="0.19.1", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
version="0.19.3", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
description="Diffusers",
|
||||
long_description=open("README.md", "r", encoding="utf-8").read(),
|
||||
long_description_content_type="text/markdown",
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
__version__ = "0.19.1"
|
||||
__version__ = "0.19.3"
|
||||
|
||||
from .configuration_utils import ConfigMixin
|
||||
from .utils import (
|
||||
@@ -185,6 +185,11 @@ else:
|
||||
StableDiffusionPix2PixZeroPipeline,
|
||||
StableDiffusionSAGPipeline,
|
||||
StableDiffusionUpscalePipeline,
|
||||
StableDiffusionXLControlNetPipeline,
|
||||
StableDiffusionXLImg2ImgPipeline,
|
||||
StableDiffusionXLInpaintPipeline,
|
||||
StableDiffusionXLInstructPix2PixPipeline,
|
||||
StableDiffusionXLPipeline,
|
||||
StableUnCLIPImg2ImgPipeline,
|
||||
StableUnCLIPPipeline,
|
||||
TextToVideoSDPipeline,
|
||||
@@ -202,20 +207,6 @@ else:
|
||||
VQDiffusionPipeline,
|
||||
)
|
||||
|
||||
try:
|
||||
if not (is_torch_available() and is_transformers_available() and is_invisible_watermark_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from .utils.dummy_torch_and_transformers_and_invisible_watermark_objects import * # noqa F403
|
||||
else:
|
||||
from .pipelines import (
|
||||
StableDiffusionXLControlNetPipeline,
|
||||
StableDiffusionXLImg2ImgPipeline,
|
||||
StableDiffusionXLInpaintPipeline,
|
||||
StableDiffusionXLInstructPix2PixPipeline,
|
||||
StableDiffusionXLPipeline,
|
||||
)
|
||||
|
||||
try:
|
||||
if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
import re
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from contextlib import nullcontext
|
||||
@@ -56,7 +57,6 @@ UNET_NAME = "unet"
|
||||
|
||||
LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
|
||||
LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
|
||||
TOTAL_EXAMPLE_KEYS = 5
|
||||
|
||||
TEXT_INVERSION_NAME = "learned_embeds.bin"
|
||||
TEXT_INVERSION_NAME_SAFE = "learned_embeds.safetensors"
|
||||
@@ -257,7 +257,7 @@ class UNet2DConditionLoadersMixin:
|
||||
use_safetensors = kwargs.pop("use_safetensors", None)
|
||||
# This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
|
||||
# See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
|
||||
network_alpha = kwargs.pop("network_alpha", None)
|
||||
network_alphas = kwargs.pop("network_alphas", None)
|
||||
|
||||
if use_safetensors and not is_safetensors_available():
|
||||
raise ValueError(
|
||||
@@ -322,7 +322,7 @@ class UNet2DConditionLoadersMixin:
|
||||
attn_processors = {}
|
||||
non_attn_lora_layers = []
|
||||
|
||||
is_lora = all("lora" in k for k in state_dict.keys())
|
||||
is_lora = all(("lora" in k or k.endswith(".alpha")) for k in state_dict.keys())
|
||||
is_custom_diffusion = any("custom_diffusion" in k for k in state_dict.keys())
|
||||
|
||||
if is_lora:
|
||||
@@ -339,10 +339,25 @@ class UNet2DConditionLoadersMixin:
|
||||
state_dict = {k.replace(f"{self.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys}
|
||||
|
||||
lora_grouped_dict = defaultdict(dict)
|
||||
for key, value in state_dict.items():
|
||||
mapped_network_alphas = {}
|
||||
|
||||
all_keys = list(state_dict.keys())
|
||||
for key in all_keys:
|
||||
value = state_dict.pop(key)
|
||||
attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:])
|
||||
lora_grouped_dict[attn_processor_key][sub_key] = value
|
||||
|
||||
# Create another `mapped_network_alphas` dictionary so that we can properly map them.
|
||||
if network_alphas is not None:
|
||||
for k in network_alphas:
|
||||
if k.replace(".alpha", "") in key:
|
||||
mapped_network_alphas.update({attn_processor_key: network_alphas[k]})
|
||||
|
||||
if len(state_dict) > 0:
|
||||
raise ValueError(
|
||||
f"The state_dict has to be empty at this point but has the following keys \n\n {', '.join(state_dict.keys())}"
|
||||
)
|
||||
|
||||
for key, value_dict in lora_grouped_dict.items():
|
||||
attn_processor = self
|
||||
for sub_key in key.split("."):
|
||||
@@ -352,13 +367,27 @@ class UNet2DConditionLoadersMixin:
|
||||
# or add_{k,v,q,out_proj}_proj_lora layers.
|
||||
if "lora.down.weight" in value_dict:
|
||||
rank = value_dict["lora.down.weight"].shape[0]
|
||||
hidden_size = value_dict["lora.up.weight"].shape[0]
|
||||
|
||||
if isinstance(attn_processor, LoRACompatibleConv):
|
||||
lora = LoRAConv2dLayer(hidden_size, hidden_size, rank, network_alpha)
|
||||
in_features = attn_processor.in_channels
|
||||
out_features = attn_processor.out_channels
|
||||
kernel_size = attn_processor.kernel_size
|
||||
|
||||
lora = LoRAConv2dLayer(
|
||||
in_features=in_features,
|
||||
out_features=out_features,
|
||||
rank=rank,
|
||||
kernel_size=kernel_size,
|
||||
stride=attn_processor.stride,
|
||||
padding=attn_processor.padding,
|
||||
network_alpha=mapped_network_alphas.get(key),
|
||||
)
|
||||
elif isinstance(attn_processor, LoRACompatibleLinear):
|
||||
lora = LoRALinearLayer(
|
||||
attn_processor.in_features, attn_processor.out_features, rank, network_alpha
|
||||
attn_processor.in_features,
|
||||
attn_processor.out_features,
|
||||
rank,
|
||||
mapped_network_alphas.get(key),
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Module {key} is not a LoRACompatibleConv or LoRACompatibleLinear module.")
|
||||
@@ -366,32 +395,64 @@ class UNet2DConditionLoadersMixin:
|
||||
value_dict = {k.replace("lora.", ""): v for k, v in value_dict.items()}
|
||||
lora.load_state_dict(value_dict)
|
||||
non_attn_lora_layers.append((attn_processor, lora))
|
||||
continue
|
||||
|
||||
rank = value_dict["to_k_lora.down.weight"].shape[0]
|
||||
hidden_size = value_dict["to_k_lora.up.weight"].shape[0]
|
||||
|
||||
if isinstance(
|
||||
attn_processor, (AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0)
|
||||
):
|
||||
cross_attention_dim = value_dict["add_k_proj_lora.down.weight"].shape[1]
|
||||
attn_processor_class = LoRAAttnAddedKVProcessor
|
||||
else:
|
||||
cross_attention_dim = value_dict["to_k_lora.down.weight"].shape[1]
|
||||
if isinstance(attn_processor, (XFormersAttnProcessor, LoRAXFormersAttnProcessor)):
|
||||
attn_processor_class = LoRAXFormersAttnProcessor
|
||||
# To handle SDXL.
|
||||
rank_mapping = {}
|
||||
hidden_size_mapping = {}
|
||||
for projection_id in ["to_k", "to_q", "to_v", "to_out"]:
|
||||
rank = value_dict[f"{projection_id}_lora.down.weight"].shape[0]
|
||||
hidden_size = value_dict[f"{projection_id}_lora.up.weight"].shape[0]
|
||||
|
||||
rank_mapping.update({f"{projection_id}_lora.down.weight": rank})
|
||||
hidden_size_mapping.update({f"{projection_id}_lora.up.weight": hidden_size})
|
||||
|
||||
if isinstance(
|
||||
attn_processor, (AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0)
|
||||
):
|
||||
cross_attention_dim = value_dict["add_k_proj_lora.down.weight"].shape[1]
|
||||
attn_processor_class = LoRAAttnAddedKVProcessor
|
||||
else:
|
||||
attn_processor_class = (
|
||||
LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
|
||||
cross_attention_dim = value_dict["to_k_lora.down.weight"].shape[1]
|
||||
if isinstance(attn_processor, (XFormersAttnProcessor, LoRAXFormersAttnProcessor)):
|
||||
attn_processor_class = LoRAXFormersAttnProcessor
|
||||
else:
|
||||
attn_processor_class = (
|
||||
LoRAAttnProcessor2_0
|
||||
if hasattr(F, "scaled_dot_product_attention")
|
||||
else LoRAAttnProcessor
|
||||
)
|
||||
|
||||
if attn_processor_class is not LoRAAttnAddedKVProcessor:
|
||||
attn_processors[key] = attn_processor_class(
|
||||
rank=rank_mapping.get("to_k_lora.down.weight"),
|
||||
hidden_size=hidden_size_mapping.get("to_k_lora.up.weight"),
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
network_alpha=mapped_network_alphas.get(key),
|
||||
q_rank=rank_mapping.get("to_q_lora.down.weight"),
|
||||
q_hidden_size=hidden_size_mapping.get("to_q_lora.up.weight"),
|
||||
v_rank=rank_mapping.get("to_v_lora.down.weight"),
|
||||
v_hidden_size=hidden_size_mapping.get("to_v_lora.up.weight"),
|
||||
out_rank=rank_mapping.get("to_out_lora.down.weight"),
|
||||
out_hidden_size=hidden_size_mapping.get("to_out_lora.up.weight"),
|
||||
# rank=rank_mapping.get("to_k_lora.down.weight", None),
|
||||
# hidden_size=hidden_size_mapping.get("to_k_lora.up.weight", None),
|
||||
# q_rank=rank_mapping.get("to_q_lora.down.weight", None),
|
||||
# q_hidden_size=hidden_size_mapping.get("to_q_lora.up.weight", None),
|
||||
# v_rank=rank_mapping.get("to_v_lora.down.weight", None),
|
||||
# v_hidden_size=hidden_size_mapping.get("to_v_lora.up.weight", None),
|
||||
# out_rank=rank_mapping.get("to_out_lora.down.weight", None),
|
||||
# out_hidden_size=hidden_size_mapping.get("to_out_lora.up.weight", None),
|
||||
)
|
||||
else:
|
||||
attn_processors[key] = attn_processor_class(
|
||||
rank=rank_mapping.get("to_k_lora.down.weight", None),
|
||||
hidden_size=hidden_size_mapping.get("to_k_lora.up.weight", None),
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
network_alpha=mapped_network_alphas.get(key),
|
||||
)
|
||||
|
||||
attn_processors[key] = attn_processor_class(
|
||||
hidden_size=hidden_size,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
rank=rank,
|
||||
network_alpha=network_alpha,
|
||||
)
|
||||
attn_processors[key].load_state_dict(value_dict)
|
||||
attn_processors[key].load_state_dict(value_dict)
|
||||
|
||||
elif is_custom_diffusion:
|
||||
custom_diffusion_grouped_dict = defaultdict(dict)
|
||||
for key, value in state_dict.items():
|
||||
@@ -434,8 +495,10 @@ class UNet2DConditionLoadersMixin:
|
||||
|
||||
# set ff layers
|
||||
for target_module, lora_layer in non_attn_lora_layers:
|
||||
if hasattr(target_module, "set_lora_layer"):
|
||||
target_module.set_lora_layer(lora_layer)
|
||||
target_module.set_lora_layer(lora_layer)
|
||||
# It should raise an error if we don't have a set lora here
|
||||
# if hasattr(target_module, "set_lora_layer"):
|
||||
# target_module.set_lora_layer(lora_layer)
|
||||
|
||||
def save_attn_procs(
|
||||
self,
|
||||
@@ -880,11 +943,11 @@ class LoraLoaderMixin:
|
||||
kwargs (`dict`, *optional*):
|
||||
See [`~loaders.LoraLoaderMixin.lora_state_dict`].
|
||||
"""
|
||||
state_dict, network_alpha = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
|
||||
self.load_lora_into_unet(state_dict, network_alpha=network_alpha, unet=self.unet)
|
||||
state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
|
||||
self.load_lora_into_unet(state_dict, network_alphas=network_alphas, unet=self.unet)
|
||||
self.load_lora_into_text_encoder(
|
||||
state_dict,
|
||||
network_alpha=network_alpha,
|
||||
network_alphas=network_alphas,
|
||||
text_encoder=self.text_encoder,
|
||||
lora_scale=self.lora_scale,
|
||||
)
|
||||
@@ -896,7 +959,7 @@ class LoraLoaderMixin:
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Return state dict for lora weights
|
||||
Return state dict for lora weights and the network alphas.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
@@ -957,6 +1020,7 @@ class LoraLoaderMixin:
|
||||
revision = kwargs.pop("revision", None)
|
||||
subfolder = kwargs.pop("subfolder", None)
|
||||
weight_name = kwargs.pop("weight_name", None)
|
||||
unet_config = kwargs.pop("unet_config", None)
|
||||
use_safetensors = kwargs.pop("use_safetensors", None)
|
||||
|
||||
if use_safetensors and not is_safetensors_available():
|
||||
@@ -1018,53 +1082,158 @@ class LoraLoaderMixin:
|
||||
else:
|
||||
state_dict = pretrained_model_name_or_path_or_dict
|
||||
|
||||
# Convert kohya-ss Style LoRA attn procs to diffusers attn procs
|
||||
network_alpha = None
|
||||
if all((k.startswith("lora_te_") or k.startswith("lora_unet_")) for k in state_dict.keys()):
|
||||
state_dict, network_alpha = cls._convert_kohya_lora_to_diffusers(state_dict)
|
||||
network_alphas = None
|
||||
if all(
|
||||
(
|
||||
k.startswith("lora_te_")
|
||||
or k.startswith("lora_unet_")
|
||||
or k.startswith("lora_te1_")
|
||||
or k.startswith("lora_te2_")
|
||||
)
|
||||
for k in state_dict.keys()
|
||||
):
|
||||
# Map SDXL blocks correctly.
|
||||
if unet_config is not None:
|
||||
# use unet config to remap block numbers
|
||||
state_dict = cls._map_sgm_blocks_to_diffusers(state_dict, unet_config)
|
||||
state_dict, network_alphas = cls._convert_kohya_lora_to_diffusers(state_dict)
|
||||
|
||||
return state_dict, network_alpha
|
||||
return state_dict, network_alphas
|
||||
|
||||
@classmethod
|
||||
def load_lora_into_unet(cls, state_dict, network_alpha, unet):
|
||||
def _map_sgm_blocks_to_diffusers(cls, state_dict, unet_config, delimiter="_", block_slice_pos=5):
|
||||
is_all_unet = all(k.startswith("lora_unet") for k in state_dict)
|
||||
new_state_dict = {}
|
||||
inner_block_map = ["resnets", "attentions", "upsamplers"]
|
||||
|
||||
# Retrieves # of down, mid and up blocks
|
||||
input_block_ids, middle_block_ids, output_block_ids = set(), set(), set()
|
||||
for layer in state_dict:
|
||||
if "text" not in layer:
|
||||
layer_id = int(layer.split(delimiter)[:block_slice_pos][-1])
|
||||
if "input_blocks" in layer:
|
||||
input_block_ids.add(layer_id)
|
||||
elif "middle_block" in layer:
|
||||
middle_block_ids.add(layer_id)
|
||||
elif "output_blocks" in layer:
|
||||
output_block_ids.add(layer_id)
|
||||
else:
|
||||
raise ValueError("Checkpoint not supported")
|
||||
|
||||
input_blocks = {
|
||||
layer_id: [key for key in state_dict if f"input_blocks{delimiter}{layer_id}" in key]
|
||||
for layer_id in input_block_ids
|
||||
}
|
||||
middle_blocks = {
|
||||
layer_id: [key for key in state_dict if f"middle_block{delimiter}{layer_id}" in key]
|
||||
for layer_id in middle_block_ids
|
||||
}
|
||||
output_blocks = {
|
||||
layer_id: [key for key in state_dict if f"output_blocks{delimiter}{layer_id}" in key]
|
||||
for layer_id in output_block_ids
|
||||
}
|
||||
|
||||
# Rename keys accordingly
|
||||
for i in input_block_ids:
|
||||
block_id = (i - 1) // (unet_config.layers_per_block + 1)
|
||||
layer_in_block_id = (i - 1) % (unet_config.layers_per_block + 1)
|
||||
|
||||
for key in input_blocks[i]:
|
||||
inner_block_id = int(key.split(delimiter)[block_slice_pos])
|
||||
inner_block_key = inner_block_map[inner_block_id] if "op" not in key else "downsamplers"
|
||||
inner_layers_in_block = str(layer_in_block_id) if "op" not in key else "0"
|
||||
new_key = delimiter.join(
|
||||
key.split(delimiter)[: block_slice_pos - 1]
|
||||
+ [str(block_id), inner_block_key, inner_layers_in_block]
|
||||
+ key.split(delimiter)[block_slice_pos + 1 :]
|
||||
)
|
||||
new_state_dict[new_key] = state_dict.pop(key)
|
||||
|
||||
for i in middle_block_ids:
|
||||
key_part = None
|
||||
if i == 0:
|
||||
key_part = [inner_block_map[0], "0"]
|
||||
elif i == 1:
|
||||
key_part = [inner_block_map[1], "0"]
|
||||
elif i == 2:
|
||||
key_part = [inner_block_map[0], "1"]
|
||||
else:
|
||||
raise ValueError(f"Invalid middle block id {i}.")
|
||||
|
||||
for key in middle_blocks[i]:
|
||||
new_key = delimiter.join(
|
||||
key.split(delimiter)[: block_slice_pos - 1] + key_part + key.split(delimiter)[block_slice_pos:]
|
||||
)
|
||||
new_state_dict[new_key] = state_dict.pop(key)
|
||||
|
||||
for i in output_block_ids:
|
||||
block_id = i // (unet_config.layers_per_block + 1)
|
||||
layer_in_block_id = i % (unet_config.layers_per_block + 1)
|
||||
|
||||
for key in output_blocks[i]:
|
||||
inner_block_id = int(key.split(delimiter)[block_slice_pos])
|
||||
inner_block_key = inner_block_map[inner_block_id]
|
||||
inner_layers_in_block = str(layer_in_block_id) if inner_block_id < 2 else "0"
|
||||
new_key = delimiter.join(
|
||||
key.split(delimiter)[: block_slice_pos - 1]
|
||||
+ [str(block_id), inner_block_key, inner_layers_in_block]
|
||||
+ key.split(delimiter)[block_slice_pos + 1 :]
|
||||
)
|
||||
new_state_dict[new_key] = state_dict.pop(key)
|
||||
|
||||
if is_all_unet and len(state_dict) > 0:
|
||||
raise ValueError("At this point all state dict entries have to be converted.")
|
||||
else:
|
||||
# Remaining is the text encoder state dict.
|
||||
for k, v in state_dict.items():
|
||||
new_state_dict.update({k: v})
|
||||
|
||||
return new_state_dict
|
||||
|
||||
@classmethod
|
||||
def load_lora_into_unet(cls, state_dict, network_alphas, unet):
|
||||
"""
|
||||
This will load the LoRA layers specified in `state_dict` into `unet`
|
||||
This will load the LoRA layers specified in `state_dict` into `unet`.
|
||||
|
||||
Parameters:
|
||||
state_dict (`dict`):
|
||||
A standard state dict containing the lora layer parameters. The keys can either be indexed directly
|
||||
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
|
||||
encoder lora layers.
|
||||
network_alpha (`float`):
|
||||
network_alphas (`Dict[str, float]`):
|
||||
See `LoRALinearLayer` for more details.
|
||||
unet (`UNet2DConditionModel`):
|
||||
The UNet model to load the LoRA layers into.
|
||||
"""
|
||||
|
||||
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
|
||||
# then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
|
||||
# their prefixes.
|
||||
keys = list(state_dict.keys())
|
||||
|
||||
if all(key.startswith(cls.unet_name) or key.startswith(cls.text_encoder_name) for key in keys):
|
||||
# Load the layers corresponding to UNet.
|
||||
unet_keys = [k for k in keys if k.startswith(cls.unet_name)]
|
||||
logger.info(f"Loading {cls.unet_name}.")
|
||||
unet_lora_state_dict = {
|
||||
k.replace(f"{cls.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys
|
||||
}
|
||||
unet.load_attn_procs(unet_lora_state_dict, network_alpha=network_alpha)
|
||||
|
||||
# Otherwise, we're dealing with the old format. This means the `state_dict` should only
|
||||
# contain the module names of the `unet` as its keys WITHOUT any prefix.
|
||||
elif not all(
|
||||
key.startswith(cls.unet_name) or key.startswith(cls.text_encoder_name) for key in state_dict.keys()
|
||||
):
|
||||
unet.load_attn_procs(state_dict, network_alpha=network_alpha)
|
||||
unet_keys = [k for k in keys if k.startswith(cls.unet_name)]
|
||||
state_dict = {k.replace(f"{cls.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys}
|
||||
|
||||
if network_alphas is not None:
|
||||
alpha_keys = [k for k in network_alphas.keys() if k.startswith(cls.unet_name)]
|
||||
network_alphas = {
|
||||
k.replace(f"{cls.unet_name}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
|
||||
}
|
||||
|
||||
else:
|
||||
# Otherwise, we're dealing with the old format. This means the `state_dict` should only
|
||||
# contain the module names of the `unet` as its keys WITHOUT any prefix.
|
||||
warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet'.{module_name}: params for module_name, params in old_state_dict.items()}`."
|
||||
warnings.warn(warn_message)
|
||||
|
||||
# load loras into unet
|
||||
unet.load_attn_procs(state_dict, network_alphas=network_alphas)
|
||||
|
||||
@classmethod
|
||||
def load_lora_into_text_encoder(cls, state_dict, network_alpha, text_encoder, prefix=None, lora_scale=1.0):
|
||||
def load_lora_into_text_encoder(cls, state_dict, network_alphas, text_encoder, prefix=None, lora_scale=1.0):
|
||||
"""
|
||||
This will load the LoRA layers specified in `state_dict` into `text_encoder`
|
||||
|
||||
@@ -1072,7 +1241,7 @@ class LoraLoaderMixin:
|
||||
state_dict (`dict`):
|
||||
A standard state dict containing the lora layer parameters. The key should be prefixed with an
|
||||
additional `text_encoder` to distinguish between unet lora layers.
|
||||
network_alpha (`float`):
|
||||
network_alphas (`Dict[str, float]`):
|
||||
See `LoRALinearLayer` for more details.
|
||||
text_encoder (`CLIPTextModel`):
|
||||
The text encoder model to load the LoRA layers into.
|
||||
@@ -1141,14 +1310,19 @@ class LoraLoaderMixin:
|
||||
].shape[1]
|
||||
patch_mlp = any(".mlp." in key for key in text_encoder_lora_state_dict.keys())
|
||||
|
||||
cls._modify_text_encoder(text_encoder, lora_scale, network_alpha, rank=rank, patch_mlp=patch_mlp)
|
||||
cls._modify_text_encoder(
|
||||
text_encoder,
|
||||
lora_scale,
|
||||
network_alphas,
|
||||
rank=rank,
|
||||
patch_mlp=patch_mlp,
|
||||
)
|
||||
|
||||
# set correct dtype & device
|
||||
text_encoder_lora_state_dict = {
|
||||
k: v.to(device=text_encoder.device, dtype=text_encoder.dtype)
|
||||
for k, v in text_encoder_lora_state_dict.items()
|
||||
}
|
||||
|
||||
load_state_dict_results = text_encoder.load_state_dict(text_encoder_lora_state_dict, strict=False)
|
||||
if len(load_state_dict_results.unexpected_keys) != 0:
|
||||
raise ValueError(
|
||||
@@ -1183,7 +1357,7 @@ class LoraLoaderMixin:
|
||||
cls,
|
||||
text_encoder,
|
||||
lora_scale=1,
|
||||
network_alpha=None,
|
||||
network_alphas=None,
|
||||
rank=4,
|
||||
dtype=None,
|
||||
patch_mlp=False,
|
||||
@@ -1196,37 +1370,46 @@ class LoraLoaderMixin:
|
||||
cls._remove_text_encoder_monkey_patch_classmethod(text_encoder)
|
||||
|
||||
lora_parameters = []
|
||||
network_alphas = {} if network_alphas is None else network_alphas
|
||||
|
||||
for name, attn_module in text_encoder_attn_modules(text_encoder):
|
||||
query_alpha = network_alphas.get(name + ".k.proj.alpha")
|
||||
key_alpha = network_alphas.get(name + ".q.proj.alpha")
|
||||
value_alpha = network_alphas.get(name + ".v.proj.alpha")
|
||||
proj_alpha = network_alphas.get(name + ".out.proj.alpha")
|
||||
|
||||
for _, attn_module in text_encoder_attn_modules(text_encoder):
|
||||
attn_module.q_proj = PatchedLoraProjection(
|
||||
attn_module.q_proj, lora_scale, network_alpha, rank=rank, dtype=dtype
|
||||
attn_module.q_proj, lora_scale, network_alpha=query_alpha, rank=rank, dtype=dtype
|
||||
)
|
||||
lora_parameters.extend(attn_module.q_proj.lora_linear_layer.parameters())
|
||||
|
||||
attn_module.k_proj = PatchedLoraProjection(
|
||||
attn_module.k_proj, lora_scale, network_alpha, rank=rank, dtype=dtype
|
||||
attn_module.k_proj, lora_scale, network_alpha=key_alpha, rank=rank, dtype=dtype
|
||||
)
|
||||
lora_parameters.extend(attn_module.k_proj.lora_linear_layer.parameters())
|
||||
|
||||
attn_module.v_proj = PatchedLoraProjection(
|
||||
attn_module.v_proj, lora_scale, network_alpha, rank=rank, dtype=dtype
|
||||
attn_module.v_proj, lora_scale, network_alpha=value_alpha, rank=rank, dtype=dtype
|
||||
)
|
||||
lora_parameters.extend(attn_module.v_proj.lora_linear_layer.parameters())
|
||||
|
||||
attn_module.out_proj = PatchedLoraProjection(
|
||||
attn_module.out_proj, lora_scale, network_alpha, rank=rank, dtype=dtype
|
||||
attn_module.out_proj, lora_scale, network_alpha=proj_alpha, rank=rank, dtype=dtype
|
||||
)
|
||||
lora_parameters.extend(attn_module.out_proj.lora_linear_layer.parameters())
|
||||
|
||||
if patch_mlp:
|
||||
for _, mlp_module in text_encoder_mlp_modules(text_encoder):
|
||||
for name, mlp_module in text_encoder_mlp_modules(text_encoder):
|
||||
fc1_alpha = network_alphas.get(name + ".fc1.alpha")
|
||||
fc2_alpha = network_alphas.get(name + ".fc2.alpha")
|
||||
|
||||
mlp_module.fc1 = PatchedLoraProjection(
|
||||
mlp_module.fc1, lora_scale, network_alpha, rank=rank, dtype=dtype
|
||||
mlp_module.fc1, lora_scale, network_alpha=fc1_alpha, rank=rank, dtype=dtype
|
||||
)
|
||||
lora_parameters.extend(mlp_module.fc1.lora_linear_layer.parameters())
|
||||
|
||||
mlp_module.fc2 = PatchedLoraProjection(
|
||||
mlp_module.fc2, lora_scale, network_alpha, rank=rank, dtype=dtype
|
||||
mlp_module.fc2, lora_scale, network_alpha=fc2_alpha, rank=rank, dtype=dtype
|
||||
)
|
||||
lora_parameters.extend(mlp_module.fc2.lora_linear_layer.parameters())
|
||||
|
||||
@@ -1333,77 +1516,163 @@ class LoraLoaderMixin:
|
||||
def _convert_kohya_lora_to_diffusers(cls, state_dict):
|
||||
unet_state_dict = {}
|
||||
te_state_dict = {}
|
||||
network_alpha = None
|
||||
unloaded_keys = []
|
||||
te2_state_dict = {}
|
||||
network_alphas = {}
|
||||
|
||||
for key, value in state_dict.items():
|
||||
if "hada" in key or "skip" in key:
|
||||
unloaded_keys.append(key)
|
||||
elif "lora_down" in key:
|
||||
lora_name = key.split(".")[0]
|
||||
lora_name_up = lora_name + ".lora_up.weight"
|
||||
lora_name_alpha = lora_name + ".alpha"
|
||||
if lora_name_alpha in state_dict:
|
||||
alpha = state_dict[lora_name_alpha].item()
|
||||
if network_alpha is None:
|
||||
network_alpha = alpha
|
||||
elif network_alpha != alpha:
|
||||
raise ValueError("Network alpha is not consistent")
|
||||
# every down weight has a corresponding up weight and potentially an alpha weight
|
||||
lora_keys = [k for k in state_dict.keys() if k.endswith("lora_down.weight")]
|
||||
for key in lora_keys:
|
||||
lora_name = key.split(".")[0]
|
||||
lora_name_up = lora_name + ".lora_up.weight"
|
||||
lora_name_alpha = lora_name + ".alpha"
|
||||
|
||||
if lora_name.startswith("lora_unet_"):
|
||||
diffusers_name = key.replace("lora_unet_", "").replace("_", ".")
|
||||
# if lora_name_alpha in state_dict:
|
||||
# alpha = state_dict.pop(lora_name_alpha).item()
|
||||
# network_alphas.update({lora_name_alpha: alpha})
|
||||
|
||||
if lora_name.startswith("lora_unet_"):
|
||||
diffusers_name = key.replace("lora_unet_", "").replace("_", ".")
|
||||
|
||||
if "input.blocks" in diffusers_name:
|
||||
diffusers_name = diffusers_name.replace("input.blocks", "down_blocks")
|
||||
else:
|
||||
diffusers_name = diffusers_name.replace("down.blocks", "down_blocks")
|
||||
|
||||
if "middle.block" in diffusers_name:
|
||||
diffusers_name = diffusers_name.replace("middle.block", "mid_block")
|
||||
else:
|
||||
diffusers_name = diffusers_name.replace("mid.block", "mid_block")
|
||||
if "output.blocks" in diffusers_name:
|
||||
diffusers_name = diffusers_name.replace("output.blocks", "up_blocks")
|
||||
else:
|
||||
diffusers_name = diffusers_name.replace("up.blocks", "up_blocks")
|
||||
diffusers_name = diffusers_name.replace("transformer.blocks", "transformer_blocks")
|
||||
diffusers_name = diffusers_name.replace("to.q.lora", "to_q_lora")
|
||||
diffusers_name = diffusers_name.replace("to.k.lora", "to_k_lora")
|
||||
diffusers_name = diffusers_name.replace("to.v.lora", "to_v_lora")
|
||||
diffusers_name = diffusers_name.replace("to.out.0.lora", "to_out_lora")
|
||||
diffusers_name = diffusers_name.replace("proj.in", "proj_in")
|
||||
diffusers_name = diffusers_name.replace("proj.out", "proj_out")
|
||||
if "transformer_blocks" in diffusers_name:
|
||||
if "attn1" in diffusers_name or "attn2" in diffusers_name:
|
||||
diffusers_name = diffusers_name.replace("attn1", "attn1.processor")
|
||||
diffusers_name = diffusers_name.replace("attn2", "attn2.processor")
|
||||
unet_state_dict[diffusers_name] = value
|
||||
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up]
|
||||
elif "ff" in diffusers_name:
|
||||
unet_state_dict[diffusers_name] = value
|
||||
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up]
|
||||
elif any(key in diffusers_name for key in ("proj_in", "proj_out")):
|
||||
unet_state_dict[diffusers_name] = value
|
||||
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up]
|
||||
|
||||
elif lora_name.startswith("lora_te_"):
|
||||
diffusers_name = key.replace("lora_te_", "").replace("_", ".")
|
||||
diffusers_name = diffusers_name.replace("text.model", "text_model")
|
||||
diffusers_name = diffusers_name.replace("self.attn", "self_attn")
|
||||
diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora")
|
||||
diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora")
|
||||
diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora")
|
||||
diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora")
|
||||
if "self_attn" in diffusers_name:
|
||||
te_state_dict[diffusers_name] = value
|
||||
te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up]
|
||||
elif "mlp" in diffusers_name:
|
||||
# Be aware that this is the new diffusers convention and the rest of the code might
|
||||
# not utilize it yet.
|
||||
diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.")
|
||||
te_state_dict[diffusers_name] = value
|
||||
te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up]
|
||||
diffusers_name = diffusers_name.replace("transformer.blocks", "transformer_blocks")
|
||||
diffusers_name = diffusers_name.replace("to.q.lora", "to_q_lora")
|
||||
diffusers_name = diffusers_name.replace("to.k.lora", "to_k_lora")
|
||||
diffusers_name = diffusers_name.replace("to.v.lora", "to_v_lora")
|
||||
diffusers_name = diffusers_name.replace("to.out.0.lora", "to_out_lora")
|
||||
diffusers_name = diffusers_name.replace("proj.in", "proj_in")
|
||||
diffusers_name = diffusers_name.replace("proj.out", "proj_out")
|
||||
diffusers_name = diffusers_name.replace("emb.layers", "time_emb_proj")
|
||||
|
||||
logger.info("Kohya-style checkpoint detected.")
|
||||
if len(unloaded_keys) > 0:
|
||||
example_unloaded_keys = ", ".join(x for x in unloaded_keys[:TOTAL_EXAMPLE_KEYS])
|
||||
logger.warning(
|
||||
f"There are some keys (such as: {example_unloaded_keys}) in the checkpoints we don't provide support for."
|
||||
# SDXL specificity.
|
||||
if "emb" in diffusers_name:
|
||||
pattern = r"\.\d+(?=\D*$)"
|
||||
diffusers_name = re.sub(pattern, "", diffusers_name, count=1)
|
||||
if ".in." in diffusers_name:
|
||||
diffusers_name = diffusers_name.replace("in.layers.2", "conv1")
|
||||
if ".out." in diffusers_name:
|
||||
diffusers_name = diffusers_name.replace("out.layers.3", "conv2")
|
||||
if "downsamplers" in diffusers_name or "upsamplers" in diffusers_name:
|
||||
diffusers_name = diffusers_name.replace("op", "conv")
|
||||
if "skip" in diffusers_name:
|
||||
diffusers_name = diffusers_name.replace("skip.connection", "conv_shortcut")
|
||||
|
||||
if "transformer_blocks" in diffusers_name:
|
||||
if "attn1" in diffusers_name or "attn2" in diffusers_name:
|
||||
diffusers_name = diffusers_name.replace("attn1", "attn1.processor")
|
||||
diffusers_name = diffusers_name.replace("attn2", "attn2.processor")
|
||||
unet_state_dict[diffusers_name] = state_dict.pop(key)
|
||||
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
|
||||
elif "ff" in diffusers_name:
|
||||
unet_state_dict[diffusers_name] = state_dict.pop(key)
|
||||
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
|
||||
elif any(key in diffusers_name for key in ("proj_in", "proj_out")):
|
||||
unet_state_dict[diffusers_name] = state_dict.pop(key)
|
||||
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
|
||||
else:
|
||||
unet_state_dict[diffusers_name] = state_dict.pop(key)
|
||||
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
|
||||
|
||||
elif lora_name.startswith("lora_te_"):
|
||||
diffusers_name = key.replace("lora_te_", "").replace("_", ".")
|
||||
diffusers_name = diffusers_name.replace("text.model", "text_model")
|
||||
diffusers_name = diffusers_name.replace("self.attn", "self_attn")
|
||||
diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora")
|
||||
diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora")
|
||||
diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora")
|
||||
diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora")
|
||||
if "self_attn" in diffusers_name:
|
||||
te_state_dict[diffusers_name] = state_dict.pop(key)
|
||||
te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
|
||||
elif "mlp" in diffusers_name:
|
||||
# Be aware that this is the new diffusers convention and the rest of the code might
|
||||
# not utilize it yet.
|
||||
diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.")
|
||||
te_state_dict[diffusers_name] = state_dict.pop(key)
|
||||
te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
|
||||
|
||||
# (sayakpaul): Duplicate code. Needs to be cleaned.
|
||||
elif lora_name.startswith("lora_te1_"):
|
||||
diffusers_name = key.replace("lora_te1_", "").replace("_", ".")
|
||||
diffusers_name = diffusers_name.replace("text.model", "text_model")
|
||||
diffusers_name = diffusers_name.replace("self.attn", "self_attn")
|
||||
diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora")
|
||||
diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora")
|
||||
diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora")
|
||||
diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora")
|
||||
if "self_attn" in diffusers_name:
|
||||
te_state_dict[diffusers_name] = state_dict.pop(key)
|
||||
te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
|
||||
elif "mlp" in diffusers_name:
|
||||
# Be aware that this is the new diffusers convention and the rest of the code might
|
||||
# not utilize it yet.
|
||||
diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.")
|
||||
te_state_dict[diffusers_name] = state_dict.pop(key)
|
||||
te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
|
||||
|
||||
# (sayakpaul): Duplicate code. Needs to be cleaned.
|
||||
elif lora_name.startswith("lora_te2_"):
|
||||
diffusers_name = key.replace("lora_te2_", "").replace("_", ".")
|
||||
diffusers_name = diffusers_name.replace("text.model", "text_model")
|
||||
diffusers_name = diffusers_name.replace("self.attn", "self_attn")
|
||||
diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora")
|
||||
diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora")
|
||||
diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora")
|
||||
diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora")
|
||||
if "self_attn" in diffusers_name:
|
||||
te2_state_dict[diffusers_name] = state_dict.pop(key)
|
||||
te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
|
||||
elif "mlp" in diffusers_name:
|
||||
# Be aware that this is the new diffusers convention and the rest of the code might
|
||||
# not utilize it yet.
|
||||
diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.")
|
||||
te2_state_dict[diffusers_name] = state_dict.pop(key)
|
||||
te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
|
||||
|
||||
# Rename the alphas so that they can be mapped appropriately.
|
||||
if lora_name_alpha in state_dict:
|
||||
alpha = state_dict.pop(lora_name_alpha).item()
|
||||
if lora_name_alpha.startswith("lora_unet_"):
|
||||
prefix = "unet."
|
||||
elif lora_name_alpha.startswith(("lora_te_", "lora_te1_")):
|
||||
prefix = "text_encoder."
|
||||
else:
|
||||
prefix = "text_encoder_2."
|
||||
new_name = prefix + diffusers_name.split(".lora.")[0] + ".alpha"
|
||||
network_alphas.update({new_name: alpha})
|
||||
|
||||
if len(state_dict) > 0:
|
||||
raise ValueError(
|
||||
f"The following keys have not been correctly be renamed: \n\n {', '.join(state_dict.keys())}"
|
||||
)
|
||||
|
||||
unet_state_dict = {f"{UNET_NAME}.{module_name}": params for module_name, params in unet_state_dict.items()}
|
||||
te_state_dict = {f"{TEXT_ENCODER_NAME}.{module_name}": params for module_name, params in te_state_dict.items()}
|
||||
logger.info("Kohya-style checkpoint detected.")
|
||||
unet_state_dict = {f"{cls.unet_name}.{module_name}": params for module_name, params in unet_state_dict.items()}
|
||||
te_state_dict = {
|
||||
f"{cls.text_encoder_name}.{module_name}": params for module_name, params in te_state_dict.items()
|
||||
}
|
||||
te2_state_dict = (
|
||||
{f"text_encoder_2.{module_name}": params for module_name, params in te2_state_dict.items()}
|
||||
if len(te2_state_dict) > 0
|
||||
else None
|
||||
)
|
||||
if te2_state_dict is not None:
|
||||
te_state_dict.update(te2_state_dict)
|
||||
|
||||
new_state_dict = {**unet_state_dict, **te_state_dict}
|
||||
return new_state_dict, network_alpha
|
||||
return new_state_dict, network_alphas
|
||||
|
||||
def unload_lora_weights(self):
|
||||
"""
|
||||
|
||||
@@ -521,17 +521,32 @@ class LoRAAttnProcessor(nn.Module):
|
||||
Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None):
|
||||
def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None, **kwargs):
|
||||
super().__init__()
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.cross_attention_dim = cross_attention_dim
|
||||
self.rank = rank
|
||||
|
||||
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
||||
q_rank = kwargs.pop("q_rank", None)
|
||||
q_hidden_size = kwargs.pop("q_hidden_size", None)
|
||||
q_rank = q_rank if q_rank is not None else rank
|
||||
q_hidden_size = q_hidden_size if q_hidden_size is not None else hidden_size
|
||||
|
||||
v_rank = kwargs.pop("v_rank", None)
|
||||
v_hidden_size = kwargs.pop("v_hidden_size", None)
|
||||
v_rank = v_rank if v_rank is not None else rank
|
||||
v_hidden_size = v_hidden_size if v_hidden_size is not None else hidden_size
|
||||
|
||||
out_rank = kwargs.pop("out_rank", None)
|
||||
out_hidden_size = kwargs.pop("out_hidden_size", None)
|
||||
out_rank = out_rank if out_rank is not None else rank
|
||||
out_hidden_size = out_hidden_size if out_hidden_size is not None else hidden_size
|
||||
|
||||
self.to_q_lora = LoRALinearLayer(q_hidden_size, q_hidden_size, q_rank, network_alpha)
|
||||
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
|
||||
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
|
||||
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
||||
self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha)
|
||||
self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha)
|
||||
|
||||
def __call__(
|
||||
self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0, temb=None
|
||||
@@ -1144,7 +1159,13 @@ class LoRAXFormersAttnProcessor(nn.Module):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, hidden_size, cross_attention_dim, rank=4, attention_op: Optional[Callable] = None, network_alpha=None
|
||||
self,
|
||||
hidden_size,
|
||||
cross_attention_dim,
|
||||
rank=4,
|
||||
attention_op: Optional[Callable] = None,
|
||||
network_alpha=None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -1153,10 +1174,25 @@ class LoRAXFormersAttnProcessor(nn.Module):
|
||||
self.rank = rank
|
||||
self.attention_op = attention_op
|
||||
|
||||
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
||||
q_rank = kwargs.pop("q_rank", None)
|
||||
q_hidden_size = kwargs.pop("q_hidden_size", None)
|
||||
q_rank = q_rank if q_rank is not None else rank
|
||||
q_hidden_size = q_hidden_size if q_hidden_size is not None else hidden_size
|
||||
|
||||
v_rank = kwargs.pop("v_rank", None)
|
||||
v_hidden_size = kwargs.pop("v_hidden_size", None)
|
||||
v_rank = v_rank if v_rank is not None else rank
|
||||
v_hidden_size = v_hidden_size if v_hidden_size is not None else hidden_size
|
||||
|
||||
out_rank = kwargs.pop("out_rank", None)
|
||||
out_hidden_size = kwargs.pop("out_hidden_size", None)
|
||||
out_rank = out_rank if out_rank is not None else rank
|
||||
out_hidden_size = out_hidden_size if out_hidden_size is not None else hidden_size
|
||||
|
||||
self.to_q_lora = LoRALinearLayer(q_hidden_size, q_hidden_size, q_rank, network_alpha)
|
||||
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
|
||||
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
|
||||
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
||||
self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha)
|
||||
self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha)
|
||||
|
||||
def __call__(
|
||||
self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0, temb=None
|
||||
@@ -1231,7 +1267,7 @@ class LoRAAttnProcessor2_0(nn.Module):
|
||||
Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None):
|
||||
def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None, **kwargs):
|
||||
super().__init__()
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
||||
@@ -1240,10 +1276,25 @@ class LoRAAttnProcessor2_0(nn.Module):
|
||||
self.cross_attention_dim = cross_attention_dim
|
||||
self.rank = rank
|
||||
|
||||
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
||||
q_rank = kwargs.pop("q_rank", None)
|
||||
q_hidden_size = kwargs.pop("q_hidden_size", None)
|
||||
q_rank = q_rank if q_rank is not None else rank
|
||||
q_hidden_size = q_hidden_size if q_hidden_size is not None else hidden_size
|
||||
|
||||
v_rank = kwargs.pop("v_rank", None)
|
||||
v_hidden_size = kwargs.pop("v_hidden_size", None)
|
||||
v_rank = v_rank if v_rank is not None else rank
|
||||
v_hidden_size = v_hidden_size if v_hidden_size is not None else hidden_size
|
||||
|
||||
out_rank = kwargs.pop("out_rank", None)
|
||||
out_hidden_size = kwargs.pop("out_hidden_size", None)
|
||||
out_rank = out_rank if out_rank is not None else rank
|
||||
out_hidden_size = out_hidden_size if out_hidden_size is not None else hidden_size
|
||||
|
||||
self.to_q_lora = LoRALinearLayer(q_hidden_size, q_hidden_size, q_rank, network_alpha)
|
||||
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
|
||||
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
|
||||
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
||||
self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha)
|
||||
self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha)
|
||||
|
||||
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0):
|
||||
residual = hidden_states
|
||||
|
||||
@@ -49,14 +49,19 @@ class LoRALinearLayer(nn.Module):
|
||||
|
||||
|
||||
class LoRAConv2dLayer(nn.Module):
|
||||
def __init__(self, in_features, out_features, rank=4, network_alpha=None):
|
||||
def __init__(
|
||||
self, in_features, out_features, rank=4, kernel_size=(1, 1), stride=(1, 1), padding=0, network_alpha=None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if rank > min(in_features, out_features):
|
||||
raise ValueError(f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}")
|
||||
|
||||
self.down = nn.Conv2d(in_features, rank, (1, 1), (1, 1), bias=False)
|
||||
self.up = nn.Conv2d(rank, out_features, (1, 1), (1, 1), bias=False)
|
||||
self.down = nn.Conv2d(in_features, rank, kernel_size=kernel_size, stride=stride, padding=padding, bias=False)
|
||||
# according to the official kohya_ss trainer kernel_size are always fixed for the up layer
|
||||
# # see: https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L129
|
||||
self.up = nn.Conv2d(rank, out_features, kernel_size=(1, 1), stride=(1, 1), bias=False)
|
||||
|
||||
# This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
|
||||
# See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
|
||||
self.network_alpha = network_alpha
|
||||
|
||||
@@ -23,6 +23,7 @@ import torch.nn.functional as F
|
||||
from .activations import get_activation
|
||||
from .attention import AdaGroupNorm
|
||||
from .attention_processor import SpatialNorm
|
||||
from .lora import LoRACompatibleConv, LoRACompatibleLinear
|
||||
|
||||
|
||||
class Upsample1D(nn.Module):
|
||||
@@ -126,7 +127,7 @@ class Upsample2D(nn.Module):
|
||||
if use_conv_transpose:
|
||||
conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1)
|
||||
elif use_conv:
|
||||
conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=1)
|
||||
conv = LoRACompatibleConv(self.channels, self.out_channels, 3, padding=1)
|
||||
|
||||
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
|
||||
if name == "conv":
|
||||
@@ -196,7 +197,7 @@ class Downsample2D(nn.Module):
|
||||
self.name = name
|
||||
|
||||
if use_conv:
|
||||
conv = nn.Conv2d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
|
||||
conv = LoRACompatibleConv(self.channels, self.out_channels, 3, stride=stride, padding=padding)
|
||||
else:
|
||||
assert self.channels == self.out_channels
|
||||
conv = nn.AvgPool2d(kernel_size=stride, stride=stride)
|
||||
@@ -534,13 +535,13 @@ class ResnetBlock2D(nn.Module):
|
||||
else:
|
||||
self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
|
||||
|
||||
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
self.conv1 = LoRACompatibleConv(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
if temb_channels is not None:
|
||||
if self.time_embedding_norm == "default":
|
||||
self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels)
|
||||
self.time_emb_proj = LoRACompatibleLinear(temb_channels, out_channels)
|
||||
elif self.time_embedding_norm == "scale_shift":
|
||||
self.time_emb_proj = torch.nn.Linear(temb_channels, 2 * out_channels)
|
||||
self.time_emb_proj = LoRACompatibleLinear(temb_channels, 2 * out_channels)
|
||||
elif self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
|
||||
self.time_emb_proj = None
|
||||
else:
|
||||
@@ -557,7 +558,7 @@ class ResnetBlock2D(nn.Module):
|
||||
|
||||
self.dropout = torch.nn.Dropout(dropout)
|
||||
conv_2d_out_channels = conv_2d_out_channels or out_channels
|
||||
self.conv2 = torch.nn.Conv2d(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1)
|
||||
self.conv2 = LoRACompatibleConv(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
self.nonlinearity = get_activation(non_linearity)
|
||||
|
||||
@@ -583,7 +584,7 @@ class ResnetBlock2D(nn.Module):
|
||||
|
||||
self.conv_shortcut = None
|
||||
if self.use_in_shortcut:
|
||||
self.conv_shortcut = torch.nn.Conv2d(
|
||||
self.conv_shortcut = LoRACompatibleConv(
|
||||
in_channels, conv_2d_out_channels, kernel_size=1, stride=1, padding=0, bias=conv_shortcut_bias
|
||||
)
|
||||
|
||||
|
||||
@@ -23,7 +23,7 @@ from ..models.embeddings import ImagePositionalEmbeddings
|
||||
from ..utils import BaseOutput, deprecate
|
||||
from .attention import BasicTransformerBlock
|
||||
from .embeddings import PatchEmbed
|
||||
from .lora import LoRACompatibleConv
|
||||
from .lora import LoRACompatibleConv, LoRACompatibleLinear
|
||||
from .modeling_utils import ModelMixin
|
||||
|
||||
|
||||
@@ -137,7 +137,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
||||
|
||||
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
if use_linear_projection:
|
||||
self.proj_in = nn.Linear(in_channels, inner_dim)
|
||||
self.proj_in = LoRACompatibleLinear(in_channels, inner_dim)
|
||||
else:
|
||||
self.proj_in = LoRACompatibleConv(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
|
||||
elif self.is_input_vectorized:
|
||||
@@ -193,7 +193,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
||||
if self.is_input_continuous:
|
||||
# TODO: should use out_channels for continuous projections
|
||||
if use_linear_projection:
|
||||
self.proj_out = nn.Linear(inner_dim, in_channels)
|
||||
self.proj_out = LoRACompatibleLinear(inner_dim, in_channels)
|
||||
else:
|
||||
self.proj_out = LoRACompatibleConv(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
elif self.is_input_vectorized:
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from ..utils import (
|
||||
OptionalDependencyNotAvailable,
|
||||
is_flax_available,
|
||||
is_invisible_watermark_available,
|
||||
is_k_diffusion_available,
|
||||
is_librosa_available,
|
||||
is_note_seq_available,
|
||||
@@ -51,6 +50,7 @@ else:
|
||||
StableDiffusionControlNetImg2ImgPipeline,
|
||||
StableDiffusionControlNetInpaintPipeline,
|
||||
StableDiffusionControlNetPipeline,
|
||||
StableDiffusionXLControlNetPipeline,
|
||||
)
|
||||
from .deepfloyd_if import (
|
||||
IFImg2ImgPipeline,
|
||||
@@ -108,6 +108,12 @@ else:
|
||||
StableUnCLIPPipeline,
|
||||
)
|
||||
from .stable_diffusion_safe import StableDiffusionPipelineSafe
|
||||
from .stable_diffusion_xl import (
|
||||
StableDiffusionXLImg2ImgPipeline,
|
||||
StableDiffusionXLInpaintPipeline,
|
||||
StableDiffusionXLInstructPix2PixPipeline,
|
||||
StableDiffusionXLPipeline,
|
||||
)
|
||||
from .t2i_adapter import StableDiffusionAdapterPipeline
|
||||
from .text_to_video_synthesis import TextToVideoSDPipeline, TextToVideoZeroPipeline, VideoToVideoSDPipeline
|
||||
from .unclip import UnCLIPImageVariationPipeline, UnCLIPPipeline
|
||||
@@ -121,20 +127,6 @@ else:
|
||||
from .vq_diffusion import VQDiffusionPipeline
|
||||
|
||||
|
||||
try:
|
||||
if not (is_torch_available() and is_transformers_available() and is_invisible_watermark_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils.dummy_torch_and_transformers_and_invisible_watermark_objects import * # noqa F403
|
||||
else:
|
||||
from .controlnet import StableDiffusionXLControlNetPipeline
|
||||
from .stable_diffusion_xl import (
|
||||
StableDiffusionXLImg2ImgPipeline,
|
||||
StableDiffusionXLInpaintPipeline,
|
||||
StableDiffusionXLInstructPix2PixPipeline,
|
||||
StableDiffusionXLPipeline,
|
||||
)
|
||||
|
||||
try:
|
||||
if not is_onnx_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
|
||||
@@ -1,21 +1,11 @@
|
||||
from ...utils import (
|
||||
OptionalDependencyNotAvailable,
|
||||
is_flax_available,
|
||||
is_invisible_watermark_available,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available() and is_invisible_watermark_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_and_invisible_watermark_objects import * # noqa F403
|
||||
else:
|
||||
from .pipeline_controlnet_sd_xl import StableDiffusionXLControlNetPipeline
|
||||
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
@@ -26,6 +16,7 @@ else:
|
||||
from .pipeline_controlnet import StableDiffusionControlNetPipeline
|
||||
from .pipeline_controlnet_img2img import StableDiffusionControlNetImg2ImgPipeline
|
||||
from .pipeline_controlnet_inpaint import StableDiffusionControlNetInpaintPipeline
|
||||
from .pipeline_controlnet_sd_xl import StableDiffusionXLControlNetPipeline
|
||||
|
||||
|
||||
if is_transformers_available() and is_flax_available():
|
||||
|
||||
@@ -22,6 +22,8 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
|
||||
|
||||
from diffusers.utils.import_utils import is_invisible_watermark_available
|
||||
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
|
||||
@@ -42,7 +44,11 @@ from ...utils import (
|
||||
)
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from ..stable_diffusion_xl import StableDiffusionXLPipelineOutput
|
||||
from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
|
||||
|
||||
|
||||
if is_invisible_watermark_available():
|
||||
from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
|
||||
|
||||
from .multicontrolnet import MultiControlNetModel
|
||||
|
||||
|
||||
@@ -109,6 +115,7 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa
|
||||
controlnet: ControlNetModel,
|
||||
scheduler: KarrasDiffusionSchedulers,
|
||||
force_zeros_for_empty_prompt: bool = True,
|
||||
add_watermarker: Optional[bool] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -130,7 +137,13 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa
|
||||
self.control_image_processor = VaeImageProcessor(
|
||||
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
|
||||
)
|
||||
self.watermark = StableDiffusionXLWatermarker()
|
||||
add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
|
||||
|
||||
if add_watermarker:
|
||||
self.watermark = StableDiffusionXLWatermarker()
|
||||
else:
|
||||
self.watermark = None
|
||||
|
||||
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
|
||||
@@ -995,7 +1008,10 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa
|
||||
image = latents
|
||||
return StableDiffusionXLPipelineOutput(images=image)
|
||||
|
||||
image = self.watermark.apply_watermark(image)
|
||||
# apply watermark if available
|
||||
if self.watermark is not None:
|
||||
image = self.watermark.apply_watermark(image)
|
||||
|
||||
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||
|
||||
# Offload last model to CPU
|
||||
|
||||
@@ -494,6 +494,7 @@ class DiffusionPipeline(ConfigMixin):
|
||||
_optional_components = []
|
||||
_exclude_from_cpu_offload = []
|
||||
_load_connected_pipes = False
|
||||
_is_onnx = False
|
||||
|
||||
def register_modules(self, **kwargs):
|
||||
# import it here to avoid circular import
|
||||
@@ -839,6 +840,11 @@ class DiffusionPipeline(ConfigMixin):
|
||||
If set to `None`, the safetensors weights are downloaded if they're available **and** if the
|
||||
safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
|
||||
weights. If set to `False`, safetensors weights are not loaded.
|
||||
use_onnx (`bool`, *optional*, defaults to `None`):
|
||||
If set to `True`, ONNX weights will always be downloaded if present. If set to `False`, ONNX weights
|
||||
will never be downloaded. By default `use_onnx` defaults to the `_is_onnx` class attribute which is
|
||||
`False` for non-ONNX pipelines and `True` for ONNX pipelines. ONNX weights include both files ending
|
||||
with `.onnx` and `.pb`.
|
||||
kwargs (remaining dictionary of keyword arguments, *optional*):
|
||||
Can be used to overwrite load and saveable variables (the pipeline components of the specific pipeline
|
||||
class). The overwritten components are passed directly to the pipelines `__init__` method. See example
|
||||
@@ -1268,6 +1274,15 @@ class DiffusionPipeline(ConfigMixin):
|
||||
variant (`str`, *optional*):
|
||||
Load weights from a specified variant filename such as `"fp16"` or `"ema"`. This is ignored when
|
||||
loading `from_flax`.
|
||||
use_safetensors (`bool`, *optional*, defaults to `None`):
|
||||
If set to `None`, the safetensors weights are downloaded if they're available **and** if the
|
||||
safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
|
||||
weights. If set to `False`, safetensors weights are not loaded.
|
||||
use_onnx (`bool`, *optional*, defaults to `False`):
|
||||
If set to `True`, ONNX weights will always be downloaded if present. If set to `False`, ONNX weights
|
||||
will never be downloaded. By default `use_onnx` defaults to the `_is_onnx` class attribute which is
|
||||
`False` for non-ONNX pipelines and `True` for ONNX pipelines. ONNX weights include both files ending
|
||||
with `.onnx` and `.pb`.
|
||||
|
||||
Returns:
|
||||
`os.PathLike`:
|
||||
@@ -1293,6 +1308,7 @@ class DiffusionPipeline(ConfigMixin):
|
||||
custom_revision = kwargs.pop("custom_revision", None)
|
||||
variant = kwargs.pop("variant", None)
|
||||
use_safetensors = kwargs.pop("use_safetensors", None)
|
||||
use_onnx = kwargs.pop("use_onnx", None)
|
||||
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
|
||||
|
||||
if use_safetensors and not is_safetensors_available():
|
||||
@@ -1364,7 +1380,7 @@ class DiffusionPipeline(ConfigMixin):
|
||||
pretrained_model_name, use_auth_token, variant, revision, model_filenames
|
||||
)
|
||||
|
||||
model_folder_names = {os.path.split(f)[0] for f in model_filenames}
|
||||
model_folder_names = {os.path.split(f)[0] for f in model_filenames if os.path.split(f)[0] in folder_names}
|
||||
|
||||
# all filenames compatible with variant will be added
|
||||
allow_patterns = list(model_filenames)
|
||||
@@ -1411,6 +1427,10 @@ class DiffusionPipeline(ConfigMixin):
|
||||
):
|
||||
ignore_patterns = ["*.bin", "*.msgpack"]
|
||||
|
||||
use_onnx = use_onnx if use_onnx is not None else pipeline_class._is_onnx
|
||||
if not use_onnx:
|
||||
ignore_patterns += ["*.onnx", "*.pb"]
|
||||
|
||||
safetensors_variant_filenames = {f for f in variant_filenames if f.endswith(".safetensors")}
|
||||
safetensors_model_filenames = {f for f in model_filenames if f.endswith(".safetensors")}
|
||||
if (
|
||||
@@ -1423,6 +1443,10 @@ class DiffusionPipeline(ConfigMixin):
|
||||
else:
|
||||
ignore_patterns = ["*.safetensors", "*.msgpack"]
|
||||
|
||||
use_onnx = use_onnx if use_onnx is not None else pipeline_class._is_onnx
|
||||
if not use_onnx:
|
||||
ignore_patterns += ["*.onnx", "*.pb"]
|
||||
|
||||
bin_variant_filenames = {f for f in variant_filenames if f.endswith(".bin")}
|
||||
bin_model_filenames = {f for f in model_filenames if f.endswith(".bin")}
|
||||
if len(bin_variant_filenames) > 0 and bin_model_filenames != bin_variant_filenames:
|
||||
|
||||
@@ -41,6 +41,7 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
|
||||
feature_extractor: CLIPImageProcessor
|
||||
|
||||
_optional_components = ["safety_checker", "feature_extractor"]
|
||||
_is_onnx = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -98,6 +98,7 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
feature_extractor: CLIPImageProcessor
|
||||
|
||||
_optional_components = ["safety_checker", "feature_extractor"]
|
||||
_is_onnx = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -90,6 +90,7 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
feature_extractor: CLIPImageProcessor
|
||||
|
||||
_optional_components = ["safety_checker", "feature_extractor"]
|
||||
_is_onnx = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -67,6 +67,7 @@ class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
"""
|
||||
_optional_components = ["safety_checker", "feature_extractor"]
|
||||
_is_onnx = True
|
||||
|
||||
vae_encoder: OnnxRuntimeModel
|
||||
vae_decoder: OnnxRuntimeModel
|
||||
|
||||
@@ -46,6 +46,8 @@ def preprocess(image):
|
||||
|
||||
|
||||
class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
|
||||
_is_onnx = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vae: OnnxRuntimeModel,
|
||||
|
||||
@@ -7,7 +7,6 @@ import PIL
|
||||
from ...utils import (
|
||||
BaseOutput,
|
||||
OptionalDependencyNotAvailable,
|
||||
is_invisible_watermark_available,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
@@ -28,10 +27,10 @@ class StableDiffusionXLPipelineOutput(BaseOutput):
|
||||
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available() and is_invisible_watermark_available()):
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_and_invisible_watermark_objects import * # noqa F403
|
||||
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
|
||||
else:
|
||||
from .pipeline_stable_diffusion_xl import StableDiffusionXLPipeline
|
||||
from .pipeline_stable_diffusion_xl_img2img import StableDiffusionXLImg2ImgPipeline
|
||||
|
||||
@@ -32,13 +32,17 @@ from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import (
|
||||
is_accelerate_available,
|
||||
is_accelerate_version,
|
||||
is_invisible_watermark_available,
|
||||
logging,
|
||||
randn_tensor,
|
||||
replace_example_docstring,
|
||||
)
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from . import StableDiffusionXLPipelineOutput
|
||||
from .watermark import StableDiffusionXLWatermarker
|
||||
|
||||
|
||||
if is_invisible_watermark_available():
|
||||
from .watermark import StableDiffusionXLWatermarker
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
@@ -84,11 +88,11 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
|
||||
|
||||
In addition the pipeline inherits the following loading methods:
|
||||
- *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
|
||||
- *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
|
||||
- *LoRA*: [`StableDiffusionXLPipeline.load_lora_weights`]
|
||||
- *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`]
|
||||
|
||||
as well as the following saving methods:
|
||||
- *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
|
||||
- *LoRA*: [`loaders.StableDiffusionXLPipeline.save_lora_weights`]
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
@@ -125,6 +129,7 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: KarrasDiffusionSchedulers,
|
||||
force_zeros_for_empty_prompt: bool = True,
|
||||
add_watermarker: Optional[bool] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -142,7 +147,12 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
self.default_sample_size = self.unet.config.sample_size
|
||||
|
||||
self.watermark = StableDiffusionXLWatermarker()
|
||||
add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
|
||||
|
||||
if add_watermarker:
|
||||
self.watermark = StableDiffusionXLWatermarker()
|
||||
else:
|
||||
self.watermark = None
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
|
||||
def enable_vae_slicing(self):
|
||||
@@ -839,7 +849,10 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
|
||||
image = latents
|
||||
return StableDiffusionXLPipelineOutput(images=image)
|
||||
|
||||
image = self.watermark.apply_watermark(image)
|
||||
# apply watermark if available
|
||||
if self.watermark is not None:
|
||||
image = self.watermark.apply_watermark(image)
|
||||
|
||||
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||
|
||||
# Offload last model to CPU
|
||||
@@ -853,14 +866,21 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
|
||||
|
||||
# Overrride to properly handle the loading and unloading of the additional text encoder.
|
||||
def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
|
||||
state_dict, network_alpha = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
|
||||
self.load_lora_into_unet(state_dict, network_alpha=network_alpha, unet=self.unet)
|
||||
# We could have accessed the unet config from `lora_state_dict()` too. We pass
|
||||
# it here explicitly to be able to tell that it's coming from an SDXL
|
||||
# pipeline.
|
||||
state_dict, network_alphas = self.lora_state_dict(
|
||||
pretrained_model_name_or_path_or_dict,
|
||||
unet_config=self.unet.config,
|
||||
**kwargs,
|
||||
)
|
||||
self.load_lora_into_unet(state_dict, network_alphas=network_alphas, unet=self.unet)
|
||||
|
||||
text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
|
||||
if len(text_encoder_state_dict) > 0:
|
||||
self.load_lora_into_text_encoder(
|
||||
text_encoder_state_dict,
|
||||
network_alpha=network_alpha,
|
||||
network_alphas=network_alphas,
|
||||
text_encoder=self.text_encoder,
|
||||
prefix="text_encoder",
|
||||
lora_scale=self.lora_scale,
|
||||
@@ -870,7 +890,7 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
|
||||
if len(text_encoder_2_state_dict) > 0:
|
||||
self.load_lora_into_text_encoder(
|
||||
text_encoder_2_state_dict,
|
||||
network_alpha=network_alpha,
|
||||
network_alphas=network_alphas,
|
||||
text_encoder=self.text_encoder_2,
|
||||
prefix="text_encoder_2",
|
||||
lora_scale=self.lora_scale,
|
||||
|
||||
@@ -33,13 +33,17 @@ from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import (
|
||||
is_accelerate_available,
|
||||
is_accelerate_version,
|
||||
is_invisible_watermark_available,
|
||||
logging,
|
||||
randn_tensor,
|
||||
replace_example_docstring,
|
||||
)
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from . import StableDiffusionXLPipelineOutput
|
||||
from .watermark import StableDiffusionXLWatermarker
|
||||
|
||||
|
||||
if is_invisible_watermark_available():
|
||||
from .watermark import StableDiffusionXLWatermarker
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
@@ -131,6 +135,7 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, L
|
||||
scheduler: KarrasDiffusionSchedulers,
|
||||
requires_aesthetics_score: bool = False,
|
||||
force_zeros_for_empty_prompt: bool = True,
|
||||
add_watermarker: Optional[bool] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -148,7 +153,12 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, L
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
|
||||
self.watermark = StableDiffusionXLWatermarker()
|
||||
add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
|
||||
|
||||
if add_watermarker:
|
||||
self.watermark = StableDiffusionXLWatermarker()
|
||||
else:
|
||||
self.watermark = None
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
|
||||
def enable_vae_slicing(self):
|
||||
@@ -906,15 +916,17 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, L
|
||||
negative_aesthetic_score,
|
||||
dtype=prompt_embeds.dtype,
|
||||
)
|
||||
add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
|
||||
add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1)
|
||||
add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0)
|
||||
|
||||
prompt_embeds = prompt_embeds.to(device)
|
||||
add_text_embeds = add_text_embeds.to(device)
|
||||
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
|
||||
add_time_ids = add_time_ids.to(device)
|
||||
|
||||
# 9. Denoising loop
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
@@ -988,7 +1000,10 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, L
|
||||
image = latents
|
||||
return StableDiffusionXLPipelineOutput(images=image)
|
||||
|
||||
image = self.watermark.apply_watermark(image)
|
||||
# apply watermark if available
|
||||
if self.watermark is not None:
|
||||
image = self.watermark.apply_watermark(image)
|
||||
|
||||
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||
|
||||
# Offload last model to CPU
|
||||
|
||||
@@ -30,10 +30,20 @@ from ...models.attention_processor import (
|
||||
XFormersAttnProcessor,
|
||||
)
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import is_accelerate_available, is_accelerate_version, logging, randn_tensor, replace_example_docstring
|
||||
from ...utils import (
|
||||
is_accelerate_available,
|
||||
is_accelerate_version,
|
||||
is_invisible_watermark_available,
|
||||
logging,
|
||||
randn_tensor,
|
||||
replace_example_docstring,
|
||||
)
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from . import StableDiffusionXLPipelineOutput
|
||||
from .watermark import StableDiffusionXLWatermarker
|
||||
|
||||
|
||||
if is_invisible_watermark_available():
|
||||
from .watermark import StableDiffusionXLWatermarker
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
@@ -265,6 +275,7 @@ class StableDiffusionXLInpaintPipeline(
|
||||
scheduler: KarrasDiffusionSchedulers,
|
||||
requires_aesthetics_score: bool = False,
|
||||
force_zeros_for_empty_prompt: bool = True,
|
||||
add_watermarker: Optional[bool] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -282,7 +293,12 @@ class StableDiffusionXLInpaintPipeline(
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
|
||||
self.watermark = StableDiffusionXLWatermarker()
|
||||
add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
|
||||
|
||||
if add_watermarker:
|
||||
self.watermark = StableDiffusionXLWatermarker()
|
||||
else:
|
||||
self.watermark = None
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
|
||||
def enable_vae_slicing(self):
|
||||
@@ -1168,15 +1184,17 @@ class StableDiffusionXLInpaintPipeline(
|
||||
negative_aesthetic_score,
|
||||
dtype=prompt_embeds.dtype,
|
||||
)
|
||||
add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
|
||||
add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1)
|
||||
add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0)
|
||||
|
||||
prompt_embeds = prompt_embeds.to(device)
|
||||
add_text_embeds = add_text_embeds.to(device)
|
||||
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
|
||||
add_time_ids = add_time_ids.to(device)
|
||||
|
||||
# 11. Denoising loop
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
@@ -1264,6 +1282,10 @@ class StableDiffusionXLInpaintPipeline(
|
||||
else:
|
||||
return StableDiffusionXLPipelineOutput(images=latents)
|
||||
|
||||
# apply watermark if available
|
||||
if self.watermark is not None:
|
||||
image = self.watermark.apply_watermark(image)
|
||||
|
||||
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||
|
||||
# Offload last model to CPU
|
||||
|
||||
@@ -34,12 +34,16 @@ from ...utils import (
|
||||
deprecate,
|
||||
is_accelerate_available,
|
||||
is_accelerate_version,
|
||||
is_invisible_watermark_available,
|
||||
logging,
|
||||
randn_tensor,
|
||||
)
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from . import StableDiffusionXLPipelineOutput
|
||||
from .watermark import StableDiffusionXLWatermarker
|
||||
|
||||
|
||||
if is_invisible_watermark_available():
|
||||
from .watermark import StableDiffusionXLWatermarker
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
@@ -109,6 +113,7 @@ class StableDiffusionXLInstructPix2PixPipeline(DiffusionPipeline, FromSingleFile
|
||||
scheduler: KarrasDiffusionSchedulers,
|
||||
requires_aesthetics_score: bool = False,
|
||||
force_zeros_for_empty_prompt: bool = True,
|
||||
add_watermarker: Optional[bool] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -128,7 +133,12 @@ class StableDiffusionXLInstructPix2PixPipeline(DiffusionPipeline, FromSingleFile
|
||||
|
||||
self.vae.config.force_upcast = True # force the VAE to be in float32 mode, as it overflows in float16
|
||||
|
||||
self.watermark = StableDiffusionXLWatermarker()
|
||||
add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
|
||||
|
||||
if add_watermarker:
|
||||
self.watermark = StableDiffusionXLWatermarker()
|
||||
else:
|
||||
self.watermark = None
|
||||
|
||||
def enable_vae_slicing(self):
|
||||
r"""
|
||||
@@ -811,6 +821,7 @@ class StableDiffusionXLInstructPix2PixPipeline(DiffusionPipeline, FromSingleFile
|
||||
negative_aesthetic_score,
|
||||
dtype=prompt_embeds.dtype,
|
||||
)
|
||||
add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1)
|
||||
|
||||
original_prompt_embeds_len = len(prompt_embeds)
|
||||
original_add_text_embeds_len = len(add_text_embeds)
|
||||
@@ -819,6 +830,7 @@ class StableDiffusionXLInstructPix2PixPipeline(DiffusionPipeline, FromSingleFile
|
||||
if do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([prompt_embeds, negative_prompt_embeds], dim=0)
|
||||
add_text_embeds = torch.cat([add_text_embeds, negative_pooled_prompt_embeds], dim=0)
|
||||
add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1)
|
||||
add_time_ids = torch.cat([add_time_ids, add_neg_time_ids], dim=0)
|
||||
|
||||
# Make dimensions consistent
|
||||
@@ -828,7 +840,7 @@ class StableDiffusionXLInstructPix2PixPipeline(DiffusionPipeline, FromSingleFile
|
||||
|
||||
prompt_embeds = prompt_embeds.to(device).to(torch.float32)
|
||||
add_text_embeds = add_text_embeds.to(device).to(torch.float32)
|
||||
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
|
||||
add_time_ids = add_time_ids.to(device)
|
||||
|
||||
# 11. Denoising loop
|
||||
self.unet = self.unet.to(torch.float32)
|
||||
@@ -906,7 +918,10 @@ class StableDiffusionXLInstructPix2PixPipeline(DiffusionPipeline, FromSingleFile
|
||||
image = latents
|
||||
return StableDiffusionXLPipelineOutput(images=image)
|
||||
|
||||
image = self.watermark.apply_watermark(image)
|
||||
# apply watermark if available
|
||||
if self.watermark is not None:
|
||||
image = self.watermark.apply_watermark(image)
|
||||
|
||||
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||
|
||||
# Offload last model to CPU
|
||||
|
||||
@@ -1,6 +1,11 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
from imwatermark import WatermarkEncoder
|
||||
|
||||
from ...utils import is_invisible_watermark_available
|
||||
|
||||
|
||||
if is_invisible_watermark_available():
|
||||
from imwatermark import WatermarkEncoder
|
||||
|
||||
|
||||
# Copied from https://github.com/Stability-AI/generative-models/blob/613af104c6b85184091d42d374fef420eddb356d/scripts/demo/streamlit_helpers.py#L66
|
||||
|
||||
@@ -1,77 +0,0 @@
|
||||
# This file is autogenerated by the command `make fix-copies`, do not edit.
|
||||
from ..utils import DummyObject, requires_backends
|
||||
|
||||
|
||||
class StableDiffusionXLControlNetPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers", "invisible_watermark"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers", "invisible_watermark"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers", "invisible_watermark"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers", "invisible_watermark"])
|
||||
|
||||
|
||||
class StableDiffusionXLImg2ImgPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers", "invisible_watermark"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers", "invisible_watermark"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers", "invisible_watermark"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers", "invisible_watermark"])
|
||||
|
||||
|
||||
class StableDiffusionXLInpaintPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers", "invisible_watermark"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers", "invisible_watermark"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers", "invisible_watermark"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers", "invisible_watermark"])
|
||||
|
||||
|
||||
class StableDiffusionXLInstructPix2PixPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers", "invisible_watermark"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers", "invisible_watermark"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers", "invisible_watermark"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers", "invisible_watermark"])
|
||||
|
||||
|
||||
class StableDiffusionXLPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers", "invisible_watermark"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers", "invisible_watermark"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers", "invisible_watermark"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers", "invisible_watermark"])
|
||||
@@ -827,6 +827,81 @@ class StableDiffusionUpscalePipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class StableDiffusionXLControlNetPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class StableDiffusionXLImg2ImgPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class StableDiffusionXLInpaintPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class StableDiffusionXLInstructPix2PixPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class StableDiffusionXLPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class StableUnCLIPImg2ImgPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
|
||||
@@ -737,8 +737,7 @@ class LoraIntegrationTests(unittest.TestCase):
|
||||
).images
|
||||
|
||||
images = images[0, -3:, -3:, -1].flatten()
|
||||
|
||||
expected = np.array([0.3636, 0.3708, 0.3694, 0.3679, 0.3829, 0.3677, 0.3692, 0.3688, 0.3292])
|
||||
expected = np.array([0.3725, 0.3767, 0.3761, 0.3796, 0.3827, 0.3763, 0.3831, 0.3809, 0.3392])
|
||||
|
||||
self.assertTrue(np.allclose(images, expected, atol=1e-4))
|
||||
|
||||
|
||||
@@ -64,7 +64,7 @@ class StableDiffusionXLImg2ImgPipelineFastTests(PipelineLatentTesterMixin, Pipel
|
||||
addition_embed_type="text_time",
|
||||
addition_time_embed_dim=8,
|
||||
transformer_layers_per_block=(1, 2),
|
||||
projection_class_embeddings_input_dim=80, # 6 * 8 + 32
|
||||
projection_class_embeddings_input_dim=72, # 5 * 8 + 32
|
||||
cross_attention_dim=64 if not skip_first_text_encoder else 32,
|
||||
)
|
||||
scheduler = EulerDiscreteScheduler(
|
||||
@@ -113,9 +113,18 @@ class StableDiffusionXLImg2ImgPipelineFastTests(PipelineLatentTesterMixin, Pipel
|
||||
"tokenizer": tokenizer if not skip_first_text_encoder else None,
|
||||
"text_encoder_2": text_encoder_2,
|
||||
"tokenizer_2": tokenizer_2,
|
||||
"requires_aesthetics_score": True,
|
||||
}
|
||||
return components
|
||||
|
||||
def test_components_function(self):
|
||||
init_components = self.get_dummy_components()
|
||||
init_components.pop("requires_aesthetics_score")
|
||||
pipe = self.pipeline_class(**init_components)
|
||||
|
||||
self.assertTrue(hasattr(pipe, "components"))
|
||||
self.assertTrue(set(pipe.components.keys()) == set(init_components.keys()))
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
|
||||
image = image / 2 + 0.5
|
||||
@@ -147,7 +156,7 @@ class StableDiffusionXLImg2ImgPipelineFastTests(PipelineLatentTesterMixin, Pipel
|
||||
|
||||
assert image.shape == (1, 32, 32, 3)
|
||||
|
||||
expected_slice = np.array([0.4656, 0.4840, 0.4439, 0.6698, 0.5574, 0.4524, 0.5799, 0.5943, 0.5165])
|
||||
expected_slice = np.array([0.4664, 0.4886, 0.4403, 0.6902, 0.5592, 0.4534, 0.5931, 0.5951, 0.5224])
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
@@ -165,7 +174,7 @@ class StableDiffusionXLImg2ImgPipelineFastTests(PipelineLatentTesterMixin, Pipel
|
||||
|
||||
assert image.shape == (1, 32, 32, 3)
|
||||
|
||||
expected_slice = np.array([0.4676, 0.4865, 0.4335, 0.6715, 0.5578, 0.4497, 0.5847, 0.5967, 0.5198])
|
||||
expected_slice = np.array([0.4578, 0.4981, 0.4301, 0.6454, 0.5588, 0.4442, 0.5678, 0.5940, 0.5176])
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
|
||||
@@ -66,7 +66,7 @@ class StableDiffusionXLInpaintPipelineFastTests(PipelineLatentTesterMixin, Pipel
|
||||
addition_embed_type="text_time",
|
||||
addition_time_embed_dim=8,
|
||||
transformer_layers_per_block=(1, 2),
|
||||
projection_class_embeddings_input_dim=80, # 6 * 8 + 32
|
||||
projection_class_embeddings_input_dim=72, # 5 * 8 + 32
|
||||
cross_attention_dim=64 if not skip_first_text_encoder else 32,
|
||||
)
|
||||
scheduler = EulerDiscreteScheduler(
|
||||
@@ -115,6 +115,7 @@ class StableDiffusionXLInpaintPipelineFastTests(PipelineLatentTesterMixin, Pipel
|
||||
"tokenizer": tokenizer if not skip_first_text_encoder else None,
|
||||
"text_encoder_2": text_encoder_2,
|
||||
"tokenizer_2": tokenizer_2,
|
||||
"requires_aesthetics_score": True,
|
||||
}
|
||||
return components
|
||||
|
||||
@@ -142,6 +143,14 @@ class StableDiffusionXLInpaintPipelineFastTests(PipelineLatentTesterMixin, Pipel
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_components_function(self):
|
||||
init_components = self.get_dummy_components()
|
||||
init_components.pop("requires_aesthetics_score")
|
||||
pipe = self.pipeline_class(**init_components)
|
||||
|
||||
self.assertTrue(hasattr(pipe, "components"))
|
||||
self.assertTrue(set(pipe.components.keys()) == set(init_components.keys()))
|
||||
|
||||
def test_stable_diffusion_xl_inpaint_euler(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
@@ -155,7 +164,7 @@ class StableDiffusionXLInpaintPipelineFastTests(PipelineLatentTesterMixin, Pipel
|
||||
|
||||
assert image.shape == (1, 64, 64, 3)
|
||||
|
||||
expected_slice = np.array([0.6965, 0.5584, 0.5693, 0.5739, 0.6092, 0.6620, 0.5902, 0.5612, 0.5319])
|
||||
expected_slice = np.array([0.8029, 0.5523, 0.5825, 0.6003, 0.6702, 0.7018, 0.6369, 0.5955, 0.5123])
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
@@ -250,10 +259,9 @@ class StableDiffusionXLInpaintPipelineFastTests(PipelineLatentTesterMixin, Pipel
|
||||
image = sd_pipe(**inputs).images
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
print(torch.from_numpy(image_slice).flatten())
|
||||
assert image.shape == (1, 64, 64, 3)
|
||||
|
||||
expected_slice = np.array([0.9106, 0.6563, 0.6766, 0.6537, 0.6709, 0.7367, 0.6537, 0.5937, 0.5418])
|
||||
expected_slice = np.array([0.7045, 0.4838, 0.5454, 0.6270, 0.6168, 0.6717, 0.6484, 0.5681, 0.4922])
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
|
||||
@@ -68,7 +68,7 @@ class StableDiffusionXLInstructPix2PixPipelineFastTests(
|
||||
addition_embed_type="text_time",
|
||||
addition_time_embed_dim=8,
|
||||
transformer_layers_per_block=(1, 2),
|
||||
projection_class_embeddings_input_dim=80, # 6 * 8 + 32
|
||||
projection_class_embeddings_input_dim=72, # 5 * 8 + 32
|
||||
cross_attention_dim=64,
|
||||
)
|
||||
|
||||
@@ -118,8 +118,7 @@ class StableDiffusionXLInstructPix2PixPipelineFastTests(
|
||||
"tokenizer": tokenizer,
|
||||
"text_encoder_2": text_encoder_2,
|
||||
"tokenizer_2": tokenizer_2,
|
||||
# "safety_checker": None,
|
||||
# "feature_extractor": None,
|
||||
"requires_aesthetics_score": True,
|
||||
}
|
||||
return components
|
||||
|
||||
@@ -141,6 +140,14 @@ class StableDiffusionXLInstructPix2PixPipelineFastTests(
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_components_function(self):
|
||||
init_components = self.get_dummy_components()
|
||||
init_components.pop("requires_aesthetics_score")
|
||||
pipe = self.pipeline_class(**init_components)
|
||||
|
||||
self.assertTrue(hasattr(pipe, "components"))
|
||||
self.assertTrue(set(pipe.components.keys()) == set(init_components.keys()))
|
||||
|
||||
def test_inference_batch_single_identical(self):
|
||||
super().test_inference_batch_single_identical(expected_max_diff=3e-3)
|
||||
|
||||
|
||||
@@ -310,6 +310,49 @@ class DownloadTests(unittest.TestCase):
|
||||
assert len([f for f in files if ".bin" in f]) == 8
|
||||
assert not any(".safetensors" in f for f in files)
|
||||
|
||||
def test_download_no_openvino_by_default(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
tmpdirname = DiffusionPipeline.download(
|
||||
"hf-internal-testing/tiny-stable-diffusion-open-vino",
|
||||
cache_dir=tmpdirname,
|
||||
)
|
||||
|
||||
all_root_files = [t[-1] for t in os.walk(os.path.join(tmpdirname))]
|
||||
files = [item for sublist in all_root_files for item in sublist]
|
||||
|
||||
# make sure that by default no openvino weights are downloaded
|
||||
assert all((f.endswith(".json") or f.endswith(".bin") or f.endswith(".txt")) for f in files)
|
||||
assert not any("openvino_" in f for f in files)
|
||||
|
||||
def test_download_no_onnx_by_default(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
tmpdirname = DiffusionPipeline.download(
|
||||
"hf-internal-testing/tiny-random-OnnxStableDiffusionPipeline",
|
||||
cache_dir=tmpdirname,
|
||||
)
|
||||
|
||||
all_root_files = [t[-1] for t in os.walk(os.path.join(tmpdirname))]
|
||||
files = [item for sublist in all_root_files for item in sublist]
|
||||
|
||||
# make sure that by default no onnx weights are downloaded
|
||||
assert all((f.endswith(".json") or f.endswith(".bin") or f.endswith(".txt")) for f in files)
|
||||
assert not any((f.endswith(".onnx") or f.endswith(".pb")) for f in files)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
tmpdirname = DiffusionPipeline.download(
|
||||
"hf-internal-testing/tiny-random-OnnxStableDiffusionPipeline",
|
||||
cache_dir=tmpdirname,
|
||||
use_onnx=True,
|
||||
)
|
||||
|
||||
all_root_files = [t[-1] for t in os.walk(os.path.join(tmpdirname))]
|
||||
files = [item for sublist in all_root_files for item in sublist]
|
||||
|
||||
# if `use_onnx` is specified make sure weights are downloaded
|
||||
assert any((f.endswith(".json") or f.endswith(".bin") or f.endswith(".txt")) for f in files)
|
||||
assert any((f.endswith(".onnx")) for f in files)
|
||||
assert any((f.endswith(".pb")) for f in files)
|
||||
|
||||
def test_download_no_safety_checker(self):
|
||||
prompt = "hello"
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
|
||||
Reference in New Issue
Block a user