Compare commits

...

8 Commits

Author SHA1 Message Date
Chen Yongji
4308bc5dbb Fix docs typo in PipelineForImageToImage name (#4800) 2023-08-28 08:45:45 +02:00
Patrick von Platen
de9c72d58c Release: v0.19.3 2023-07-30 12:01:46 +02:00
Patrick von Platen
7b022df49c [SDXL] Fix dummy imports 2023-07-30 12:01:17 +02:00
Sayak Paul
965e52ce61 Release: v0.19.2 2023-07-28 23:33:51 +05:30
Sayak Paul
b1e52794a2 [Feat] Support SDXL Kohya-style LoRA (#4287)
* sdxl lora changes.

* better name replacement.

* better replacement.

* debugging

* debugging

* debugging

* debugging

* debugging

* remove print.

* print state dict keys.

* print

* distingisuih better

* debuggable.

* fxi: tyests

* fix: arg from training script.

* access from class.

* run style

* debug

* save intermediate

* some simplifications for SDXL LoRA

* styling

* unet config is not needed in diffusers format.

* fix: dynamic SGM block mapping for SDXL kohya loras (#4322)

* Use lora compatible layers for linear proj_in/proj_out (#4323)

* improve condition for using the sgm_diffusers mapping

* informative comment.

* load compatible keys and embedding layer maaping.

* Get SDXL 1.0 example lora to load

* simplify

* specif ranks and hidden sizes.

* better handling of k rank and hidden

* debug

* debug

* debug

* debug

* debug

* fix: alpha keys

* add check for handling LoRAAttnAddedKVProcessor

* sanity comment

* modifications for text encoder SDXL

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* denugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* up

* up

* up

* up

* up

* up

* unneeded comments.

* unneeded comments.

* kwargs for the other attention processors.

* kwargs for the other attention processors.

* debugging

* debugging

* debugging

* debugging

* improve

* debugging

* debugging

* more print

* Fix alphas

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* clean up

* clean up.

* debugging

* fix: text

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Batuhan Taskaya <batuhan@python.org>
2023-07-28 23:28:22 +05:30
Patrick von Platen
c3e3a1ee10 [SDXL] Make watermarker optional under certain circumstances to improve usability of SDXL 1.0 (#4346)
* improve sdxl

* more fixes

* improve sdxl

* improve sdxl

* improve sdxl

* finish
2023-07-28 23:28:11 +05:30
Patrick von Platen
9cde56a729 [ONNX] Don't download ONNX model by default (#4338)
* [Download] Don't download ONNX weights by default

* [Download] Don't download ONNX weights by default

* [Download] Don't download ONNX weights by default

* fix more

* finish

* finish

* finish
2023-07-28 23:27:58 +05:30
Patrick von Platen
c63d7cdba0 [SDXL Refiner] Fix refiner forward pass for batched input (#4327)
* fix_batch_xl

* Fix other pipelines as well

* up

* up

* Update tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_inpaint.py

* sort

* up

* Finish it all up Co-authored-by: Bagheera <bghira@users.github.com>

* Co-authored-by: Bagheera bghira@users.github.com

* Co-authored-by: Bagheera <bghira@users.github.com>

* Finish it all up Co-authored-by: Bagheera <bghira@users.github.com>
2023-07-28 23:27:36 +05:30
36 changed files with 872 additions and 324 deletions

View File

@@ -21,7 +21,7 @@ For example, to perform Image-to-Image with the SD1.5 checkpoint, you can do
```python
from diffusers import PipelineForImageToImage
pipe_i2i = PipelineForImageoImage.from_pretrained("runwayml/stable-diffusion-v1-5")
pipe_i2i = PipelineForImageToImage.from_pretrained("runwayml/stable-diffusion-v1-5")
```
It will also help you switch between tasks seamlessly using the same checkpoint without reallocating additional memory. For example, to re-use the Image-to-Image pipeline we just created for inpainting, you can do

View File

@@ -38,9 +38,25 @@ You can install the libraries as follows:
pip install transformers
pip install accelerate
pip install safetensors
```
### Watermarker
We recommend to add an invisible watermark to images generating by Stable Diffusion XL, this can help with identifying if an image is machine-synthesised for downstream applications. To do so, please install
the [invisible-watermark library](https://pypi.org/project/invisible-watermark/) via:
```
pip install invisible-watermark>=0.2.0
```
If the `invisible-watermark` library is installed the watermarker will be used **by default**.
If you have other provisions for generating or deploying images safely, you can disable the watermarker as follows:
```py
pipe = StableDiffusionXLPipeline.from_pretrained(..., add_watermarker=False)
```
### Text-to-Image
You can use SDXL as follows for *text-to-image*:

View File

@@ -354,4 +354,52 @@ directly with [`~diffusers.loaders.LoraLoaderMixin.load_lora_weights`] like so:
lora_model_id = "sayakpaul/civitai-light-shadow-lora"
lora_filename = "light_and_shadow.safetensors"
pipeline.load_lora_weights(lora_model_id, weight_name=lora_filename)
```
```
### Supporting Stable Diffusion XL LoRAs trained using the Kohya-trainer
With this [PR](https://github.com/huggingface/diffusers/pull/4287), there should now be better support for loading Kohya-style LoRAs trained on Stable Diffusion XL (SDXL).
Here are some example checkpoints we tried out:
* SDXL 0.9:
* https://civitai.com/models/22279?modelVersionId=118556
* https://civitai.com/models/104515/sdxlor30costumesrevue-starlight-saijoclaudine-lora
* https://civitai.com/models/108448/daiton-sdxl-test
* https://filebin.net/2ntfqqnapiu9q3zx/pixelbuildings128-v1.safetensors
* SDXL 1.0:
* https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_offset_example-lora_1.0.safetensors
Here is an example of how to perform inference with these checkpoints in `diffusers`:
```python
from diffusers import DiffusionPipeline
import torch
base_model_id = "stabilityai/stable-diffusion-xl-base-0.9"
pipeline = DiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16).to("cuda")
pipeline.load_lora_weights(".", weight_name="Kamepan.safetensors")
prompt = "anime screencap, glint, drawing, best quality, light smile, shy, a full body of a girl wearing wedding dress in the middle of the forest beneath the trees, fireflies, big eyes, 2d, cute, anime girl, waifu, cel shading, magical girl, vivid colors, (outline:1.1), manga anime artstyle, masterpiece, offical wallpaper, glint <lora:kame_sdxl_v2:1>"
negative_prompt = "(deformed, bad quality, sketch, depth of field, blurry:1.1), grainy, bad anatomy, bad perspective, old, ugly, realistic, cartoon, disney, bad propotions"
generator = torch.manual_seed(2947883060)
num_inference_steps = 30
guidance_scale = 7
image = pipeline(
prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=num_inference_steps,
generator=generator, guidance_scale=guidance_scale
).images[0]
image.save("Kamepan.png")
```
`Kamepan.safetensors` comes from https://civitai.com/models/22279?modelVersionId=118556 .
If you notice carefully, the inference UX is exactly identical to what we presented in the sections above.
Thanks to [@isidentical](https://github.com/isidentical) for helping us on integrating this feature.
### Known limitations specific to the Kohya-styled LoRAs
* SDXL LoRAs that have both the text encoders are currently leading to weird results. We're actively investigating the issue.
* When images don't looks similar to other UIs such ComfyUI, it can be beacause of multiple reasons as explained [here](https://github.com/huggingface/diffusers/pull/4287/#issuecomment-1655110736).

View File

@@ -4,6 +4,5 @@ transformers>=4.25.1
ftfy
tensorboard
Jinja2
invisible-watermark>=0.2.0
datasets
wandb

View File

@@ -4,4 +4,3 @@ transformers>=4.25.1
ftfy
tensorboard
Jinja2
invisible-watermark>=0.2.0

View File

@@ -924,10 +924,10 @@ def main(args):
else:
raise ValueError(f"unexpected save model: {model.__class__}")
lora_state_dict, network_alpha = LoraLoaderMixin.lora_state_dict(input_dir)
LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alpha=network_alpha, unet=unet_)
lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir)
LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_)
LoraLoaderMixin.load_lora_into_text_encoder(
lora_state_dict, network_alpha=network_alpha, text_encoder=text_encoder_
lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_
)
accelerator.register_save_state_pre_hook(save_model_hook)

View File

@@ -829,13 +829,13 @@ def main(args):
else:
raise ValueError(f"unexpected save model: {model.__class__}")
lora_state_dict, network_alpha = LoraLoaderMixin.lora_state_dict(input_dir)
LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alpha=network_alpha, unet=unet_)
lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir)
LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_)
LoraLoaderMixin.load_lora_into_text_encoder(
lora_state_dict, network_alpha=network_alpha, text_encoder=text_encoder_one_
lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_one_
)
LoraLoaderMixin.load_lora_into_text_encoder(
lora_state_dict, network_alpha=network_alpha, text_encoder=text_encoder_two_
lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_two_
)
accelerator.register_save_state_pre_hook(save_model_hook)

View File

@@ -233,7 +233,7 @@ install_requires = [
setup(
name="diffusers",
version="0.19.1", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
version="0.19.3", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
description="Diffusers",
long_description=open("README.md", "r", encoding="utf-8").read(),
long_description_content_type="text/markdown",

View File

@@ -1,4 +1,4 @@
__version__ = "0.19.1"
__version__ = "0.19.3"
from .configuration_utils import ConfigMixin
from .utils import (
@@ -185,6 +185,11 @@ else:
StableDiffusionPix2PixZeroPipeline,
StableDiffusionSAGPipeline,
StableDiffusionUpscalePipeline,
StableDiffusionXLControlNetPipeline,
StableDiffusionXLImg2ImgPipeline,
StableDiffusionXLInpaintPipeline,
StableDiffusionXLInstructPix2PixPipeline,
StableDiffusionXLPipeline,
StableUnCLIPImg2ImgPipeline,
StableUnCLIPPipeline,
TextToVideoSDPipeline,
@@ -202,20 +207,6 @@ else:
VQDiffusionPipeline,
)
try:
if not (is_torch_available() and is_transformers_available() and is_invisible_watermark_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils.dummy_torch_and_transformers_and_invisible_watermark_objects import * # noqa F403
else:
from .pipelines import (
StableDiffusionXLControlNetPipeline,
StableDiffusionXLImg2ImgPipeline,
StableDiffusionXLInpaintPipeline,
StableDiffusionXLInstructPix2PixPipeline,
StableDiffusionXLPipeline,
)
try:
if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()):
raise OptionalDependencyNotAvailable()

View File

@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import re
import warnings
from collections import defaultdict
from contextlib import nullcontext
@@ -56,7 +57,6 @@ UNET_NAME = "unet"
LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
TOTAL_EXAMPLE_KEYS = 5
TEXT_INVERSION_NAME = "learned_embeds.bin"
TEXT_INVERSION_NAME_SAFE = "learned_embeds.safetensors"
@@ -257,7 +257,7 @@ class UNet2DConditionLoadersMixin:
use_safetensors = kwargs.pop("use_safetensors", None)
# This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
# See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
network_alpha = kwargs.pop("network_alpha", None)
network_alphas = kwargs.pop("network_alphas", None)
if use_safetensors and not is_safetensors_available():
raise ValueError(
@@ -322,7 +322,7 @@ class UNet2DConditionLoadersMixin:
attn_processors = {}
non_attn_lora_layers = []
is_lora = all("lora" in k for k in state_dict.keys())
is_lora = all(("lora" in k or k.endswith(".alpha")) for k in state_dict.keys())
is_custom_diffusion = any("custom_diffusion" in k for k in state_dict.keys())
if is_lora:
@@ -339,10 +339,25 @@ class UNet2DConditionLoadersMixin:
state_dict = {k.replace(f"{self.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys}
lora_grouped_dict = defaultdict(dict)
for key, value in state_dict.items():
mapped_network_alphas = {}
all_keys = list(state_dict.keys())
for key in all_keys:
value = state_dict.pop(key)
attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:])
lora_grouped_dict[attn_processor_key][sub_key] = value
# Create another `mapped_network_alphas` dictionary so that we can properly map them.
if network_alphas is not None:
for k in network_alphas:
if k.replace(".alpha", "") in key:
mapped_network_alphas.update({attn_processor_key: network_alphas[k]})
if len(state_dict) > 0:
raise ValueError(
f"The state_dict has to be empty at this point but has the following keys \n\n {', '.join(state_dict.keys())}"
)
for key, value_dict in lora_grouped_dict.items():
attn_processor = self
for sub_key in key.split("."):
@@ -352,13 +367,27 @@ class UNet2DConditionLoadersMixin:
# or add_{k,v,q,out_proj}_proj_lora layers.
if "lora.down.weight" in value_dict:
rank = value_dict["lora.down.weight"].shape[0]
hidden_size = value_dict["lora.up.weight"].shape[0]
if isinstance(attn_processor, LoRACompatibleConv):
lora = LoRAConv2dLayer(hidden_size, hidden_size, rank, network_alpha)
in_features = attn_processor.in_channels
out_features = attn_processor.out_channels
kernel_size = attn_processor.kernel_size
lora = LoRAConv2dLayer(
in_features=in_features,
out_features=out_features,
rank=rank,
kernel_size=kernel_size,
stride=attn_processor.stride,
padding=attn_processor.padding,
network_alpha=mapped_network_alphas.get(key),
)
elif isinstance(attn_processor, LoRACompatibleLinear):
lora = LoRALinearLayer(
attn_processor.in_features, attn_processor.out_features, rank, network_alpha
attn_processor.in_features,
attn_processor.out_features,
rank,
mapped_network_alphas.get(key),
)
else:
raise ValueError(f"Module {key} is not a LoRACompatibleConv or LoRACompatibleLinear module.")
@@ -366,32 +395,64 @@ class UNet2DConditionLoadersMixin:
value_dict = {k.replace("lora.", ""): v for k, v in value_dict.items()}
lora.load_state_dict(value_dict)
non_attn_lora_layers.append((attn_processor, lora))
continue
rank = value_dict["to_k_lora.down.weight"].shape[0]
hidden_size = value_dict["to_k_lora.up.weight"].shape[0]
if isinstance(
attn_processor, (AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0)
):
cross_attention_dim = value_dict["add_k_proj_lora.down.weight"].shape[1]
attn_processor_class = LoRAAttnAddedKVProcessor
else:
cross_attention_dim = value_dict["to_k_lora.down.weight"].shape[1]
if isinstance(attn_processor, (XFormersAttnProcessor, LoRAXFormersAttnProcessor)):
attn_processor_class = LoRAXFormersAttnProcessor
# To handle SDXL.
rank_mapping = {}
hidden_size_mapping = {}
for projection_id in ["to_k", "to_q", "to_v", "to_out"]:
rank = value_dict[f"{projection_id}_lora.down.weight"].shape[0]
hidden_size = value_dict[f"{projection_id}_lora.up.weight"].shape[0]
rank_mapping.update({f"{projection_id}_lora.down.weight": rank})
hidden_size_mapping.update({f"{projection_id}_lora.up.weight": hidden_size})
if isinstance(
attn_processor, (AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0)
):
cross_attention_dim = value_dict["add_k_proj_lora.down.weight"].shape[1]
attn_processor_class = LoRAAttnAddedKVProcessor
else:
attn_processor_class = (
LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
cross_attention_dim = value_dict["to_k_lora.down.weight"].shape[1]
if isinstance(attn_processor, (XFormersAttnProcessor, LoRAXFormersAttnProcessor)):
attn_processor_class = LoRAXFormersAttnProcessor
else:
attn_processor_class = (
LoRAAttnProcessor2_0
if hasattr(F, "scaled_dot_product_attention")
else LoRAAttnProcessor
)
if attn_processor_class is not LoRAAttnAddedKVProcessor:
attn_processors[key] = attn_processor_class(
rank=rank_mapping.get("to_k_lora.down.weight"),
hidden_size=hidden_size_mapping.get("to_k_lora.up.weight"),
cross_attention_dim=cross_attention_dim,
network_alpha=mapped_network_alphas.get(key),
q_rank=rank_mapping.get("to_q_lora.down.weight"),
q_hidden_size=hidden_size_mapping.get("to_q_lora.up.weight"),
v_rank=rank_mapping.get("to_v_lora.down.weight"),
v_hidden_size=hidden_size_mapping.get("to_v_lora.up.weight"),
out_rank=rank_mapping.get("to_out_lora.down.weight"),
out_hidden_size=hidden_size_mapping.get("to_out_lora.up.weight"),
# rank=rank_mapping.get("to_k_lora.down.weight", None),
# hidden_size=hidden_size_mapping.get("to_k_lora.up.weight", None),
# q_rank=rank_mapping.get("to_q_lora.down.weight", None),
# q_hidden_size=hidden_size_mapping.get("to_q_lora.up.weight", None),
# v_rank=rank_mapping.get("to_v_lora.down.weight", None),
# v_hidden_size=hidden_size_mapping.get("to_v_lora.up.weight", None),
# out_rank=rank_mapping.get("to_out_lora.down.weight", None),
# out_hidden_size=hidden_size_mapping.get("to_out_lora.up.weight", None),
)
else:
attn_processors[key] = attn_processor_class(
rank=rank_mapping.get("to_k_lora.down.weight", None),
hidden_size=hidden_size_mapping.get("to_k_lora.up.weight", None),
cross_attention_dim=cross_attention_dim,
network_alpha=mapped_network_alphas.get(key),
)
attn_processors[key] = attn_processor_class(
hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim,
rank=rank,
network_alpha=network_alpha,
)
attn_processors[key].load_state_dict(value_dict)
attn_processors[key].load_state_dict(value_dict)
elif is_custom_diffusion:
custom_diffusion_grouped_dict = defaultdict(dict)
for key, value in state_dict.items():
@@ -434,8 +495,10 @@ class UNet2DConditionLoadersMixin:
# set ff layers
for target_module, lora_layer in non_attn_lora_layers:
if hasattr(target_module, "set_lora_layer"):
target_module.set_lora_layer(lora_layer)
target_module.set_lora_layer(lora_layer)
# It should raise an error if we don't have a set lora here
# if hasattr(target_module, "set_lora_layer"):
# target_module.set_lora_layer(lora_layer)
def save_attn_procs(
self,
@@ -880,11 +943,11 @@ class LoraLoaderMixin:
kwargs (`dict`, *optional*):
See [`~loaders.LoraLoaderMixin.lora_state_dict`].
"""
state_dict, network_alpha = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
self.load_lora_into_unet(state_dict, network_alpha=network_alpha, unet=self.unet)
state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
self.load_lora_into_unet(state_dict, network_alphas=network_alphas, unet=self.unet)
self.load_lora_into_text_encoder(
state_dict,
network_alpha=network_alpha,
network_alphas=network_alphas,
text_encoder=self.text_encoder,
lora_scale=self.lora_scale,
)
@@ -896,7 +959,7 @@ class LoraLoaderMixin:
**kwargs,
):
r"""
Return state dict for lora weights
Return state dict for lora weights and the network alphas.
<Tip warning={true}>
@@ -957,6 +1020,7 @@ class LoraLoaderMixin:
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None)
unet_config = kwargs.pop("unet_config", None)
use_safetensors = kwargs.pop("use_safetensors", None)
if use_safetensors and not is_safetensors_available():
@@ -1018,53 +1082,158 @@ class LoraLoaderMixin:
else:
state_dict = pretrained_model_name_or_path_or_dict
# Convert kohya-ss Style LoRA attn procs to diffusers attn procs
network_alpha = None
if all((k.startswith("lora_te_") or k.startswith("lora_unet_")) for k in state_dict.keys()):
state_dict, network_alpha = cls._convert_kohya_lora_to_diffusers(state_dict)
network_alphas = None
if all(
(
k.startswith("lora_te_")
or k.startswith("lora_unet_")
or k.startswith("lora_te1_")
or k.startswith("lora_te2_")
)
for k in state_dict.keys()
):
# Map SDXL blocks correctly.
if unet_config is not None:
# use unet config to remap block numbers
state_dict = cls._map_sgm_blocks_to_diffusers(state_dict, unet_config)
state_dict, network_alphas = cls._convert_kohya_lora_to_diffusers(state_dict)
return state_dict, network_alpha
return state_dict, network_alphas
@classmethod
def load_lora_into_unet(cls, state_dict, network_alpha, unet):
def _map_sgm_blocks_to_diffusers(cls, state_dict, unet_config, delimiter="_", block_slice_pos=5):
is_all_unet = all(k.startswith("lora_unet") for k in state_dict)
new_state_dict = {}
inner_block_map = ["resnets", "attentions", "upsamplers"]
# Retrieves # of down, mid and up blocks
input_block_ids, middle_block_ids, output_block_ids = set(), set(), set()
for layer in state_dict:
if "text" not in layer:
layer_id = int(layer.split(delimiter)[:block_slice_pos][-1])
if "input_blocks" in layer:
input_block_ids.add(layer_id)
elif "middle_block" in layer:
middle_block_ids.add(layer_id)
elif "output_blocks" in layer:
output_block_ids.add(layer_id)
else:
raise ValueError("Checkpoint not supported")
input_blocks = {
layer_id: [key for key in state_dict if f"input_blocks{delimiter}{layer_id}" in key]
for layer_id in input_block_ids
}
middle_blocks = {
layer_id: [key for key in state_dict if f"middle_block{delimiter}{layer_id}" in key]
for layer_id in middle_block_ids
}
output_blocks = {
layer_id: [key for key in state_dict if f"output_blocks{delimiter}{layer_id}" in key]
for layer_id in output_block_ids
}
# Rename keys accordingly
for i in input_block_ids:
block_id = (i - 1) // (unet_config.layers_per_block + 1)
layer_in_block_id = (i - 1) % (unet_config.layers_per_block + 1)
for key in input_blocks[i]:
inner_block_id = int(key.split(delimiter)[block_slice_pos])
inner_block_key = inner_block_map[inner_block_id] if "op" not in key else "downsamplers"
inner_layers_in_block = str(layer_in_block_id) if "op" not in key else "0"
new_key = delimiter.join(
key.split(delimiter)[: block_slice_pos - 1]
+ [str(block_id), inner_block_key, inner_layers_in_block]
+ key.split(delimiter)[block_slice_pos + 1 :]
)
new_state_dict[new_key] = state_dict.pop(key)
for i in middle_block_ids:
key_part = None
if i == 0:
key_part = [inner_block_map[0], "0"]
elif i == 1:
key_part = [inner_block_map[1], "0"]
elif i == 2:
key_part = [inner_block_map[0], "1"]
else:
raise ValueError(f"Invalid middle block id {i}.")
for key in middle_blocks[i]:
new_key = delimiter.join(
key.split(delimiter)[: block_slice_pos - 1] + key_part + key.split(delimiter)[block_slice_pos:]
)
new_state_dict[new_key] = state_dict.pop(key)
for i in output_block_ids:
block_id = i // (unet_config.layers_per_block + 1)
layer_in_block_id = i % (unet_config.layers_per_block + 1)
for key in output_blocks[i]:
inner_block_id = int(key.split(delimiter)[block_slice_pos])
inner_block_key = inner_block_map[inner_block_id]
inner_layers_in_block = str(layer_in_block_id) if inner_block_id < 2 else "0"
new_key = delimiter.join(
key.split(delimiter)[: block_slice_pos - 1]
+ [str(block_id), inner_block_key, inner_layers_in_block]
+ key.split(delimiter)[block_slice_pos + 1 :]
)
new_state_dict[new_key] = state_dict.pop(key)
if is_all_unet and len(state_dict) > 0:
raise ValueError("At this point all state dict entries have to be converted.")
else:
# Remaining is the text encoder state dict.
for k, v in state_dict.items():
new_state_dict.update({k: v})
return new_state_dict
@classmethod
def load_lora_into_unet(cls, state_dict, network_alphas, unet):
"""
This will load the LoRA layers specified in `state_dict` into `unet`
This will load the LoRA layers specified in `state_dict` into `unet`.
Parameters:
state_dict (`dict`):
A standard state dict containing the lora layer parameters. The keys can either be indexed directly
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
encoder lora layers.
network_alpha (`float`):
network_alphas (`Dict[str, float]`):
See `LoRALinearLayer` for more details.
unet (`UNet2DConditionModel`):
The UNet model to load the LoRA layers into.
"""
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
# then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
# their prefixes.
keys = list(state_dict.keys())
if all(key.startswith(cls.unet_name) or key.startswith(cls.text_encoder_name) for key in keys):
# Load the layers corresponding to UNet.
unet_keys = [k for k in keys if k.startswith(cls.unet_name)]
logger.info(f"Loading {cls.unet_name}.")
unet_lora_state_dict = {
k.replace(f"{cls.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys
}
unet.load_attn_procs(unet_lora_state_dict, network_alpha=network_alpha)
# Otherwise, we're dealing with the old format. This means the `state_dict` should only
# contain the module names of the `unet` as its keys WITHOUT any prefix.
elif not all(
key.startswith(cls.unet_name) or key.startswith(cls.text_encoder_name) for key in state_dict.keys()
):
unet.load_attn_procs(state_dict, network_alpha=network_alpha)
unet_keys = [k for k in keys if k.startswith(cls.unet_name)]
state_dict = {k.replace(f"{cls.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys}
if network_alphas is not None:
alpha_keys = [k for k in network_alphas.keys() if k.startswith(cls.unet_name)]
network_alphas = {
k.replace(f"{cls.unet_name}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
}
else:
# Otherwise, we're dealing with the old format. This means the `state_dict` should only
# contain the module names of the `unet` as its keys WITHOUT any prefix.
warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet'.{module_name}: params for module_name, params in old_state_dict.items()}`."
warnings.warn(warn_message)
# load loras into unet
unet.load_attn_procs(state_dict, network_alphas=network_alphas)
@classmethod
def load_lora_into_text_encoder(cls, state_dict, network_alpha, text_encoder, prefix=None, lora_scale=1.0):
def load_lora_into_text_encoder(cls, state_dict, network_alphas, text_encoder, prefix=None, lora_scale=1.0):
"""
This will load the LoRA layers specified in `state_dict` into `text_encoder`
@@ -1072,7 +1241,7 @@ class LoraLoaderMixin:
state_dict (`dict`):
A standard state dict containing the lora layer parameters. The key should be prefixed with an
additional `text_encoder` to distinguish between unet lora layers.
network_alpha (`float`):
network_alphas (`Dict[str, float]`):
See `LoRALinearLayer` for more details.
text_encoder (`CLIPTextModel`):
The text encoder model to load the LoRA layers into.
@@ -1141,14 +1310,19 @@ class LoraLoaderMixin:
].shape[1]
patch_mlp = any(".mlp." in key for key in text_encoder_lora_state_dict.keys())
cls._modify_text_encoder(text_encoder, lora_scale, network_alpha, rank=rank, patch_mlp=patch_mlp)
cls._modify_text_encoder(
text_encoder,
lora_scale,
network_alphas,
rank=rank,
patch_mlp=patch_mlp,
)
# set correct dtype & device
text_encoder_lora_state_dict = {
k: v.to(device=text_encoder.device, dtype=text_encoder.dtype)
for k, v in text_encoder_lora_state_dict.items()
}
load_state_dict_results = text_encoder.load_state_dict(text_encoder_lora_state_dict, strict=False)
if len(load_state_dict_results.unexpected_keys) != 0:
raise ValueError(
@@ -1183,7 +1357,7 @@ class LoraLoaderMixin:
cls,
text_encoder,
lora_scale=1,
network_alpha=None,
network_alphas=None,
rank=4,
dtype=None,
patch_mlp=False,
@@ -1196,37 +1370,46 @@ class LoraLoaderMixin:
cls._remove_text_encoder_monkey_patch_classmethod(text_encoder)
lora_parameters = []
network_alphas = {} if network_alphas is None else network_alphas
for name, attn_module in text_encoder_attn_modules(text_encoder):
query_alpha = network_alphas.get(name + ".k.proj.alpha")
key_alpha = network_alphas.get(name + ".q.proj.alpha")
value_alpha = network_alphas.get(name + ".v.proj.alpha")
proj_alpha = network_alphas.get(name + ".out.proj.alpha")
for _, attn_module in text_encoder_attn_modules(text_encoder):
attn_module.q_proj = PatchedLoraProjection(
attn_module.q_proj, lora_scale, network_alpha, rank=rank, dtype=dtype
attn_module.q_proj, lora_scale, network_alpha=query_alpha, rank=rank, dtype=dtype
)
lora_parameters.extend(attn_module.q_proj.lora_linear_layer.parameters())
attn_module.k_proj = PatchedLoraProjection(
attn_module.k_proj, lora_scale, network_alpha, rank=rank, dtype=dtype
attn_module.k_proj, lora_scale, network_alpha=key_alpha, rank=rank, dtype=dtype
)
lora_parameters.extend(attn_module.k_proj.lora_linear_layer.parameters())
attn_module.v_proj = PatchedLoraProjection(
attn_module.v_proj, lora_scale, network_alpha, rank=rank, dtype=dtype
attn_module.v_proj, lora_scale, network_alpha=value_alpha, rank=rank, dtype=dtype
)
lora_parameters.extend(attn_module.v_proj.lora_linear_layer.parameters())
attn_module.out_proj = PatchedLoraProjection(
attn_module.out_proj, lora_scale, network_alpha, rank=rank, dtype=dtype
attn_module.out_proj, lora_scale, network_alpha=proj_alpha, rank=rank, dtype=dtype
)
lora_parameters.extend(attn_module.out_proj.lora_linear_layer.parameters())
if patch_mlp:
for _, mlp_module in text_encoder_mlp_modules(text_encoder):
for name, mlp_module in text_encoder_mlp_modules(text_encoder):
fc1_alpha = network_alphas.get(name + ".fc1.alpha")
fc2_alpha = network_alphas.get(name + ".fc2.alpha")
mlp_module.fc1 = PatchedLoraProjection(
mlp_module.fc1, lora_scale, network_alpha, rank=rank, dtype=dtype
mlp_module.fc1, lora_scale, network_alpha=fc1_alpha, rank=rank, dtype=dtype
)
lora_parameters.extend(mlp_module.fc1.lora_linear_layer.parameters())
mlp_module.fc2 = PatchedLoraProjection(
mlp_module.fc2, lora_scale, network_alpha, rank=rank, dtype=dtype
mlp_module.fc2, lora_scale, network_alpha=fc2_alpha, rank=rank, dtype=dtype
)
lora_parameters.extend(mlp_module.fc2.lora_linear_layer.parameters())
@@ -1333,77 +1516,163 @@ class LoraLoaderMixin:
def _convert_kohya_lora_to_diffusers(cls, state_dict):
unet_state_dict = {}
te_state_dict = {}
network_alpha = None
unloaded_keys = []
te2_state_dict = {}
network_alphas = {}
for key, value in state_dict.items():
if "hada" in key or "skip" in key:
unloaded_keys.append(key)
elif "lora_down" in key:
lora_name = key.split(".")[0]
lora_name_up = lora_name + ".lora_up.weight"
lora_name_alpha = lora_name + ".alpha"
if lora_name_alpha in state_dict:
alpha = state_dict[lora_name_alpha].item()
if network_alpha is None:
network_alpha = alpha
elif network_alpha != alpha:
raise ValueError("Network alpha is not consistent")
# every down weight has a corresponding up weight and potentially an alpha weight
lora_keys = [k for k in state_dict.keys() if k.endswith("lora_down.weight")]
for key in lora_keys:
lora_name = key.split(".")[0]
lora_name_up = lora_name + ".lora_up.weight"
lora_name_alpha = lora_name + ".alpha"
if lora_name.startswith("lora_unet_"):
diffusers_name = key.replace("lora_unet_", "").replace("_", ".")
# if lora_name_alpha in state_dict:
# alpha = state_dict.pop(lora_name_alpha).item()
# network_alphas.update({lora_name_alpha: alpha})
if lora_name.startswith("lora_unet_"):
diffusers_name = key.replace("lora_unet_", "").replace("_", ".")
if "input.blocks" in diffusers_name:
diffusers_name = diffusers_name.replace("input.blocks", "down_blocks")
else:
diffusers_name = diffusers_name.replace("down.blocks", "down_blocks")
if "middle.block" in diffusers_name:
diffusers_name = diffusers_name.replace("middle.block", "mid_block")
else:
diffusers_name = diffusers_name.replace("mid.block", "mid_block")
if "output.blocks" in diffusers_name:
diffusers_name = diffusers_name.replace("output.blocks", "up_blocks")
else:
diffusers_name = diffusers_name.replace("up.blocks", "up_blocks")
diffusers_name = diffusers_name.replace("transformer.blocks", "transformer_blocks")
diffusers_name = diffusers_name.replace("to.q.lora", "to_q_lora")
diffusers_name = diffusers_name.replace("to.k.lora", "to_k_lora")
diffusers_name = diffusers_name.replace("to.v.lora", "to_v_lora")
diffusers_name = diffusers_name.replace("to.out.0.lora", "to_out_lora")
diffusers_name = diffusers_name.replace("proj.in", "proj_in")
diffusers_name = diffusers_name.replace("proj.out", "proj_out")
if "transformer_blocks" in diffusers_name:
if "attn1" in diffusers_name or "attn2" in diffusers_name:
diffusers_name = diffusers_name.replace("attn1", "attn1.processor")
diffusers_name = diffusers_name.replace("attn2", "attn2.processor")
unet_state_dict[diffusers_name] = value
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up]
elif "ff" in diffusers_name:
unet_state_dict[diffusers_name] = value
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up]
elif any(key in diffusers_name for key in ("proj_in", "proj_out")):
unet_state_dict[diffusers_name] = value
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up]
elif lora_name.startswith("lora_te_"):
diffusers_name = key.replace("lora_te_", "").replace("_", ".")
diffusers_name = diffusers_name.replace("text.model", "text_model")
diffusers_name = diffusers_name.replace("self.attn", "self_attn")
diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora")
diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora")
diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora")
diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora")
if "self_attn" in diffusers_name:
te_state_dict[diffusers_name] = value
te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up]
elif "mlp" in diffusers_name:
# Be aware that this is the new diffusers convention and the rest of the code might
# not utilize it yet.
diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.")
te_state_dict[diffusers_name] = value
te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up]
diffusers_name = diffusers_name.replace("transformer.blocks", "transformer_blocks")
diffusers_name = diffusers_name.replace("to.q.lora", "to_q_lora")
diffusers_name = diffusers_name.replace("to.k.lora", "to_k_lora")
diffusers_name = diffusers_name.replace("to.v.lora", "to_v_lora")
diffusers_name = diffusers_name.replace("to.out.0.lora", "to_out_lora")
diffusers_name = diffusers_name.replace("proj.in", "proj_in")
diffusers_name = diffusers_name.replace("proj.out", "proj_out")
diffusers_name = diffusers_name.replace("emb.layers", "time_emb_proj")
logger.info("Kohya-style checkpoint detected.")
if len(unloaded_keys) > 0:
example_unloaded_keys = ", ".join(x for x in unloaded_keys[:TOTAL_EXAMPLE_KEYS])
logger.warning(
f"There are some keys (such as: {example_unloaded_keys}) in the checkpoints we don't provide support for."
# SDXL specificity.
if "emb" in diffusers_name:
pattern = r"\.\d+(?=\D*$)"
diffusers_name = re.sub(pattern, "", diffusers_name, count=1)
if ".in." in diffusers_name:
diffusers_name = diffusers_name.replace("in.layers.2", "conv1")
if ".out." in diffusers_name:
diffusers_name = diffusers_name.replace("out.layers.3", "conv2")
if "downsamplers" in diffusers_name or "upsamplers" in diffusers_name:
diffusers_name = diffusers_name.replace("op", "conv")
if "skip" in diffusers_name:
diffusers_name = diffusers_name.replace("skip.connection", "conv_shortcut")
if "transformer_blocks" in diffusers_name:
if "attn1" in diffusers_name or "attn2" in diffusers_name:
diffusers_name = diffusers_name.replace("attn1", "attn1.processor")
diffusers_name = diffusers_name.replace("attn2", "attn2.processor")
unet_state_dict[diffusers_name] = state_dict.pop(key)
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
elif "ff" in diffusers_name:
unet_state_dict[diffusers_name] = state_dict.pop(key)
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
elif any(key in diffusers_name for key in ("proj_in", "proj_out")):
unet_state_dict[diffusers_name] = state_dict.pop(key)
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
else:
unet_state_dict[diffusers_name] = state_dict.pop(key)
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
elif lora_name.startswith("lora_te_"):
diffusers_name = key.replace("lora_te_", "").replace("_", ".")
diffusers_name = diffusers_name.replace("text.model", "text_model")
diffusers_name = diffusers_name.replace("self.attn", "self_attn")
diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora")
diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora")
diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora")
diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora")
if "self_attn" in diffusers_name:
te_state_dict[diffusers_name] = state_dict.pop(key)
te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
elif "mlp" in diffusers_name:
# Be aware that this is the new diffusers convention and the rest of the code might
# not utilize it yet.
diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.")
te_state_dict[diffusers_name] = state_dict.pop(key)
te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
# (sayakpaul): Duplicate code. Needs to be cleaned.
elif lora_name.startswith("lora_te1_"):
diffusers_name = key.replace("lora_te1_", "").replace("_", ".")
diffusers_name = diffusers_name.replace("text.model", "text_model")
diffusers_name = diffusers_name.replace("self.attn", "self_attn")
diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora")
diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora")
diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora")
diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora")
if "self_attn" in diffusers_name:
te_state_dict[diffusers_name] = state_dict.pop(key)
te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
elif "mlp" in diffusers_name:
# Be aware that this is the new diffusers convention and the rest of the code might
# not utilize it yet.
diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.")
te_state_dict[diffusers_name] = state_dict.pop(key)
te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
# (sayakpaul): Duplicate code. Needs to be cleaned.
elif lora_name.startswith("lora_te2_"):
diffusers_name = key.replace("lora_te2_", "").replace("_", ".")
diffusers_name = diffusers_name.replace("text.model", "text_model")
diffusers_name = diffusers_name.replace("self.attn", "self_attn")
diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora")
diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora")
diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora")
diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora")
if "self_attn" in diffusers_name:
te2_state_dict[diffusers_name] = state_dict.pop(key)
te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
elif "mlp" in diffusers_name:
# Be aware that this is the new diffusers convention and the rest of the code might
# not utilize it yet.
diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.")
te2_state_dict[diffusers_name] = state_dict.pop(key)
te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
# Rename the alphas so that they can be mapped appropriately.
if lora_name_alpha in state_dict:
alpha = state_dict.pop(lora_name_alpha).item()
if lora_name_alpha.startswith("lora_unet_"):
prefix = "unet."
elif lora_name_alpha.startswith(("lora_te_", "lora_te1_")):
prefix = "text_encoder."
else:
prefix = "text_encoder_2."
new_name = prefix + diffusers_name.split(".lora.")[0] + ".alpha"
network_alphas.update({new_name: alpha})
if len(state_dict) > 0:
raise ValueError(
f"The following keys have not been correctly be renamed: \n\n {', '.join(state_dict.keys())}"
)
unet_state_dict = {f"{UNET_NAME}.{module_name}": params for module_name, params in unet_state_dict.items()}
te_state_dict = {f"{TEXT_ENCODER_NAME}.{module_name}": params for module_name, params in te_state_dict.items()}
logger.info("Kohya-style checkpoint detected.")
unet_state_dict = {f"{cls.unet_name}.{module_name}": params for module_name, params in unet_state_dict.items()}
te_state_dict = {
f"{cls.text_encoder_name}.{module_name}": params for module_name, params in te_state_dict.items()
}
te2_state_dict = (
{f"text_encoder_2.{module_name}": params for module_name, params in te2_state_dict.items()}
if len(te2_state_dict) > 0
else None
)
if te2_state_dict is not None:
te_state_dict.update(te2_state_dict)
new_state_dict = {**unet_state_dict, **te_state_dict}
return new_state_dict, network_alpha
return new_state_dict, network_alphas
def unload_lora_weights(self):
"""

View File

@@ -521,17 +521,32 @@ class LoRAAttnProcessor(nn.Module):
Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
"""
def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None):
def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None, **kwargs):
super().__init__()
self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
self.rank = rank
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
q_rank = kwargs.pop("q_rank", None)
q_hidden_size = kwargs.pop("q_hidden_size", None)
q_rank = q_rank if q_rank is not None else rank
q_hidden_size = q_hidden_size if q_hidden_size is not None else hidden_size
v_rank = kwargs.pop("v_rank", None)
v_hidden_size = kwargs.pop("v_hidden_size", None)
v_rank = v_rank if v_rank is not None else rank
v_hidden_size = v_hidden_size if v_hidden_size is not None else hidden_size
out_rank = kwargs.pop("out_rank", None)
out_hidden_size = kwargs.pop("out_hidden_size", None)
out_rank = out_rank if out_rank is not None else rank
out_hidden_size = out_hidden_size if out_hidden_size is not None else hidden_size
self.to_q_lora = LoRALinearLayer(q_hidden_size, q_hidden_size, q_rank, network_alpha)
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha)
self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha)
def __call__(
self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0, temb=None
@@ -1144,7 +1159,13 @@ class LoRAXFormersAttnProcessor(nn.Module):
"""
def __init__(
self, hidden_size, cross_attention_dim, rank=4, attention_op: Optional[Callable] = None, network_alpha=None
self,
hidden_size,
cross_attention_dim,
rank=4,
attention_op: Optional[Callable] = None,
network_alpha=None,
**kwargs,
):
super().__init__()
@@ -1153,10 +1174,25 @@ class LoRAXFormersAttnProcessor(nn.Module):
self.rank = rank
self.attention_op = attention_op
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
q_rank = kwargs.pop("q_rank", None)
q_hidden_size = kwargs.pop("q_hidden_size", None)
q_rank = q_rank if q_rank is not None else rank
q_hidden_size = q_hidden_size if q_hidden_size is not None else hidden_size
v_rank = kwargs.pop("v_rank", None)
v_hidden_size = kwargs.pop("v_hidden_size", None)
v_rank = v_rank if v_rank is not None else rank
v_hidden_size = v_hidden_size if v_hidden_size is not None else hidden_size
out_rank = kwargs.pop("out_rank", None)
out_hidden_size = kwargs.pop("out_hidden_size", None)
out_rank = out_rank if out_rank is not None else rank
out_hidden_size = out_hidden_size if out_hidden_size is not None else hidden_size
self.to_q_lora = LoRALinearLayer(q_hidden_size, q_hidden_size, q_rank, network_alpha)
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha)
self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha)
def __call__(
self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0, temb=None
@@ -1231,7 +1267,7 @@ class LoRAAttnProcessor2_0(nn.Module):
Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
"""
def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None):
def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None, **kwargs):
super().__init__()
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
@@ -1240,10 +1276,25 @@ class LoRAAttnProcessor2_0(nn.Module):
self.cross_attention_dim = cross_attention_dim
self.rank = rank
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
q_rank = kwargs.pop("q_rank", None)
q_hidden_size = kwargs.pop("q_hidden_size", None)
q_rank = q_rank if q_rank is not None else rank
q_hidden_size = q_hidden_size if q_hidden_size is not None else hidden_size
v_rank = kwargs.pop("v_rank", None)
v_hidden_size = kwargs.pop("v_hidden_size", None)
v_rank = v_rank if v_rank is not None else rank
v_hidden_size = v_hidden_size if v_hidden_size is not None else hidden_size
out_rank = kwargs.pop("out_rank", None)
out_hidden_size = kwargs.pop("out_hidden_size", None)
out_rank = out_rank if out_rank is not None else rank
out_hidden_size = out_hidden_size if out_hidden_size is not None else hidden_size
self.to_q_lora = LoRALinearLayer(q_hidden_size, q_hidden_size, q_rank, network_alpha)
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha)
self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha)
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0):
residual = hidden_states

View File

@@ -49,14 +49,19 @@ class LoRALinearLayer(nn.Module):
class LoRAConv2dLayer(nn.Module):
def __init__(self, in_features, out_features, rank=4, network_alpha=None):
def __init__(
self, in_features, out_features, rank=4, kernel_size=(1, 1), stride=(1, 1), padding=0, network_alpha=None
):
super().__init__()
if rank > min(in_features, out_features):
raise ValueError(f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}")
self.down = nn.Conv2d(in_features, rank, (1, 1), (1, 1), bias=False)
self.up = nn.Conv2d(rank, out_features, (1, 1), (1, 1), bias=False)
self.down = nn.Conv2d(in_features, rank, kernel_size=kernel_size, stride=stride, padding=padding, bias=False)
# according to the official kohya_ss trainer kernel_size are always fixed for the up layer
# # see: https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L129
self.up = nn.Conv2d(rank, out_features, kernel_size=(1, 1), stride=(1, 1), bias=False)
# This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
# See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
self.network_alpha = network_alpha

View File

@@ -23,6 +23,7 @@ import torch.nn.functional as F
from .activations import get_activation
from .attention import AdaGroupNorm
from .attention_processor import SpatialNorm
from .lora import LoRACompatibleConv, LoRACompatibleLinear
class Upsample1D(nn.Module):
@@ -126,7 +127,7 @@ class Upsample2D(nn.Module):
if use_conv_transpose:
conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1)
elif use_conv:
conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=1)
conv = LoRACompatibleConv(self.channels, self.out_channels, 3, padding=1)
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
if name == "conv":
@@ -196,7 +197,7 @@ class Downsample2D(nn.Module):
self.name = name
if use_conv:
conv = nn.Conv2d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
conv = LoRACompatibleConv(self.channels, self.out_channels, 3, stride=stride, padding=padding)
else:
assert self.channels == self.out_channels
conv = nn.AvgPool2d(kernel_size=stride, stride=stride)
@@ -534,13 +535,13 @@ class ResnetBlock2D(nn.Module):
else:
self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.conv1 = LoRACompatibleConv(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
if temb_channels is not None:
if self.time_embedding_norm == "default":
self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels)
self.time_emb_proj = LoRACompatibleLinear(temb_channels, out_channels)
elif self.time_embedding_norm == "scale_shift":
self.time_emb_proj = torch.nn.Linear(temb_channels, 2 * out_channels)
self.time_emb_proj = LoRACompatibleLinear(temb_channels, 2 * out_channels)
elif self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
self.time_emb_proj = None
else:
@@ -557,7 +558,7 @@ class ResnetBlock2D(nn.Module):
self.dropout = torch.nn.Dropout(dropout)
conv_2d_out_channels = conv_2d_out_channels or out_channels
self.conv2 = torch.nn.Conv2d(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1)
self.conv2 = LoRACompatibleConv(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1)
self.nonlinearity = get_activation(non_linearity)
@@ -583,7 +584,7 @@ class ResnetBlock2D(nn.Module):
self.conv_shortcut = None
if self.use_in_shortcut:
self.conv_shortcut = torch.nn.Conv2d(
self.conv_shortcut = LoRACompatibleConv(
in_channels, conv_2d_out_channels, kernel_size=1, stride=1, padding=0, bias=conv_shortcut_bias
)

View File

@@ -23,7 +23,7 @@ from ..models.embeddings import ImagePositionalEmbeddings
from ..utils import BaseOutput, deprecate
from .attention import BasicTransformerBlock
from .embeddings import PatchEmbed
from .lora import LoRACompatibleConv
from .lora import LoRACompatibleConv, LoRACompatibleLinear
from .modeling_utils import ModelMixin
@@ -137,7 +137,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
if use_linear_projection:
self.proj_in = nn.Linear(in_channels, inner_dim)
self.proj_in = LoRACompatibleLinear(in_channels, inner_dim)
else:
self.proj_in = LoRACompatibleConv(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
elif self.is_input_vectorized:
@@ -193,7 +193,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
if self.is_input_continuous:
# TODO: should use out_channels for continuous projections
if use_linear_projection:
self.proj_out = nn.Linear(inner_dim, in_channels)
self.proj_out = LoRACompatibleLinear(inner_dim, in_channels)
else:
self.proj_out = LoRACompatibleConv(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
elif self.is_input_vectorized:

View File

@@ -1,7 +1,6 @@
from ..utils import (
OptionalDependencyNotAvailable,
is_flax_available,
is_invisible_watermark_available,
is_k_diffusion_available,
is_librosa_available,
is_note_seq_available,
@@ -51,6 +50,7 @@ else:
StableDiffusionControlNetImg2ImgPipeline,
StableDiffusionControlNetInpaintPipeline,
StableDiffusionControlNetPipeline,
StableDiffusionXLControlNetPipeline,
)
from .deepfloyd_if import (
IFImg2ImgPipeline,
@@ -108,6 +108,12 @@ else:
StableUnCLIPPipeline,
)
from .stable_diffusion_safe import StableDiffusionPipelineSafe
from .stable_diffusion_xl import (
StableDiffusionXLImg2ImgPipeline,
StableDiffusionXLInpaintPipeline,
StableDiffusionXLInstructPix2PixPipeline,
StableDiffusionXLPipeline,
)
from .t2i_adapter import StableDiffusionAdapterPipeline
from .text_to_video_synthesis import TextToVideoSDPipeline, TextToVideoZeroPipeline, VideoToVideoSDPipeline
from .unclip import UnCLIPImageVariationPipeline, UnCLIPPipeline
@@ -121,20 +127,6 @@ else:
from .vq_diffusion import VQDiffusionPipeline
try:
if not (is_torch_available() and is_transformers_available() and is_invisible_watermark_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ..utils.dummy_torch_and_transformers_and_invisible_watermark_objects import * # noqa F403
else:
from .controlnet import StableDiffusionXLControlNetPipeline
from .stable_diffusion_xl import (
StableDiffusionXLImg2ImgPipeline,
StableDiffusionXLInpaintPipeline,
StableDiffusionXLInstructPix2PixPipeline,
StableDiffusionXLPipeline,
)
try:
if not is_onnx_available():
raise OptionalDependencyNotAvailable()

View File

@@ -1,21 +1,11 @@
from ...utils import (
OptionalDependencyNotAvailable,
is_flax_available,
is_invisible_watermark_available,
is_torch_available,
is_transformers_available,
)
try:
if not (is_transformers_available() and is_torch_available() and is_invisible_watermark_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_and_invisible_watermark_objects import * # noqa F403
else:
from .pipeline_controlnet_sd_xl import StableDiffusionXLControlNetPipeline
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
@@ -26,6 +16,7 @@ else:
from .pipeline_controlnet import StableDiffusionControlNetPipeline
from .pipeline_controlnet_img2img import StableDiffusionControlNetImg2ImgPipeline
from .pipeline_controlnet_inpaint import StableDiffusionControlNetInpaintPipeline
from .pipeline_controlnet_sd_xl import StableDiffusionXLControlNetPipeline
if is_transformers_available() and is_flax_available():

View File

@@ -22,6 +22,8 @@ import torch
import torch.nn.functional as F
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
from diffusers.utils.import_utils import is_invisible_watermark_available
from ...image_processor import VaeImageProcessor
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
@@ -42,7 +44,11 @@ from ...utils import (
)
from ..pipeline_utils import DiffusionPipeline
from ..stable_diffusion_xl import StableDiffusionXLPipelineOutput
from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
if is_invisible_watermark_available():
from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
from .multicontrolnet import MultiControlNetModel
@@ -109,6 +115,7 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa
controlnet: ControlNetModel,
scheduler: KarrasDiffusionSchedulers,
force_zeros_for_empty_prompt: bool = True,
add_watermarker: Optional[bool] = None,
):
super().__init__()
@@ -130,7 +137,13 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa
self.control_image_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
)
self.watermark = StableDiffusionXLWatermarker()
add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
if add_watermarker:
self.watermark = StableDiffusionXLWatermarker()
else:
self.watermark = None
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
@@ -995,7 +1008,10 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa
image = latents
return StableDiffusionXLPipelineOutput(images=image)
image = self.watermark.apply_watermark(image)
# apply watermark if available
if self.watermark is not None:
image = self.watermark.apply_watermark(image)
image = self.image_processor.postprocess(image, output_type=output_type)
# Offload last model to CPU

View File

@@ -494,6 +494,7 @@ class DiffusionPipeline(ConfigMixin):
_optional_components = []
_exclude_from_cpu_offload = []
_load_connected_pipes = False
_is_onnx = False
def register_modules(self, **kwargs):
# import it here to avoid circular import
@@ -839,6 +840,11 @@ class DiffusionPipeline(ConfigMixin):
If set to `None`, the safetensors weights are downloaded if they're available **and** if the
safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
weights. If set to `False`, safetensors weights are not loaded.
use_onnx (`bool`, *optional*, defaults to `None`):
If set to `True`, ONNX weights will always be downloaded if present. If set to `False`, ONNX weights
will never be downloaded. By default `use_onnx` defaults to the `_is_onnx` class attribute which is
`False` for non-ONNX pipelines and `True` for ONNX pipelines. ONNX weights include both files ending
with `.onnx` and `.pb`.
kwargs (remaining dictionary of keyword arguments, *optional*):
Can be used to overwrite load and saveable variables (the pipeline components of the specific pipeline
class). The overwritten components are passed directly to the pipelines `__init__` method. See example
@@ -1268,6 +1274,15 @@ class DiffusionPipeline(ConfigMixin):
variant (`str`, *optional*):
Load weights from a specified variant filename such as `"fp16"` or `"ema"`. This is ignored when
loading `from_flax`.
use_safetensors (`bool`, *optional*, defaults to `None`):
If set to `None`, the safetensors weights are downloaded if they're available **and** if the
safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
weights. If set to `False`, safetensors weights are not loaded.
use_onnx (`bool`, *optional*, defaults to `False`):
If set to `True`, ONNX weights will always be downloaded if present. If set to `False`, ONNX weights
will never be downloaded. By default `use_onnx` defaults to the `_is_onnx` class attribute which is
`False` for non-ONNX pipelines and `True` for ONNX pipelines. ONNX weights include both files ending
with `.onnx` and `.pb`.
Returns:
`os.PathLike`:
@@ -1293,6 +1308,7 @@ class DiffusionPipeline(ConfigMixin):
custom_revision = kwargs.pop("custom_revision", None)
variant = kwargs.pop("variant", None)
use_safetensors = kwargs.pop("use_safetensors", None)
use_onnx = kwargs.pop("use_onnx", None)
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
if use_safetensors and not is_safetensors_available():
@@ -1364,7 +1380,7 @@ class DiffusionPipeline(ConfigMixin):
pretrained_model_name, use_auth_token, variant, revision, model_filenames
)
model_folder_names = {os.path.split(f)[0] for f in model_filenames}
model_folder_names = {os.path.split(f)[0] for f in model_filenames if os.path.split(f)[0] in folder_names}
# all filenames compatible with variant will be added
allow_patterns = list(model_filenames)
@@ -1411,6 +1427,10 @@ class DiffusionPipeline(ConfigMixin):
):
ignore_patterns = ["*.bin", "*.msgpack"]
use_onnx = use_onnx if use_onnx is not None else pipeline_class._is_onnx
if not use_onnx:
ignore_patterns += ["*.onnx", "*.pb"]
safetensors_variant_filenames = {f for f in variant_filenames if f.endswith(".safetensors")}
safetensors_model_filenames = {f for f in model_filenames if f.endswith(".safetensors")}
if (
@@ -1423,6 +1443,10 @@ class DiffusionPipeline(ConfigMixin):
else:
ignore_patterns = ["*.safetensors", "*.msgpack"]
use_onnx = use_onnx if use_onnx is not None else pipeline_class._is_onnx
if not use_onnx:
ignore_patterns += ["*.onnx", "*.pb"]
bin_variant_filenames = {f for f in variant_filenames if f.endswith(".bin")}
bin_model_filenames = {f for f in model_filenames if f.endswith(".bin")}
if len(bin_variant_filenames) > 0 and bin_model_filenames != bin_variant_filenames:

View File

@@ -41,6 +41,7 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
feature_extractor: CLIPImageProcessor
_optional_components = ["safety_checker", "feature_extractor"]
_is_onnx = True
def __init__(
self,

View File

@@ -98,6 +98,7 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
feature_extractor: CLIPImageProcessor
_optional_components = ["safety_checker", "feature_extractor"]
_is_onnx = True
def __init__(
self,

View File

@@ -90,6 +90,7 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
feature_extractor: CLIPImageProcessor
_optional_components = ["safety_checker", "feature_extractor"]
_is_onnx = True
def __init__(
self,

View File

@@ -67,6 +67,7 @@ class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
_optional_components = ["safety_checker", "feature_extractor"]
_is_onnx = True
vae_encoder: OnnxRuntimeModel
vae_decoder: OnnxRuntimeModel

View File

@@ -46,6 +46,8 @@ def preprocess(image):
class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
_is_onnx = True
def __init__(
self,
vae: OnnxRuntimeModel,

View File

@@ -7,7 +7,6 @@ import PIL
from ...utils import (
BaseOutput,
OptionalDependencyNotAvailable,
is_invisible_watermark_available,
is_torch_available,
is_transformers_available,
)
@@ -28,10 +27,10 @@ class StableDiffusionXLPipelineOutput(BaseOutput):
try:
if not (is_transformers_available() and is_torch_available() and is_invisible_watermark_available()):
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_and_invisible_watermark_objects import * # noqa F403
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
else:
from .pipeline_stable_diffusion_xl import StableDiffusionXLPipeline
from .pipeline_stable_diffusion_xl_img2img import StableDiffusionXLImg2ImgPipeline

View File

@@ -32,13 +32,17 @@ from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
is_accelerate_available,
is_accelerate_version,
is_invisible_watermark_available,
logging,
randn_tensor,
replace_example_docstring,
)
from ..pipeline_utils import DiffusionPipeline
from . import StableDiffusionXLPipelineOutput
from .watermark import StableDiffusionXLWatermarker
if is_invisible_watermark_available():
from .watermark import StableDiffusionXLWatermarker
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -84,11 +88,11 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
In addition the pipeline inherits the following loading methods:
- *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
- *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
- *LoRA*: [`StableDiffusionXLPipeline.load_lora_weights`]
- *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`]
as well as the following saving methods:
- *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
- *LoRA*: [`loaders.StableDiffusionXLPipeline.save_lora_weights`]
Args:
vae ([`AutoencoderKL`]):
@@ -125,6 +129,7 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
unet: UNet2DConditionModel,
scheduler: KarrasDiffusionSchedulers,
force_zeros_for_empty_prompt: bool = True,
add_watermarker: Optional[bool] = None,
):
super().__init__()
@@ -142,7 +147,12 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.default_sample_size = self.unet.config.sample_size
self.watermark = StableDiffusionXLWatermarker()
add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
if add_watermarker:
self.watermark = StableDiffusionXLWatermarker()
else:
self.watermark = None
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
def enable_vae_slicing(self):
@@ -839,7 +849,10 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
image = latents
return StableDiffusionXLPipelineOutput(images=image)
image = self.watermark.apply_watermark(image)
# apply watermark if available
if self.watermark is not None:
image = self.watermark.apply_watermark(image)
image = self.image_processor.postprocess(image, output_type=output_type)
# Offload last model to CPU
@@ -853,14 +866,21 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
# Overrride to properly handle the loading and unloading of the additional text encoder.
def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
state_dict, network_alpha = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
self.load_lora_into_unet(state_dict, network_alpha=network_alpha, unet=self.unet)
# We could have accessed the unet config from `lora_state_dict()` too. We pass
# it here explicitly to be able to tell that it's coming from an SDXL
# pipeline.
state_dict, network_alphas = self.lora_state_dict(
pretrained_model_name_or_path_or_dict,
unet_config=self.unet.config,
**kwargs,
)
self.load_lora_into_unet(state_dict, network_alphas=network_alphas, unet=self.unet)
text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
if len(text_encoder_state_dict) > 0:
self.load_lora_into_text_encoder(
text_encoder_state_dict,
network_alpha=network_alpha,
network_alphas=network_alphas,
text_encoder=self.text_encoder,
prefix="text_encoder",
lora_scale=self.lora_scale,
@@ -870,7 +890,7 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
if len(text_encoder_2_state_dict) > 0:
self.load_lora_into_text_encoder(
text_encoder_2_state_dict,
network_alpha=network_alpha,
network_alphas=network_alphas,
text_encoder=self.text_encoder_2,
prefix="text_encoder_2",
lora_scale=self.lora_scale,

View File

@@ -33,13 +33,17 @@ from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
is_accelerate_available,
is_accelerate_version,
is_invisible_watermark_available,
logging,
randn_tensor,
replace_example_docstring,
)
from ..pipeline_utils import DiffusionPipeline
from . import StableDiffusionXLPipelineOutput
from .watermark import StableDiffusionXLWatermarker
if is_invisible_watermark_available():
from .watermark import StableDiffusionXLWatermarker
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -131,6 +135,7 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, L
scheduler: KarrasDiffusionSchedulers,
requires_aesthetics_score: bool = False,
force_zeros_for_empty_prompt: bool = True,
add_watermarker: Optional[bool] = None,
):
super().__init__()
@@ -148,7 +153,12 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, L
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.watermark = StableDiffusionXLWatermarker()
add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
if add_watermarker:
self.watermark = StableDiffusionXLWatermarker()
else:
self.watermark = None
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
def enable_vae_slicing(self):
@@ -906,15 +916,17 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, L
negative_aesthetic_score,
dtype=prompt_embeds.dtype,
)
add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1)
if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1)
add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0)
prompt_embeds = prompt_embeds.to(device)
add_text_embeds = add_text_embeds.to(device)
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
add_time_ids = add_time_ids.to(device)
# 9. Denoising loop
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
@@ -988,7 +1000,10 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, L
image = latents
return StableDiffusionXLPipelineOutput(images=image)
image = self.watermark.apply_watermark(image)
# apply watermark if available
if self.watermark is not None:
image = self.watermark.apply_watermark(image)
image = self.image_processor.postprocess(image, output_type=output_type)
# Offload last model to CPU

View File

@@ -30,10 +30,20 @@ from ...models.attention_processor import (
XFormersAttnProcessor,
)
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import is_accelerate_available, is_accelerate_version, logging, randn_tensor, replace_example_docstring
from ...utils import (
is_accelerate_available,
is_accelerate_version,
is_invisible_watermark_available,
logging,
randn_tensor,
replace_example_docstring,
)
from ..pipeline_utils import DiffusionPipeline
from . import StableDiffusionXLPipelineOutput
from .watermark import StableDiffusionXLWatermarker
if is_invisible_watermark_available():
from .watermark import StableDiffusionXLWatermarker
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -265,6 +275,7 @@ class StableDiffusionXLInpaintPipeline(
scheduler: KarrasDiffusionSchedulers,
requires_aesthetics_score: bool = False,
force_zeros_for_empty_prompt: bool = True,
add_watermarker: Optional[bool] = None,
):
super().__init__()
@@ -282,7 +293,12 @@ class StableDiffusionXLInpaintPipeline(
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.watermark = StableDiffusionXLWatermarker()
add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
if add_watermarker:
self.watermark = StableDiffusionXLWatermarker()
else:
self.watermark = None
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
def enable_vae_slicing(self):
@@ -1168,15 +1184,17 @@ class StableDiffusionXLInpaintPipeline(
negative_aesthetic_score,
dtype=prompt_embeds.dtype,
)
add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1)
if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1)
add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0)
prompt_embeds = prompt_embeds.to(device)
add_text_embeds = add_text_embeds.to(device)
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
add_time_ids = add_time_ids.to(device)
# 11. Denoising loop
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
@@ -1264,6 +1282,10 @@ class StableDiffusionXLInpaintPipeline(
else:
return StableDiffusionXLPipelineOutput(images=latents)
# apply watermark if available
if self.watermark is not None:
image = self.watermark.apply_watermark(image)
image = self.image_processor.postprocess(image, output_type=output_type)
# Offload last model to CPU

View File

@@ -34,12 +34,16 @@ from ...utils import (
deprecate,
is_accelerate_available,
is_accelerate_version,
is_invisible_watermark_available,
logging,
randn_tensor,
)
from ..pipeline_utils import DiffusionPipeline
from . import StableDiffusionXLPipelineOutput
from .watermark import StableDiffusionXLWatermarker
if is_invisible_watermark_available():
from .watermark import StableDiffusionXLWatermarker
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -109,6 +113,7 @@ class StableDiffusionXLInstructPix2PixPipeline(DiffusionPipeline, FromSingleFile
scheduler: KarrasDiffusionSchedulers,
requires_aesthetics_score: bool = False,
force_zeros_for_empty_prompt: bool = True,
add_watermarker: Optional[bool] = None,
):
super().__init__()
@@ -128,7 +133,12 @@ class StableDiffusionXLInstructPix2PixPipeline(DiffusionPipeline, FromSingleFile
self.vae.config.force_upcast = True # force the VAE to be in float32 mode, as it overflows in float16
self.watermark = StableDiffusionXLWatermarker()
add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
if add_watermarker:
self.watermark = StableDiffusionXLWatermarker()
else:
self.watermark = None
def enable_vae_slicing(self):
r"""
@@ -811,6 +821,7 @@ class StableDiffusionXLInstructPix2PixPipeline(DiffusionPipeline, FromSingleFile
negative_aesthetic_score,
dtype=prompt_embeds.dtype,
)
add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1)
original_prompt_embeds_len = len(prompt_embeds)
original_add_text_embeds_len = len(add_text_embeds)
@@ -819,6 +830,7 @@ class StableDiffusionXLInstructPix2PixPipeline(DiffusionPipeline, FromSingleFile
if do_classifier_free_guidance:
prompt_embeds = torch.cat([prompt_embeds, negative_prompt_embeds], dim=0)
add_text_embeds = torch.cat([add_text_embeds, negative_pooled_prompt_embeds], dim=0)
add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1)
add_time_ids = torch.cat([add_time_ids, add_neg_time_ids], dim=0)
# Make dimensions consistent
@@ -828,7 +840,7 @@ class StableDiffusionXLInstructPix2PixPipeline(DiffusionPipeline, FromSingleFile
prompt_embeds = prompt_embeds.to(device).to(torch.float32)
add_text_embeds = add_text_embeds.to(device).to(torch.float32)
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
add_time_ids = add_time_ids.to(device)
# 11. Denoising loop
self.unet = self.unet.to(torch.float32)
@@ -906,7 +918,10 @@ class StableDiffusionXLInstructPix2PixPipeline(DiffusionPipeline, FromSingleFile
image = latents
return StableDiffusionXLPipelineOutput(images=image)
image = self.watermark.apply_watermark(image)
# apply watermark if available
if self.watermark is not None:
image = self.watermark.apply_watermark(image)
image = self.image_processor.postprocess(image, output_type=output_type)
# Offload last model to CPU

View File

@@ -1,6 +1,11 @@
import numpy as np
import torch
from imwatermark import WatermarkEncoder
from ...utils import is_invisible_watermark_available
if is_invisible_watermark_available():
from imwatermark import WatermarkEncoder
# Copied from https://github.com/Stability-AI/generative-models/blob/613af104c6b85184091d42d374fef420eddb356d/scripts/demo/streamlit_helpers.py#L66

View File

@@ -1,77 +0,0 @@
# This file is autogenerated by the command `make fix-copies`, do not edit.
from ..utils import DummyObject, requires_backends
class StableDiffusionXLControlNetPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers", "invisible_watermark"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers", "invisible_watermark"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers", "invisible_watermark"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers", "invisible_watermark"])
class StableDiffusionXLImg2ImgPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers", "invisible_watermark"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers", "invisible_watermark"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers", "invisible_watermark"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers", "invisible_watermark"])
class StableDiffusionXLInpaintPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers", "invisible_watermark"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers", "invisible_watermark"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers", "invisible_watermark"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers", "invisible_watermark"])
class StableDiffusionXLInstructPix2PixPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers", "invisible_watermark"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers", "invisible_watermark"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers", "invisible_watermark"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers", "invisible_watermark"])
class StableDiffusionXLPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers", "invisible_watermark"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers", "invisible_watermark"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers", "invisible_watermark"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers", "invisible_watermark"])

View File

@@ -827,6 +827,81 @@ class StableDiffusionUpscalePipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class StableDiffusionXLControlNetPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class StableDiffusionXLImg2ImgPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class StableDiffusionXLInpaintPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class StableDiffusionXLInstructPix2PixPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class StableDiffusionXLPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class StableUnCLIPImg2ImgPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]

View File

@@ -737,8 +737,7 @@ class LoraIntegrationTests(unittest.TestCase):
).images
images = images[0, -3:, -3:, -1].flatten()
expected = np.array([0.3636, 0.3708, 0.3694, 0.3679, 0.3829, 0.3677, 0.3692, 0.3688, 0.3292])
expected = np.array([0.3725, 0.3767, 0.3761, 0.3796, 0.3827, 0.3763, 0.3831, 0.3809, 0.3392])
self.assertTrue(np.allclose(images, expected, atol=1e-4))

View File

@@ -64,7 +64,7 @@ class StableDiffusionXLImg2ImgPipelineFastTests(PipelineLatentTesterMixin, Pipel
addition_embed_type="text_time",
addition_time_embed_dim=8,
transformer_layers_per_block=(1, 2),
projection_class_embeddings_input_dim=80, # 6 * 8 + 32
projection_class_embeddings_input_dim=72, # 5 * 8 + 32
cross_attention_dim=64 if not skip_first_text_encoder else 32,
)
scheduler = EulerDiscreteScheduler(
@@ -113,9 +113,18 @@ class StableDiffusionXLImg2ImgPipelineFastTests(PipelineLatentTesterMixin, Pipel
"tokenizer": tokenizer if not skip_first_text_encoder else None,
"text_encoder_2": text_encoder_2,
"tokenizer_2": tokenizer_2,
"requires_aesthetics_score": True,
}
return components
def test_components_function(self):
init_components = self.get_dummy_components()
init_components.pop("requires_aesthetics_score")
pipe = self.pipeline_class(**init_components)
self.assertTrue(hasattr(pipe, "components"))
self.assertTrue(set(pipe.components.keys()) == set(init_components.keys()))
def get_dummy_inputs(self, device, seed=0):
image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
image = image / 2 + 0.5
@@ -147,7 +156,7 @@ class StableDiffusionXLImg2ImgPipelineFastTests(PipelineLatentTesterMixin, Pipel
assert image.shape == (1, 32, 32, 3)
expected_slice = np.array([0.4656, 0.4840, 0.4439, 0.6698, 0.5574, 0.4524, 0.5799, 0.5943, 0.5165])
expected_slice = np.array([0.4664, 0.4886, 0.4403, 0.6902, 0.5592, 0.4534, 0.5931, 0.5951, 0.5224])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
@@ -165,7 +174,7 @@ class StableDiffusionXLImg2ImgPipelineFastTests(PipelineLatentTesterMixin, Pipel
assert image.shape == (1, 32, 32, 3)
expected_slice = np.array([0.4676, 0.4865, 0.4335, 0.6715, 0.5578, 0.4497, 0.5847, 0.5967, 0.5198])
expected_slice = np.array([0.4578, 0.4981, 0.4301, 0.6454, 0.5588, 0.4442, 0.5678, 0.5940, 0.5176])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2

View File

@@ -66,7 +66,7 @@ class StableDiffusionXLInpaintPipelineFastTests(PipelineLatentTesterMixin, Pipel
addition_embed_type="text_time",
addition_time_embed_dim=8,
transformer_layers_per_block=(1, 2),
projection_class_embeddings_input_dim=80, # 6 * 8 + 32
projection_class_embeddings_input_dim=72, # 5 * 8 + 32
cross_attention_dim=64 if not skip_first_text_encoder else 32,
)
scheduler = EulerDiscreteScheduler(
@@ -115,6 +115,7 @@ class StableDiffusionXLInpaintPipelineFastTests(PipelineLatentTesterMixin, Pipel
"tokenizer": tokenizer if not skip_first_text_encoder else None,
"text_encoder_2": text_encoder_2,
"tokenizer_2": tokenizer_2,
"requires_aesthetics_score": True,
}
return components
@@ -142,6 +143,14 @@ class StableDiffusionXLInpaintPipelineFastTests(PipelineLatentTesterMixin, Pipel
}
return inputs
def test_components_function(self):
init_components = self.get_dummy_components()
init_components.pop("requires_aesthetics_score")
pipe = self.pipeline_class(**init_components)
self.assertTrue(hasattr(pipe, "components"))
self.assertTrue(set(pipe.components.keys()) == set(init_components.keys()))
def test_stable_diffusion_xl_inpaint_euler(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
@@ -155,7 +164,7 @@ class StableDiffusionXLInpaintPipelineFastTests(PipelineLatentTesterMixin, Pipel
assert image.shape == (1, 64, 64, 3)
expected_slice = np.array([0.6965, 0.5584, 0.5693, 0.5739, 0.6092, 0.6620, 0.5902, 0.5612, 0.5319])
expected_slice = np.array([0.8029, 0.5523, 0.5825, 0.6003, 0.6702, 0.7018, 0.6369, 0.5955, 0.5123])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
@@ -250,10 +259,9 @@ class StableDiffusionXLInpaintPipelineFastTests(PipelineLatentTesterMixin, Pipel
image = sd_pipe(**inputs).images
image_slice = image[0, -3:, -3:, -1]
print(torch.from_numpy(image_slice).flatten())
assert image.shape == (1, 64, 64, 3)
expected_slice = np.array([0.9106, 0.6563, 0.6766, 0.6537, 0.6709, 0.7367, 0.6537, 0.5937, 0.5418])
expected_slice = np.array([0.7045, 0.4838, 0.5454, 0.6270, 0.6168, 0.6717, 0.6484, 0.5681, 0.4922])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2

View File

@@ -68,7 +68,7 @@ class StableDiffusionXLInstructPix2PixPipelineFastTests(
addition_embed_type="text_time",
addition_time_embed_dim=8,
transformer_layers_per_block=(1, 2),
projection_class_embeddings_input_dim=80, # 6 * 8 + 32
projection_class_embeddings_input_dim=72, # 5 * 8 + 32
cross_attention_dim=64,
)
@@ -118,8 +118,7 @@ class StableDiffusionXLInstructPix2PixPipelineFastTests(
"tokenizer": tokenizer,
"text_encoder_2": text_encoder_2,
"tokenizer_2": tokenizer_2,
# "safety_checker": None,
# "feature_extractor": None,
"requires_aesthetics_score": True,
}
return components
@@ -141,6 +140,14 @@ class StableDiffusionXLInstructPix2PixPipelineFastTests(
}
return inputs
def test_components_function(self):
init_components = self.get_dummy_components()
init_components.pop("requires_aesthetics_score")
pipe = self.pipeline_class(**init_components)
self.assertTrue(hasattr(pipe, "components"))
self.assertTrue(set(pipe.components.keys()) == set(init_components.keys()))
def test_inference_batch_single_identical(self):
super().test_inference_batch_single_identical(expected_max_diff=3e-3)

View File

@@ -310,6 +310,49 @@ class DownloadTests(unittest.TestCase):
assert len([f for f in files if ".bin" in f]) == 8
assert not any(".safetensors" in f for f in files)
def test_download_no_openvino_by_default(self):
with tempfile.TemporaryDirectory() as tmpdirname:
tmpdirname = DiffusionPipeline.download(
"hf-internal-testing/tiny-stable-diffusion-open-vino",
cache_dir=tmpdirname,
)
all_root_files = [t[-1] for t in os.walk(os.path.join(tmpdirname))]
files = [item for sublist in all_root_files for item in sublist]
# make sure that by default no openvino weights are downloaded
assert all((f.endswith(".json") or f.endswith(".bin") or f.endswith(".txt")) for f in files)
assert not any("openvino_" in f for f in files)
def test_download_no_onnx_by_default(self):
with tempfile.TemporaryDirectory() as tmpdirname:
tmpdirname = DiffusionPipeline.download(
"hf-internal-testing/tiny-random-OnnxStableDiffusionPipeline",
cache_dir=tmpdirname,
)
all_root_files = [t[-1] for t in os.walk(os.path.join(tmpdirname))]
files = [item for sublist in all_root_files for item in sublist]
# make sure that by default no onnx weights are downloaded
assert all((f.endswith(".json") or f.endswith(".bin") or f.endswith(".txt")) for f in files)
assert not any((f.endswith(".onnx") or f.endswith(".pb")) for f in files)
with tempfile.TemporaryDirectory() as tmpdirname:
tmpdirname = DiffusionPipeline.download(
"hf-internal-testing/tiny-random-OnnxStableDiffusionPipeline",
cache_dir=tmpdirname,
use_onnx=True,
)
all_root_files = [t[-1] for t in os.walk(os.path.join(tmpdirname))]
files = [item for sublist in all_root_files for item in sublist]
# if `use_onnx` is specified make sure weights are downloaded
assert any((f.endswith(".json") or f.endswith(".bin") or f.endswith(".txt")) for f in files)
assert any((f.endswith(".onnx")) for f in files)
assert any((f.endswith(".pb")) for f in files)
def test_download_no_safety_checker(self):
prompt = "hello"
pipe = StableDiffusionPipeline.from_pretrained(