mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-07 04:54:47 +08:00
Compare commits
14 Commits
inpainting
...
v0.19.3
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
de9c72d58c | ||
|
|
7b022df49c | ||
|
|
965e52ce61 | ||
|
|
b1e52794a2 | ||
|
|
c3e3a1ee10 | ||
|
|
9cde56a729 | ||
|
|
c63d7cdba0 | ||
|
|
aa4634a7fa | ||
|
|
0709650e9d | ||
|
|
a9829164f4 | ||
|
|
49c95178ad | ||
|
|
c2f755bc62 | ||
|
|
2fb877b66c | ||
|
|
ef9824f9f7 |
@@ -39,8 +39,8 @@ Currently AutoPipeline support the Text-to-Image, Image-to-Image, and Inpainting
|
||||
- [Stable Diffusion Controlnet](./api/pipelines/controlnet)
|
||||
- [Stable Diffusion XL](./stable_diffusion/stable_diffusion_xl)
|
||||
- [IF](./if)
|
||||
- [Kandinsky](./kandinsky)(./kandinsky)(./kandinsky)(./kandinsky)(./kandinsky)
|
||||
- [Kandinsky 2.2]()(./kandinsky)
|
||||
- [Kandinsky](./kandinsky)
|
||||
- [Kandinsky 2.2](./kandinsky)
|
||||
|
||||
|
||||
## AutoPipelineForText2Image
|
||||
|
||||
@@ -105,6 +105,30 @@ One cheeseburger monster coming up! Enjoy!
|
||||
|
||||

|
||||
|
||||
<Tip>
|
||||
|
||||
We also provide an end-to-end Kandinsky pipeline [`KandinskyCombinedPipeline`], which combines both the prior pipeline and text-to-image pipeline, and lets you perform inference in a single step. You can create the combined pipeline with the [`~AutoPipelineForTextToImage.from_pretrained`] method
|
||||
|
||||
```python
|
||||
from diffusers import AutoPipelineForTextToImage
|
||||
import torch
|
||||
|
||||
pipe = AutoPipelineForTextToImage.from_pretrained(
|
||||
"kandinsky-community/kandinsky-2-1", torch_dtype=torch.float16
|
||||
)
|
||||
pipe.enable_model_cpu_offload()
|
||||
```
|
||||
|
||||
Under the hood, it will automatically load both [`KandinskyPriorPipeline`] and [`KandinskyPipeline`]. To generate images, you no longer need to call both pipelines and pass the outputs from one to another. You only need to call the combined pipeline once. You can set different `guidance_scale` and `num_inference_steps` for the prior pipeline with the `prior_guidance_scale` and `prior_num_inference_steps` arguments.
|
||||
|
||||
```python
|
||||
prompt = "A alien cheeseburger creature eating itself, claymation, cinematic, moody lighting"
|
||||
negative_prompt = "low quality, bad quality"
|
||||
|
||||
image = pipe(prompt=prompt, negative_prompt=negative_prompt, prior_guidance_scale =1.0, guidance_scacle = 4.0, height=768, width=768).images[0]
|
||||
```
|
||||
</Tip>
|
||||
|
||||
The Kandinsky model works extremely well with creative prompts. Here is some of the amazing art that can be created using the exact same process but with different prompts.
|
||||
|
||||
```python
|
||||
@@ -187,6 +211,34 @@ out.images[0].save("fantasy_land.png")
|
||||

|
||||
|
||||
|
||||
<Tip>
|
||||
|
||||
You can also use the [`KandinskyImg2ImgCombinedPipeline`] for end-to-end image-to-image generation with Kandinsky 2.1
|
||||
|
||||
```python
|
||||
from diffusers import AutoPipelineForImage2Image
|
||||
import torch
|
||||
import requests
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
import os
|
||||
|
||||
pipe = AutoPipelineForImage2Image.from_pretrained("kandinsky-community/kandinsky-2-1", torch_dtype=torch.float16)
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
prompt = "A fantasy landscape, Cinematic lighting"
|
||||
negative_prompt = "low quality, bad quality"
|
||||
|
||||
url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
|
||||
|
||||
response = requests.get(url)
|
||||
original_image = Image.open(BytesIO(response.content)).convert("RGB")
|
||||
original_image.thumbnail((768, 768))
|
||||
|
||||
image = pipe(prompt=prompt, image=original_image, strength=0.3).images[0]
|
||||
```
|
||||
</Tip>
|
||||
|
||||
### Text Guided Inpainting Generation
|
||||
|
||||
You can use [`KandinskyInpaintPipeline`] to edit images. In this example, we will add a hat to the portrait of a cat.
|
||||
@@ -231,6 +283,33 @@ image.save("cat_with_hat.png")
|
||||
```
|
||||

|
||||
|
||||
<Tip>
|
||||
|
||||
To use the [`KandinskyInpaintCombinedPipeline`] to perform end-to-end image inpainting generation, you can run below code instead
|
||||
|
||||
```python
|
||||
from diffusers import AutoPipelineForInpainting
|
||||
|
||||
pipe = AutoPipelineForInpainting.from_pretrained("kandinsky-community/kandinsky-2-1-inpaint", torch_dtype=torch.float16)
|
||||
pipe.enable_model_cpu_offload()
|
||||
image = pipe(prompt=prompt, image=original_image, mask_image=mask).images[0]
|
||||
```
|
||||
</Tip>
|
||||
|
||||
🚨🚨🚨 __Breaking change for Kandinsky Mask Inpainting__ 🚨🚨🚨
|
||||
|
||||
We introduced a breaking change for Kandinsky inpainting pipeline in the following pull request: https://github.com/huggingface/diffusers/pull/4207. Previously we accepted a mask format where black pixels represent the masked-out area. This is inconsistent with all other pipelines in diffusers. We have changed the mask format in Knaindsky and now using white pixels instead.
|
||||
Please upgrade your inpainting code to follow the above. If you are using Kandinsky Inpaint in production. You now need to change the mask to:
|
||||
|
||||
```python
|
||||
# For PIL input
|
||||
import PIL.ImageOps
|
||||
mask = PIL.ImageOps.invert(mask)
|
||||
|
||||
# For PyTorch and Numpy input
|
||||
mask = 1 - mask
|
||||
```
|
||||
|
||||
### Interpolate
|
||||
|
||||
The [`KandinskyPriorPipeline`] also comes with a cool utility function that will allow you to interpolate the latent space of different images and texts super easily. Here is an example of how you can create an Impressionist-style portrait for your pet based on "The Starry Night".
|
||||
|
||||
@@ -11,7 +11,22 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
The Kandinsky 2.2 release includes robust new text-to-image models that support text-to-image generation, image-to-image generation, image interpolation, and text-guided image inpainting. The general workflow to perform these tasks using Kandinsky 2.2 is the same as in Kandinsky 2.1. First, you will need to use a prior pipeline to generate image embeddings based on your text prompt, and then use one of the image decoding pipelines to generate the output image. The only difference is that in Kandinsky 2.2, all of the decoding pipelines no longer accept the `prompt` input, and the image generation process is conditioned with only `image_embeds` and `negative_image_embeds`.
|
||||
|
||||
Let's look at an example of how to perform text-to-image generation using Kandinsky 2.2.
|
||||
Same as with Kandinsky 2.1, the easiest way to perform text-to-image generation is to use the combined Kandinsky pipeline. This process is exactly the same as Kandinsky 2.1. All you need to do is to replace the Kandinsky 2.1 checkpoint with 2.2.
|
||||
|
||||
```python
|
||||
from diffusers import AutoPipelineForText2Image
|
||||
import torch
|
||||
|
||||
pipe = AutoPipelineForText2Image.from_pretrained("kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16)
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
prompt = "A alien cheeseburger creature eating itself, claymation, cinematic, moody lighting"
|
||||
negative_prompt = "low quality, bad quality"
|
||||
|
||||
image = pipe(prompt=prompt, negative_prompt=negative_prompt, prior_guidance_scale =1.0, height=768, width=768).images[0]
|
||||
```
|
||||
|
||||
Now, let's look at an example where we take separate steps to run the prior pipeline and text-to-image pipeline. This way, we can understand what's happening under the hood and how Kandinsky 2.2 differs from Kandinsky 2.1.
|
||||
|
||||
First, let's create the prior pipeline and text-to-image pipeline with Kandinsky 2.2 checkpoints.
|
||||
|
||||
|
||||
@@ -38,9 +38,25 @@ You can install the libraries as follows:
|
||||
pip install transformers
|
||||
pip install accelerate
|
||||
pip install safetensors
|
||||
```
|
||||
|
||||
### Watermarker
|
||||
|
||||
We recommend to add an invisible watermark to images generating by Stable Diffusion XL, this can help with identifying if an image is machine-synthesised for downstream applications. To do so, please install
|
||||
the [invisible-watermark library](https://pypi.org/project/invisible-watermark/) via:
|
||||
|
||||
```
|
||||
pip install invisible-watermark>=0.2.0
|
||||
```
|
||||
|
||||
If the `invisible-watermark` library is installed the watermarker will be used **by default**.
|
||||
|
||||
If you have other provisions for generating or deploying images safely, you can disable the watermarker as follows:
|
||||
|
||||
```py
|
||||
pipe = StableDiffusionXLPipeline.from_pretrained(..., add_watermarker=False)
|
||||
```
|
||||
|
||||
### Text-to-Image
|
||||
|
||||
You can use SDXL as follows for *text-to-image*:
|
||||
|
||||
@@ -354,4 +354,52 @@ directly with [`~diffusers.loaders.LoraLoaderMixin.load_lora_weights`] like so:
|
||||
lora_model_id = "sayakpaul/civitai-light-shadow-lora"
|
||||
lora_filename = "light_and_shadow.safetensors"
|
||||
pipeline.load_lora_weights(lora_model_id, weight_name=lora_filename)
|
||||
```
|
||||
```
|
||||
|
||||
### Supporting Stable Diffusion XL LoRAs trained using the Kohya-trainer
|
||||
|
||||
With this [PR](https://github.com/huggingface/diffusers/pull/4287), there should now be better support for loading Kohya-style LoRAs trained on Stable Diffusion XL (SDXL).
|
||||
|
||||
Here are some example checkpoints we tried out:
|
||||
|
||||
* SDXL 0.9:
|
||||
* https://civitai.com/models/22279?modelVersionId=118556
|
||||
* https://civitai.com/models/104515/sdxlor30costumesrevue-starlight-saijoclaudine-lora
|
||||
* https://civitai.com/models/108448/daiton-sdxl-test
|
||||
* https://filebin.net/2ntfqqnapiu9q3zx/pixelbuildings128-v1.safetensors
|
||||
* SDXL 1.0:
|
||||
* https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_offset_example-lora_1.0.safetensors
|
||||
|
||||
Here is an example of how to perform inference with these checkpoints in `diffusers`:
|
||||
|
||||
```python
|
||||
from diffusers import DiffusionPipeline
|
||||
import torch
|
||||
|
||||
base_model_id = "stabilityai/stable-diffusion-xl-base-0.9"
|
||||
pipeline = DiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16).to("cuda")
|
||||
pipeline.load_lora_weights(".", weight_name="Kamepan.safetensors")
|
||||
|
||||
prompt = "anime screencap, glint, drawing, best quality, light smile, shy, a full body of a girl wearing wedding dress in the middle of the forest beneath the trees, fireflies, big eyes, 2d, cute, anime girl, waifu, cel shading, magical girl, vivid colors, (outline:1.1), manga anime artstyle, masterpiece, offical wallpaper, glint <lora:kame_sdxl_v2:1>"
|
||||
negative_prompt = "(deformed, bad quality, sketch, depth of field, blurry:1.1), grainy, bad anatomy, bad perspective, old, ugly, realistic, cartoon, disney, bad propotions"
|
||||
generator = torch.manual_seed(2947883060)
|
||||
num_inference_steps = 30
|
||||
guidance_scale = 7
|
||||
|
||||
image = pipeline(
|
||||
prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=num_inference_steps,
|
||||
generator=generator, guidance_scale=guidance_scale
|
||||
).images[0]
|
||||
image.save("Kamepan.png")
|
||||
```
|
||||
|
||||
`Kamepan.safetensors` comes from https://civitai.com/models/22279?modelVersionId=118556 .
|
||||
|
||||
If you notice carefully, the inference UX is exactly identical to what we presented in the sections above.
|
||||
|
||||
Thanks to [@isidentical](https://github.com/isidentical) for helping us on integrating this feature.
|
||||
|
||||
### Known limitations specific to the Kohya-styled LoRAs
|
||||
|
||||
* SDXL LoRAs that have both the text encoders are currently leading to weird results. We're actively investigating the issue.
|
||||
* When images don't looks similar to other UIs such ComfyUI, it can be beacause of multiple reasons as explained [here](https://github.com/huggingface/diffusers/pull/4287/#issuecomment-1655110736).
|
||||
@@ -4,6 +4,5 @@ transformers>=4.25.1
|
||||
ftfy
|
||||
tensorboard
|
||||
Jinja2
|
||||
invisible-watermark>=0.2.0
|
||||
datasets
|
||||
wandb
|
||||
|
||||
@@ -56,7 +56,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.19.0.dev0")
|
||||
check_min_version("0.19.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -59,7 +59,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.19.0.dev0")
|
||||
check_min_version("0.19.0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -58,7 +58,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.18.0.dev0")
|
||||
check_min_version("0.19.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -57,7 +57,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.19.0.dev0")
|
||||
check_min_version("0.19.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -4,4 +4,3 @@ transformers>=4.25.1
|
||||
ftfy
|
||||
tensorboard
|
||||
Jinja2
|
||||
invisible-watermark>=0.2.0
|
||||
@@ -59,7 +59,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.19.0.dev0")
|
||||
check_min_version("0.19.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -36,7 +36,7 @@ from diffusers.utils import check_min_version
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.19.0.dev0")
|
||||
check_min_version("0.19.0")
|
||||
|
||||
# Cache compiled models across invocations of this script.
|
||||
cc.initialize_cache(os.path.expanduser("~/.cache/jax/compilation_cache"))
|
||||
|
||||
@@ -69,7 +69,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.19.0.dev0")
|
||||
check_min_version("0.19.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -924,10 +924,10 @@ def main(args):
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
|
||||
lora_state_dict, network_alpha = LoraLoaderMixin.lora_state_dict(input_dir)
|
||||
LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alpha=network_alpha, unet=unet_)
|
||||
lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir)
|
||||
LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_)
|
||||
LoraLoaderMixin.load_lora_into_text_encoder(
|
||||
lora_state_dict, network_alpha=network_alpha, text_encoder=text_encoder_
|
||||
lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_
|
||||
)
|
||||
|
||||
accelerator.register_save_state_pre_hook(save_model_hook)
|
||||
|
||||
@@ -58,7 +58,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.19.0.dev0")
|
||||
check_min_version("0.19.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -829,13 +829,13 @@ def main(args):
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
|
||||
lora_state_dict, network_alpha = LoraLoaderMixin.lora_state_dict(input_dir)
|
||||
LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alpha=network_alpha, unet=unet_)
|
||||
lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir)
|
||||
LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_)
|
||||
LoraLoaderMixin.load_lora_into_text_encoder(
|
||||
lora_state_dict, network_alpha=network_alpha, text_encoder=text_encoder_one_
|
||||
lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_one_
|
||||
)
|
||||
LoraLoaderMixin.load_lora_into_text_encoder(
|
||||
lora_state_dict, network_alpha=network_alpha, text_encoder=text_encoder_two_
|
||||
lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_two_
|
||||
)
|
||||
|
||||
accelerator.register_save_state_pre_hook(save_model_hook)
|
||||
|
||||
@@ -52,7 +52,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.19.0.dev0")
|
||||
check_min_version("0.19.0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -55,7 +55,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.18.0.dev0")
|
||||
check_min_version("0.19.0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -54,7 +54,7 @@ if is_wandb_available():
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.19.0.dev0")
|
||||
check_min_version("0.19.0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -33,7 +33,7 @@ from diffusers.utils import check_min_version
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.19.0.dev0")
|
||||
check_min_version("0.19.0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -48,7 +48,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.19.0.dev0")
|
||||
check_min_version("0.19.0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -78,7 +78,7 @@ else:
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.19.0.dev0")
|
||||
check_min_version("0.19.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -56,7 +56,7 @@ else:
|
||||
# ------------------------------------------------------------------------------
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.19.0.dev0")
|
||||
check_min_version("0.19.0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -30,7 +30,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.19.0.dev0")
|
||||
check_min_version("0.19.0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
2
setup.py
2
setup.py
@@ -233,7 +233,7 @@ install_requires = [
|
||||
|
||||
setup(
|
||||
name="diffusers",
|
||||
version="0.19.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
version="0.19.3", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
description="Diffusers",
|
||||
long_description=open("README.md", "r", encoding="utf-8").read(),
|
||||
long_description_content_type="text/markdown",
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
__version__ = "0.19.0.dev0"
|
||||
__version__ = "0.19.3"
|
||||
|
||||
from .configuration_utils import ConfigMixin
|
||||
from .utils import (
|
||||
@@ -185,6 +185,11 @@ else:
|
||||
StableDiffusionPix2PixZeroPipeline,
|
||||
StableDiffusionSAGPipeline,
|
||||
StableDiffusionUpscalePipeline,
|
||||
StableDiffusionXLControlNetPipeline,
|
||||
StableDiffusionXLImg2ImgPipeline,
|
||||
StableDiffusionXLInpaintPipeline,
|
||||
StableDiffusionXLInstructPix2PixPipeline,
|
||||
StableDiffusionXLPipeline,
|
||||
StableUnCLIPImg2ImgPipeline,
|
||||
StableUnCLIPPipeline,
|
||||
TextToVideoSDPipeline,
|
||||
@@ -202,20 +207,6 @@ else:
|
||||
VQDiffusionPipeline,
|
||||
)
|
||||
|
||||
try:
|
||||
if not (is_torch_available() and is_transformers_available() and is_invisible_watermark_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from .utils.dummy_torch_and_transformers_and_invisible_watermark_objects import * # noqa F403
|
||||
else:
|
||||
from .pipelines import (
|
||||
StableDiffusionXLControlNetPipeline,
|
||||
StableDiffusionXLImg2ImgPipeline,
|
||||
StableDiffusionXLInpaintPipeline,
|
||||
StableDiffusionXLInstructPix2PixPipeline,
|
||||
StableDiffusionXLPipeline,
|
||||
)
|
||||
|
||||
try:
|
||||
if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
import re
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from contextlib import nullcontext
|
||||
@@ -56,7 +57,6 @@ UNET_NAME = "unet"
|
||||
|
||||
LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
|
||||
LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
|
||||
TOTAL_EXAMPLE_KEYS = 5
|
||||
|
||||
TEXT_INVERSION_NAME = "learned_embeds.bin"
|
||||
TEXT_INVERSION_NAME_SAFE = "learned_embeds.safetensors"
|
||||
@@ -257,7 +257,7 @@ class UNet2DConditionLoadersMixin:
|
||||
use_safetensors = kwargs.pop("use_safetensors", None)
|
||||
# This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
|
||||
# See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
|
||||
network_alpha = kwargs.pop("network_alpha", None)
|
||||
network_alphas = kwargs.pop("network_alphas", None)
|
||||
|
||||
if use_safetensors and not is_safetensors_available():
|
||||
raise ValueError(
|
||||
@@ -322,7 +322,7 @@ class UNet2DConditionLoadersMixin:
|
||||
attn_processors = {}
|
||||
non_attn_lora_layers = []
|
||||
|
||||
is_lora = all("lora" in k for k in state_dict.keys())
|
||||
is_lora = all(("lora" in k or k.endswith(".alpha")) for k in state_dict.keys())
|
||||
is_custom_diffusion = any("custom_diffusion" in k for k in state_dict.keys())
|
||||
|
||||
if is_lora:
|
||||
@@ -339,10 +339,25 @@ class UNet2DConditionLoadersMixin:
|
||||
state_dict = {k.replace(f"{self.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys}
|
||||
|
||||
lora_grouped_dict = defaultdict(dict)
|
||||
for key, value in state_dict.items():
|
||||
mapped_network_alphas = {}
|
||||
|
||||
all_keys = list(state_dict.keys())
|
||||
for key in all_keys:
|
||||
value = state_dict.pop(key)
|
||||
attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:])
|
||||
lora_grouped_dict[attn_processor_key][sub_key] = value
|
||||
|
||||
# Create another `mapped_network_alphas` dictionary so that we can properly map them.
|
||||
if network_alphas is not None:
|
||||
for k in network_alphas:
|
||||
if k.replace(".alpha", "") in key:
|
||||
mapped_network_alphas.update({attn_processor_key: network_alphas[k]})
|
||||
|
||||
if len(state_dict) > 0:
|
||||
raise ValueError(
|
||||
f"The state_dict has to be empty at this point but has the following keys \n\n {', '.join(state_dict.keys())}"
|
||||
)
|
||||
|
||||
for key, value_dict in lora_grouped_dict.items():
|
||||
attn_processor = self
|
||||
for sub_key in key.split("."):
|
||||
@@ -352,13 +367,27 @@ class UNet2DConditionLoadersMixin:
|
||||
# or add_{k,v,q,out_proj}_proj_lora layers.
|
||||
if "lora.down.weight" in value_dict:
|
||||
rank = value_dict["lora.down.weight"].shape[0]
|
||||
hidden_size = value_dict["lora.up.weight"].shape[0]
|
||||
|
||||
if isinstance(attn_processor, LoRACompatibleConv):
|
||||
lora = LoRAConv2dLayer(hidden_size, hidden_size, rank, network_alpha)
|
||||
in_features = attn_processor.in_channels
|
||||
out_features = attn_processor.out_channels
|
||||
kernel_size = attn_processor.kernel_size
|
||||
|
||||
lora = LoRAConv2dLayer(
|
||||
in_features=in_features,
|
||||
out_features=out_features,
|
||||
rank=rank,
|
||||
kernel_size=kernel_size,
|
||||
stride=attn_processor.stride,
|
||||
padding=attn_processor.padding,
|
||||
network_alpha=mapped_network_alphas.get(key),
|
||||
)
|
||||
elif isinstance(attn_processor, LoRACompatibleLinear):
|
||||
lora = LoRALinearLayer(
|
||||
attn_processor.in_features, attn_processor.out_features, rank, network_alpha
|
||||
attn_processor.in_features,
|
||||
attn_processor.out_features,
|
||||
rank,
|
||||
mapped_network_alphas.get(key),
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Module {key} is not a LoRACompatibleConv or LoRACompatibleLinear module.")
|
||||
@@ -366,32 +395,64 @@ class UNet2DConditionLoadersMixin:
|
||||
value_dict = {k.replace("lora.", ""): v for k, v in value_dict.items()}
|
||||
lora.load_state_dict(value_dict)
|
||||
non_attn_lora_layers.append((attn_processor, lora))
|
||||
continue
|
||||
|
||||
rank = value_dict["to_k_lora.down.weight"].shape[0]
|
||||
hidden_size = value_dict["to_k_lora.up.weight"].shape[0]
|
||||
|
||||
if isinstance(
|
||||
attn_processor, (AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0)
|
||||
):
|
||||
cross_attention_dim = value_dict["add_k_proj_lora.down.weight"].shape[1]
|
||||
attn_processor_class = LoRAAttnAddedKVProcessor
|
||||
else:
|
||||
cross_attention_dim = value_dict["to_k_lora.down.weight"].shape[1]
|
||||
if isinstance(attn_processor, (XFormersAttnProcessor, LoRAXFormersAttnProcessor)):
|
||||
attn_processor_class = LoRAXFormersAttnProcessor
|
||||
# To handle SDXL.
|
||||
rank_mapping = {}
|
||||
hidden_size_mapping = {}
|
||||
for projection_id in ["to_k", "to_q", "to_v", "to_out"]:
|
||||
rank = value_dict[f"{projection_id}_lora.down.weight"].shape[0]
|
||||
hidden_size = value_dict[f"{projection_id}_lora.up.weight"].shape[0]
|
||||
|
||||
rank_mapping.update({f"{projection_id}_lora.down.weight": rank})
|
||||
hidden_size_mapping.update({f"{projection_id}_lora.up.weight": hidden_size})
|
||||
|
||||
if isinstance(
|
||||
attn_processor, (AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0)
|
||||
):
|
||||
cross_attention_dim = value_dict["add_k_proj_lora.down.weight"].shape[1]
|
||||
attn_processor_class = LoRAAttnAddedKVProcessor
|
||||
else:
|
||||
attn_processor_class = (
|
||||
LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
|
||||
cross_attention_dim = value_dict["to_k_lora.down.weight"].shape[1]
|
||||
if isinstance(attn_processor, (XFormersAttnProcessor, LoRAXFormersAttnProcessor)):
|
||||
attn_processor_class = LoRAXFormersAttnProcessor
|
||||
else:
|
||||
attn_processor_class = (
|
||||
LoRAAttnProcessor2_0
|
||||
if hasattr(F, "scaled_dot_product_attention")
|
||||
else LoRAAttnProcessor
|
||||
)
|
||||
|
||||
if attn_processor_class is not LoRAAttnAddedKVProcessor:
|
||||
attn_processors[key] = attn_processor_class(
|
||||
rank=rank_mapping.get("to_k_lora.down.weight"),
|
||||
hidden_size=hidden_size_mapping.get("to_k_lora.up.weight"),
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
network_alpha=mapped_network_alphas.get(key),
|
||||
q_rank=rank_mapping.get("to_q_lora.down.weight"),
|
||||
q_hidden_size=hidden_size_mapping.get("to_q_lora.up.weight"),
|
||||
v_rank=rank_mapping.get("to_v_lora.down.weight"),
|
||||
v_hidden_size=hidden_size_mapping.get("to_v_lora.up.weight"),
|
||||
out_rank=rank_mapping.get("to_out_lora.down.weight"),
|
||||
out_hidden_size=hidden_size_mapping.get("to_out_lora.up.weight"),
|
||||
# rank=rank_mapping.get("to_k_lora.down.weight", None),
|
||||
# hidden_size=hidden_size_mapping.get("to_k_lora.up.weight", None),
|
||||
# q_rank=rank_mapping.get("to_q_lora.down.weight", None),
|
||||
# q_hidden_size=hidden_size_mapping.get("to_q_lora.up.weight", None),
|
||||
# v_rank=rank_mapping.get("to_v_lora.down.weight", None),
|
||||
# v_hidden_size=hidden_size_mapping.get("to_v_lora.up.weight", None),
|
||||
# out_rank=rank_mapping.get("to_out_lora.down.weight", None),
|
||||
# out_hidden_size=hidden_size_mapping.get("to_out_lora.up.weight", None),
|
||||
)
|
||||
else:
|
||||
attn_processors[key] = attn_processor_class(
|
||||
rank=rank_mapping.get("to_k_lora.down.weight", None),
|
||||
hidden_size=hidden_size_mapping.get("to_k_lora.up.weight", None),
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
network_alpha=mapped_network_alphas.get(key),
|
||||
)
|
||||
|
||||
attn_processors[key] = attn_processor_class(
|
||||
hidden_size=hidden_size,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
rank=rank,
|
||||
network_alpha=network_alpha,
|
||||
)
|
||||
attn_processors[key].load_state_dict(value_dict)
|
||||
attn_processors[key].load_state_dict(value_dict)
|
||||
|
||||
elif is_custom_diffusion:
|
||||
custom_diffusion_grouped_dict = defaultdict(dict)
|
||||
for key, value in state_dict.items():
|
||||
@@ -434,8 +495,10 @@ class UNet2DConditionLoadersMixin:
|
||||
|
||||
# set ff layers
|
||||
for target_module, lora_layer in non_attn_lora_layers:
|
||||
if hasattr(target_module, "set_lora_layer"):
|
||||
target_module.set_lora_layer(lora_layer)
|
||||
target_module.set_lora_layer(lora_layer)
|
||||
# It should raise an error if we don't have a set lora here
|
||||
# if hasattr(target_module, "set_lora_layer"):
|
||||
# target_module.set_lora_layer(lora_layer)
|
||||
|
||||
def save_attn_procs(
|
||||
self,
|
||||
@@ -880,11 +943,11 @@ class LoraLoaderMixin:
|
||||
kwargs (`dict`, *optional*):
|
||||
See [`~loaders.LoraLoaderMixin.lora_state_dict`].
|
||||
"""
|
||||
state_dict, network_alpha = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
|
||||
self.load_lora_into_unet(state_dict, network_alpha=network_alpha, unet=self.unet)
|
||||
state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
|
||||
self.load_lora_into_unet(state_dict, network_alphas=network_alphas, unet=self.unet)
|
||||
self.load_lora_into_text_encoder(
|
||||
state_dict,
|
||||
network_alpha=network_alpha,
|
||||
network_alphas=network_alphas,
|
||||
text_encoder=self.text_encoder,
|
||||
lora_scale=self.lora_scale,
|
||||
)
|
||||
@@ -896,7 +959,7 @@ class LoraLoaderMixin:
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Return state dict for lora weights
|
||||
Return state dict for lora weights and the network alphas.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
@@ -957,6 +1020,7 @@ class LoraLoaderMixin:
|
||||
revision = kwargs.pop("revision", None)
|
||||
subfolder = kwargs.pop("subfolder", None)
|
||||
weight_name = kwargs.pop("weight_name", None)
|
||||
unet_config = kwargs.pop("unet_config", None)
|
||||
use_safetensors = kwargs.pop("use_safetensors", None)
|
||||
|
||||
if use_safetensors and not is_safetensors_available():
|
||||
@@ -1018,53 +1082,158 @@ class LoraLoaderMixin:
|
||||
else:
|
||||
state_dict = pretrained_model_name_or_path_or_dict
|
||||
|
||||
# Convert kohya-ss Style LoRA attn procs to diffusers attn procs
|
||||
network_alpha = None
|
||||
if all((k.startswith("lora_te_") or k.startswith("lora_unet_")) for k in state_dict.keys()):
|
||||
state_dict, network_alpha = cls._convert_kohya_lora_to_diffusers(state_dict)
|
||||
network_alphas = None
|
||||
if all(
|
||||
(
|
||||
k.startswith("lora_te_")
|
||||
or k.startswith("lora_unet_")
|
||||
or k.startswith("lora_te1_")
|
||||
or k.startswith("lora_te2_")
|
||||
)
|
||||
for k in state_dict.keys()
|
||||
):
|
||||
# Map SDXL blocks correctly.
|
||||
if unet_config is not None:
|
||||
# use unet config to remap block numbers
|
||||
state_dict = cls._map_sgm_blocks_to_diffusers(state_dict, unet_config)
|
||||
state_dict, network_alphas = cls._convert_kohya_lora_to_diffusers(state_dict)
|
||||
|
||||
return state_dict, network_alpha
|
||||
return state_dict, network_alphas
|
||||
|
||||
@classmethod
|
||||
def load_lora_into_unet(cls, state_dict, network_alpha, unet):
|
||||
def _map_sgm_blocks_to_diffusers(cls, state_dict, unet_config, delimiter="_", block_slice_pos=5):
|
||||
is_all_unet = all(k.startswith("lora_unet") for k in state_dict)
|
||||
new_state_dict = {}
|
||||
inner_block_map = ["resnets", "attentions", "upsamplers"]
|
||||
|
||||
# Retrieves # of down, mid and up blocks
|
||||
input_block_ids, middle_block_ids, output_block_ids = set(), set(), set()
|
||||
for layer in state_dict:
|
||||
if "text" not in layer:
|
||||
layer_id = int(layer.split(delimiter)[:block_slice_pos][-1])
|
||||
if "input_blocks" in layer:
|
||||
input_block_ids.add(layer_id)
|
||||
elif "middle_block" in layer:
|
||||
middle_block_ids.add(layer_id)
|
||||
elif "output_blocks" in layer:
|
||||
output_block_ids.add(layer_id)
|
||||
else:
|
||||
raise ValueError("Checkpoint not supported")
|
||||
|
||||
input_blocks = {
|
||||
layer_id: [key for key in state_dict if f"input_blocks{delimiter}{layer_id}" in key]
|
||||
for layer_id in input_block_ids
|
||||
}
|
||||
middle_blocks = {
|
||||
layer_id: [key for key in state_dict if f"middle_block{delimiter}{layer_id}" in key]
|
||||
for layer_id in middle_block_ids
|
||||
}
|
||||
output_blocks = {
|
||||
layer_id: [key for key in state_dict if f"output_blocks{delimiter}{layer_id}" in key]
|
||||
for layer_id in output_block_ids
|
||||
}
|
||||
|
||||
# Rename keys accordingly
|
||||
for i in input_block_ids:
|
||||
block_id = (i - 1) // (unet_config.layers_per_block + 1)
|
||||
layer_in_block_id = (i - 1) % (unet_config.layers_per_block + 1)
|
||||
|
||||
for key in input_blocks[i]:
|
||||
inner_block_id = int(key.split(delimiter)[block_slice_pos])
|
||||
inner_block_key = inner_block_map[inner_block_id] if "op" not in key else "downsamplers"
|
||||
inner_layers_in_block = str(layer_in_block_id) if "op" not in key else "0"
|
||||
new_key = delimiter.join(
|
||||
key.split(delimiter)[: block_slice_pos - 1]
|
||||
+ [str(block_id), inner_block_key, inner_layers_in_block]
|
||||
+ key.split(delimiter)[block_slice_pos + 1 :]
|
||||
)
|
||||
new_state_dict[new_key] = state_dict.pop(key)
|
||||
|
||||
for i in middle_block_ids:
|
||||
key_part = None
|
||||
if i == 0:
|
||||
key_part = [inner_block_map[0], "0"]
|
||||
elif i == 1:
|
||||
key_part = [inner_block_map[1], "0"]
|
||||
elif i == 2:
|
||||
key_part = [inner_block_map[0], "1"]
|
||||
else:
|
||||
raise ValueError(f"Invalid middle block id {i}.")
|
||||
|
||||
for key in middle_blocks[i]:
|
||||
new_key = delimiter.join(
|
||||
key.split(delimiter)[: block_slice_pos - 1] + key_part + key.split(delimiter)[block_slice_pos:]
|
||||
)
|
||||
new_state_dict[new_key] = state_dict.pop(key)
|
||||
|
||||
for i in output_block_ids:
|
||||
block_id = i // (unet_config.layers_per_block + 1)
|
||||
layer_in_block_id = i % (unet_config.layers_per_block + 1)
|
||||
|
||||
for key in output_blocks[i]:
|
||||
inner_block_id = int(key.split(delimiter)[block_slice_pos])
|
||||
inner_block_key = inner_block_map[inner_block_id]
|
||||
inner_layers_in_block = str(layer_in_block_id) if inner_block_id < 2 else "0"
|
||||
new_key = delimiter.join(
|
||||
key.split(delimiter)[: block_slice_pos - 1]
|
||||
+ [str(block_id), inner_block_key, inner_layers_in_block]
|
||||
+ key.split(delimiter)[block_slice_pos + 1 :]
|
||||
)
|
||||
new_state_dict[new_key] = state_dict.pop(key)
|
||||
|
||||
if is_all_unet and len(state_dict) > 0:
|
||||
raise ValueError("At this point all state dict entries have to be converted.")
|
||||
else:
|
||||
# Remaining is the text encoder state dict.
|
||||
for k, v in state_dict.items():
|
||||
new_state_dict.update({k: v})
|
||||
|
||||
return new_state_dict
|
||||
|
||||
@classmethod
|
||||
def load_lora_into_unet(cls, state_dict, network_alphas, unet):
|
||||
"""
|
||||
This will load the LoRA layers specified in `state_dict` into `unet`
|
||||
This will load the LoRA layers specified in `state_dict` into `unet`.
|
||||
|
||||
Parameters:
|
||||
state_dict (`dict`):
|
||||
A standard state dict containing the lora layer parameters. The keys can either be indexed directly
|
||||
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
|
||||
encoder lora layers.
|
||||
network_alpha (`float`):
|
||||
network_alphas (`Dict[str, float]`):
|
||||
See `LoRALinearLayer` for more details.
|
||||
unet (`UNet2DConditionModel`):
|
||||
The UNet model to load the LoRA layers into.
|
||||
"""
|
||||
|
||||
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
|
||||
# then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
|
||||
# their prefixes.
|
||||
keys = list(state_dict.keys())
|
||||
|
||||
if all(key.startswith(cls.unet_name) or key.startswith(cls.text_encoder_name) for key in keys):
|
||||
# Load the layers corresponding to UNet.
|
||||
unet_keys = [k for k in keys if k.startswith(cls.unet_name)]
|
||||
logger.info(f"Loading {cls.unet_name}.")
|
||||
unet_lora_state_dict = {
|
||||
k.replace(f"{cls.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys
|
||||
}
|
||||
unet.load_attn_procs(unet_lora_state_dict, network_alpha=network_alpha)
|
||||
|
||||
# Otherwise, we're dealing with the old format. This means the `state_dict` should only
|
||||
# contain the module names of the `unet` as its keys WITHOUT any prefix.
|
||||
elif not all(
|
||||
key.startswith(cls.unet_name) or key.startswith(cls.text_encoder_name) for key in state_dict.keys()
|
||||
):
|
||||
unet.load_attn_procs(state_dict, network_alpha=network_alpha)
|
||||
unet_keys = [k for k in keys if k.startswith(cls.unet_name)]
|
||||
state_dict = {k.replace(f"{cls.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys}
|
||||
|
||||
if network_alphas is not None:
|
||||
alpha_keys = [k for k in network_alphas.keys() if k.startswith(cls.unet_name)]
|
||||
network_alphas = {
|
||||
k.replace(f"{cls.unet_name}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
|
||||
}
|
||||
|
||||
else:
|
||||
# Otherwise, we're dealing with the old format. This means the `state_dict` should only
|
||||
# contain the module names of the `unet` as its keys WITHOUT any prefix.
|
||||
warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet'.{module_name}: params for module_name, params in old_state_dict.items()}`."
|
||||
warnings.warn(warn_message)
|
||||
|
||||
# load loras into unet
|
||||
unet.load_attn_procs(state_dict, network_alphas=network_alphas)
|
||||
|
||||
@classmethod
|
||||
def load_lora_into_text_encoder(cls, state_dict, network_alpha, text_encoder, prefix=None, lora_scale=1.0):
|
||||
def load_lora_into_text_encoder(cls, state_dict, network_alphas, text_encoder, prefix=None, lora_scale=1.0):
|
||||
"""
|
||||
This will load the LoRA layers specified in `state_dict` into `text_encoder`
|
||||
|
||||
@@ -1072,7 +1241,7 @@ class LoraLoaderMixin:
|
||||
state_dict (`dict`):
|
||||
A standard state dict containing the lora layer parameters. The key should be prefixed with an
|
||||
additional `text_encoder` to distinguish between unet lora layers.
|
||||
network_alpha (`float`):
|
||||
network_alphas (`Dict[str, float]`):
|
||||
See `LoRALinearLayer` for more details.
|
||||
text_encoder (`CLIPTextModel`):
|
||||
The text encoder model to load the LoRA layers into.
|
||||
@@ -1141,14 +1310,19 @@ class LoraLoaderMixin:
|
||||
].shape[1]
|
||||
patch_mlp = any(".mlp." in key for key in text_encoder_lora_state_dict.keys())
|
||||
|
||||
cls._modify_text_encoder(text_encoder, lora_scale, network_alpha, rank=rank, patch_mlp=patch_mlp)
|
||||
cls._modify_text_encoder(
|
||||
text_encoder,
|
||||
lora_scale,
|
||||
network_alphas,
|
||||
rank=rank,
|
||||
patch_mlp=patch_mlp,
|
||||
)
|
||||
|
||||
# set correct dtype & device
|
||||
text_encoder_lora_state_dict = {
|
||||
k: v.to(device=text_encoder.device, dtype=text_encoder.dtype)
|
||||
for k, v in text_encoder_lora_state_dict.items()
|
||||
}
|
||||
|
||||
load_state_dict_results = text_encoder.load_state_dict(text_encoder_lora_state_dict, strict=False)
|
||||
if len(load_state_dict_results.unexpected_keys) != 0:
|
||||
raise ValueError(
|
||||
@@ -1183,7 +1357,7 @@ class LoraLoaderMixin:
|
||||
cls,
|
||||
text_encoder,
|
||||
lora_scale=1,
|
||||
network_alpha=None,
|
||||
network_alphas=None,
|
||||
rank=4,
|
||||
dtype=None,
|
||||
patch_mlp=False,
|
||||
@@ -1196,37 +1370,46 @@ class LoraLoaderMixin:
|
||||
cls._remove_text_encoder_monkey_patch_classmethod(text_encoder)
|
||||
|
||||
lora_parameters = []
|
||||
network_alphas = {} if network_alphas is None else network_alphas
|
||||
|
||||
for name, attn_module in text_encoder_attn_modules(text_encoder):
|
||||
query_alpha = network_alphas.get(name + ".k.proj.alpha")
|
||||
key_alpha = network_alphas.get(name + ".q.proj.alpha")
|
||||
value_alpha = network_alphas.get(name + ".v.proj.alpha")
|
||||
proj_alpha = network_alphas.get(name + ".out.proj.alpha")
|
||||
|
||||
for _, attn_module in text_encoder_attn_modules(text_encoder):
|
||||
attn_module.q_proj = PatchedLoraProjection(
|
||||
attn_module.q_proj, lora_scale, network_alpha, rank=rank, dtype=dtype
|
||||
attn_module.q_proj, lora_scale, network_alpha=query_alpha, rank=rank, dtype=dtype
|
||||
)
|
||||
lora_parameters.extend(attn_module.q_proj.lora_linear_layer.parameters())
|
||||
|
||||
attn_module.k_proj = PatchedLoraProjection(
|
||||
attn_module.k_proj, lora_scale, network_alpha, rank=rank, dtype=dtype
|
||||
attn_module.k_proj, lora_scale, network_alpha=key_alpha, rank=rank, dtype=dtype
|
||||
)
|
||||
lora_parameters.extend(attn_module.k_proj.lora_linear_layer.parameters())
|
||||
|
||||
attn_module.v_proj = PatchedLoraProjection(
|
||||
attn_module.v_proj, lora_scale, network_alpha, rank=rank, dtype=dtype
|
||||
attn_module.v_proj, lora_scale, network_alpha=value_alpha, rank=rank, dtype=dtype
|
||||
)
|
||||
lora_parameters.extend(attn_module.v_proj.lora_linear_layer.parameters())
|
||||
|
||||
attn_module.out_proj = PatchedLoraProjection(
|
||||
attn_module.out_proj, lora_scale, network_alpha, rank=rank, dtype=dtype
|
||||
attn_module.out_proj, lora_scale, network_alpha=proj_alpha, rank=rank, dtype=dtype
|
||||
)
|
||||
lora_parameters.extend(attn_module.out_proj.lora_linear_layer.parameters())
|
||||
|
||||
if patch_mlp:
|
||||
for _, mlp_module in text_encoder_mlp_modules(text_encoder):
|
||||
for name, mlp_module in text_encoder_mlp_modules(text_encoder):
|
||||
fc1_alpha = network_alphas.get(name + ".fc1.alpha")
|
||||
fc2_alpha = network_alphas.get(name + ".fc2.alpha")
|
||||
|
||||
mlp_module.fc1 = PatchedLoraProjection(
|
||||
mlp_module.fc1, lora_scale, network_alpha, rank=rank, dtype=dtype
|
||||
mlp_module.fc1, lora_scale, network_alpha=fc1_alpha, rank=rank, dtype=dtype
|
||||
)
|
||||
lora_parameters.extend(mlp_module.fc1.lora_linear_layer.parameters())
|
||||
|
||||
mlp_module.fc2 = PatchedLoraProjection(
|
||||
mlp_module.fc2, lora_scale, network_alpha, rank=rank, dtype=dtype
|
||||
mlp_module.fc2, lora_scale, network_alpha=fc2_alpha, rank=rank, dtype=dtype
|
||||
)
|
||||
lora_parameters.extend(mlp_module.fc2.lora_linear_layer.parameters())
|
||||
|
||||
@@ -1333,77 +1516,163 @@ class LoraLoaderMixin:
|
||||
def _convert_kohya_lora_to_diffusers(cls, state_dict):
|
||||
unet_state_dict = {}
|
||||
te_state_dict = {}
|
||||
network_alpha = None
|
||||
unloaded_keys = []
|
||||
te2_state_dict = {}
|
||||
network_alphas = {}
|
||||
|
||||
for key, value in state_dict.items():
|
||||
if "hada" in key or "skip" in key:
|
||||
unloaded_keys.append(key)
|
||||
elif "lora_down" in key:
|
||||
lora_name = key.split(".")[0]
|
||||
lora_name_up = lora_name + ".lora_up.weight"
|
||||
lora_name_alpha = lora_name + ".alpha"
|
||||
if lora_name_alpha in state_dict:
|
||||
alpha = state_dict[lora_name_alpha].item()
|
||||
if network_alpha is None:
|
||||
network_alpha = alpha
|
||||
elif network_alpha != alpha:
|
||||
raise ValueError("Network alpha is not consistent")
|
||||
# every down weight has a corresponding up weight and potentially an alpha weight
|
||||
lora_keys = [k for k in state_dict.keys() if k.endswith("lora_down.weight")]
|
||||
for key in lora_keys:
|
||||
lora_name = key.split(".")[0]
|
||||
lora_name_up = lora_name + ".lora_up.weight"
|
||||
lora_name_alpha = lora_name + ".alpha"
|
||||
|
||||
if lora_name.startswith("lora_unet_"):
|
||||
diffusers_name = key.replace("lora_unet_", "").replace("_", ".")
|
||||
# if lora_name_alpha in state_dict:
|
||||
# alpha = state_dict.pop(lora_name_alpha).item()
|
||||
# network_alphas.update({lora_name_alpha: alpha})
|
||||
|
||||
if lora_name.startswith("lora_unet_"):
|
||||
diffusers_name = key.replace("lora_unet_", "").replace("_", ".")
|
||||
|
||||
if "input.blocks" in diffusers_name:
|
||||
diffusers_name = diffusers_name.replace("input.blocks", "down_blocks")
|
||||
else:
|
||||
diffusers_name = diffusers_name.replace("down.blocks", "down_blocks")
|
||||
|
||||
if "middle.block" in diffusers_name:
|
||||
diffusers_name = diffusers_name.replace("middle.block", "mid_block")
|
||||
else:
|
||||
diffusers_name = diffusers_name.replace("mid.block", "mid_block")
|
||||
if "output.blocks" in diffusers_name:
|
||||
diffusers_name = diffusers_name.replace("output.blocks", "up_blocks")
|
||||
else:
|
||||
diffusers_name = diffusers_name.replace("up.blocks", "up_blocks")
|
||||
diffusers_name = diffusers_name.replace("transformer.blocks", "transformer_blocks")
|
||||
diffusers_name = diffusers_name.replace("to.q.lora", "to_q_lora")
|
||||
diffusers_name = diffusers_name.replace("to.k.lora", "to_k_lora")
|
||||
diffusers_name = diffusers_name.replace("to.v.lora", "to_v_lora")
|
||||
diffusers_name = diffusers_name.replace("to.out.0.lora", "to_out_lora")
|
||||
diffusers_name = diffusers_name.replace("proj.in", "proj_in")
|
||||
diffusers_name = diffusers_name.replace("proj.out", "proj_out")
|
||||
if "transformer_blocks" in diffusers_name:
|
||||
if "attn1" in diffusers_name or "attn2" in diffusers_name:
|
||||
diffusers_name = diffusers_name.replace("attn1", "attn1.processor")
|
||||
diffusers_name = diffusers_name.replace("attn2", "attn2.processor")
|
||||
unet_state_dict[diffusers_name] = value
|
||||
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up]
|
||||
elif "ff" in diffusers_name:
|
||||
unet_state_dict[diffusers_name] = value
|
||||
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up]
|
||||
elif any(key in diffusers_name for key in ("proj_in", "proj_out")):
|
||||
unet_state_dict[diffusers_name] = value
|
||||
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up]
|
||||
|
||||
elif lora_name.startswith("lora_te_"):
|
||||
diffusers_name = key.replace("lora_te_", "").replace("_", ".")
|
||||
diffusers_name = diffusers_name.replace("text.model", "text_model")
|
||||
diffusers_name = diffusers_name.replace("self.attn", "self_attn")
|
||||
diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora")
|
||||
diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora")
|
||||
diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora")
|
||||
diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora")
|
||||
if "self_attn" in diffusers_name:
|
||||
te_state_dict[diffusers_name] = value
|
||||
te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up]
|
||||
elif "mlp" in diffusers_name:
|
||||
# Be aware that this is the new diffusers convention and the rest of the code might
|
||||
# not utilize it yet.
|
||||
diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.")
|
||||
te_state_dict[diffusers_name] = value
|
||||
te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up]
|
||||
diffusers_name = diffusers_name.replace("transformer.blocks", "transformer_blocks")
|
||||
diffusers_name = diffusers_name.replace("to.q.lora", "to_q_lora")
|
||||
diffusers_name = diffusers_name.replace("to.k.lora", "to_k_lora")
|
||||
diffusers_name = diffusers_name.replace("to.v.lora", "to_v_lora")
|
||||
diffusers_name = diffusers_name.replace("to.out.0.lora", "to_out_lora")
|
||||
diffusers_name = diffusers_name.replace("proj.in", "proj_in")
|
||||
diffusers_name = diffusers_name.replace("proj.out", "proj_out")
|
||||
diffusers_name = diffusers_name.replace("emb.layers", "time_emb_proj")
|
||||
|
||||
logger.info("Kohya-style checkpoint detected.")
|
||||
if len(unloaded_keys) > 0:
|
||||
example_unloaded_keys = ", ".join(x for x in unloaded_keys[:TOTAL_EXAMPLE_KEYS])
|
||||
logger.warning(
|
||||
f"There are some keys (such as: {example_unloaded_keys}) in the checkpoints we don't provide support for."
|
||||
# SDXL specificity.
|
||||
if "emb" in diffusers_name:
|
||||
pattern = r"\.\d+(?=\D*$)"
|
||||
diffusers_name = re.sub(pattern, "", diffusers_name, count=1)
|
||||
if ".in." in diffusers_name:
|
||||
diffusers_name = diffusers_name.replace("in.layers.2", "conv1")
|
||||
if ".out." in diffusers_name:
|
||||
diffusers_name = diffusers_name.replace("out.layers.3", "conv2")
|
||||
if "downsamplers" in diffusers_name or "upsamplers" in diffusers_name:
|
||||
diffusers_name = diffusers_name.replace("op", "conv")
|
||||
if "skip" in diffusers_name:
|
||||
diffusers_name = diffusers_name.replace("skip.connection", "conv_shortcut")
|
||||
|
||||
if "transformer_blocks" in diffusers_name:
|
||||
if "attn1" in diffusers_name or "attn2" in diffusers_name:
|
||||
diffusers_name = diffusers_name.replace("attn1", "attn1.processor")
|
||||
diffusers_name = diffusers_name.replace("attn2", "attn2.processor")
|
||||
unet_state_dict[diffusers_name] = state_dict.pop(key)
|
||||
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
|
||||
elif "ff" in diffusers_name:
|
||||
unet_state_dict[diffusers_name] = state_dict.pop(key)
|
||||
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
|
||||
elif any(key in diffusers_name for key in ("proj_in", "proj_out")):
|
||||
unet_state_dict[diffusers_name] = state_dict.pop(key)
|
||||
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
|
||||
else:
|
||||
unet_state_dict[diffusers_name] = state_dict.pop(key)
|
||||
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
|
||||
|
||||
elif lora_name.startswith("lora_te_"):
|
||||
diffusers_name = key.replace("lora_te_", "").replace("_", ".")
|
||||
diffusers_name = diffusers_name.replace("text.model", "text_model")
|
||||
diffusers_name = diffusers_name.replace("self.attn", "self_attn")
|
||||
diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora")
|
||||
diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora")
|
||||
diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora")
|
||||
diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora")
|
||||
if "self_attn" in diffusers_name:
|
||||
te_state_dict[diffusers_name] = state_dict.pop(key)
|
||||
te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
|
||||
elif "mlp" in diffusers_name:
|
||||
# Be aware that this is the new diffusers convention and the rest of the code might
|
||||
# not utilize it yet.
|
||||
diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.")
|
||||
te_state_dict[diffusers_name] = state_dict.pop(key)
|
||||
te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
|
||||
|
||||
# (sayakpaul): Duplicate code. Needs to be cleaned.
|
||||
elif lora_name.startswith("lora_te1_"):
|
||||
diffusers_name = key.replace("lora_te1_", "").replace("_", ".")
|
||||
diffusers_name = diffusers_name.replace("text.model", "text_model")
|
||||
diffusers_name = diffusers_name.replace("self.attn", "self_attn")
|
||||
diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora")
|
||||
diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora")
|
||||
diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora")
|
||||
diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora")
|
||||
if "self_attn" in diffusers_name:
|
||||
te_state_dict[diffusers_name] = state_dict.pop(key)
|
||||
te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
|
||||
elif "mlp" in diffusers_name:
|
||||
# Be aware that this is the new diffusers convention and the rest of the code might
|
||||
# not utilize it yet.
|
||||
diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.")
|
||||
te_state_dict[diffusers_name] = state_dict.pop(key)
|
||||
te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
|
||||
|
||||
# (sayakpaul): Duplicate code. Needs to be cleaned.
|
||||
elif lora_name.startswith("lora_te2_"):
|
||||
diffusers_name = key.replace("lora_te2_", "").replace("_", ".")
|
||||
diffusers_name = diffusers_name.replace("text.model", "text_model")
|
||||
diffusers_name = diffusers_name.replace("self.attn", "self_attn")
|
||||
diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora")
|
||||
diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora")
|
||||
diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora")
|
||||
diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora")
|
||||
if "self_attn" in diffusers_name:
|
||||
te2_state_dict[diffusers_name] = state_dict.pop(key)
|
||||
te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
|
||||
elif "mlp" in diffusers_name:
|
||||
# Be aware that this is the new diffusers convention and the rest of the code might
|
||||
# not utilize it yet.
|
||||
diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.")
|
||||
te2_state_dict[diffusers_name] = state_dict.pop(key)
|
||||
te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
|
||||
|
||||
# Rename the alphas so that they can be mapped appropriately.
|
||||
if lora_name_alpha in state_dict:
|
||||
alpha = state_dict.pop(lora_name_alpha).item()
|
||||
if lora_name_alpha.startswith("lora_unet_"):
|
||||
prefix = "unet."
|
||||
elif lora_name_alpha.startswith(("lora_te_", "lora_te1_")):
|
||||
prefix = "text_encoder."
|
||||
else:
|
||||
prefix = "text_encoder_2."
|
||||
new_name = prefix + diffusers_name.split(".lora.")[0] + ".alpha"
|
||||
network_alphas.update({new_name: alpha})
|
||||
|
||||
if len(state_dict) > 0:
|
||||
raise ValueError(
|
||||
f"The following keys have not been correctly be renamed: \n\n {', '.join(state_dict.keys())}"
|
||||
)
|
||||
|
||||
unet_state_dict = {f"{UNET_NAME}.{module_name}": params for module_name, params in unet_state_dict.items()}
|
||||
te_state_dict = {f"{TEXT_ENCODER_NAME}.{module_name}": params for module_name, params in te_state_dict.items()}
|
||||
logger.info("Kohya-style checkpoint detected.")
|
||||
unet_state_dict = {f"{cls.unet_name}.{module_name}": params for module_name, params in unet_state_dict.items()}
|
||||
te_state_dict = {
|
||||
f"{cls.text_encoder_name}.{module_name}": params for module_name, params in te_state_dict.items()
|
||||
}
|
||||
te2_state_dict = (
|
||||
{f"text_encoder_2.{module_name}": params for module_name, params in te2_state_dict.items()}
|
||||
if len(te2_state_dict) > 0
|
||||
else None
|
||||
)
|
||||
if te2_state_dict is not None:
|
||||
te_state_dict.update(te2_state_dict)
|
||||
|
||||
new_state_dict = {**unet_state_dict, **te_state_dict}
|
||||
return new_state_dict, network_alpha
|
||||
return new_state_dict, network_alphas
|
||||
|
||||
def unload_lora_weights(self):
|
||||
"""
|
||||
|
||||
@@ -521,17 +521,32 @@ class LoRAAttnProcessor(nn.Module):
|
||||
Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None):
|
||||
def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None, **kwargs):
|
||||
super().__init__()
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.cross_attention_dim = cross_attention_dim
|
||||
self.rank = rank
|
||||
|
||||
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
||||
q_rank = kwargs.pop("q_rank", None)
|
||||
q_hidden_size = kwargs.pop("q_hidden_size", None)
|
||||
q_rank = q_rank if q_rank is not None else rank
|
||||
q_hidden_size = q_hidden_size if q_hidden_size is not None else hidden_size
|
||||
|
||||
v_rank = kwargs.pop("v_rank", None)
|
||||
v_hidden_size = kwargs.pop("v_hidden_size", None)
|
||||
v_rank = v_rank if v_rank is not None else rank
|
||||
v_hidden_size = v_hidden_size if v_hidden_size is not None else hidden_size
|
||||
|
||||
out_rank = kwargs.pop("out_rank", None)
|
||||
out_hidden_size = kwargs.pop("out_hidden_size", None)
|
||||
out_rank = out_rank if out_rank is not None else rank
|
||||
out_hidden_size = out_hidden_size if out_hidden_size is not None else hidden_size
|
||||
|
||||
self.to_q_lora = LoRALinearLayer(q_hidden_size, q_hidden_size, q_rank, network_alpha)
|
||||
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
|
||||
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
|
||||
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
||||
self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha)
|
||||
self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha)
|
||||
|
||||
def __call__(
|
||||
self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0, temb=None
|
||||
@@ -1144,7 +1159,13 @@ class LoRAXFormersAttnProcessor(nn.Module):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, hidden_size, cross_attention_dim, rank=4, attention_op: Optional[Callable] = None, network_alpha=None
|
||||
self,
|
||||
hidden_size,
|
||||
cross_attention_dim,
|
||||
rank=4,
|
||||
attention_op: Optional[Callable] = None,
|
||||
network_alpha=None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -1153,10 +1174,25 @@ class LoRAXFormersAttnProcessor(nn.Module):
|
||||
self.rank = rank
|
||||
self.attention_op = attention_op
|
||||
|
||||
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
||||
q_rank = kwargs.pop("q_rank", None)
|
||||
q_hidden_size = kwargs.pop("q_hidden_size", None)
|
||||
q_rank = q_rank if q_rank is not None else rank
|
||||
q_hidden_size = q_hidden_size if q_hidden_size is not None else hidden_size
|
||||
|
||||
v_rank = kwargs.pop("v_rank", None)
|
||||
v_hidden_size = kwargs.pop("v_hidden_size", None)
|
||||
v_rank = v_rank if v_rank is not None else rank
|
||||
v_hidden_size = v_hidden_size if v_hidden_size is not None else hidden_size
|
||||
|
||||
out_rank = kwargs.pop("out_rank", None)
|
||||
out_hidden_size = kwargs.pop("out_hidden_size", None)
|
||||
out_rank = out_rank if out_rank is not None else rank
|
||||
out_hidden_size = out_hidden_size if out_hidden_size is not None else hidden_size
|
||||
|
||||
self.to_q_lora = LoRALinearLayer(q_hidden_size, q_hidden_size, q_rank, network_alpha)
|
||||
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
|
||||
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
|
||||
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
||||
self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha)
|
||||
self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha)
|
||||
|
||||
def __call__(
|
||||
self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0, temb=None
|
||||
@@ -1231,7 +1267,7 @@ class LoRAAttnProcessor2_0(nn.Module):
|
||||
Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None):
|
||||
def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None, **kwargs):
|
||||
super().__init__()
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
||||
@@ -1240,10 +1276,25 @@ class LoRAAttnProcessor2_0(nn.Module):
|
||||
self.cross_attention_dim = cross_attention_dim
|
||||
self.rank = rank
|
||||
|
||||
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
||||
q_rank = kwargs.pop("q_rank", None)
|
||||
q_hidden_size = kwargs.pop("q_hidden_size", None)
|
||||
q_rank = q_rank if q_rank is not None else rank
|
||||
q_hidden_size = q_hidden_size if q_hidden_size is not None else hidden_size
|
||||
|
||||
v_rank = kwargs.pop("v_rank", None)
|
||||
v_hidden_size = kwargs.pop("v_hidden_size", None)
|
||||
v_rank = v_rank if v_rank is not None else rank
|
||||
v_hidden_size = v_hidden_size if v_hidden_size is not None else hidden_size
|
||||
|
||||
out_rank = kwargs.pop("out_rank", None)
|
||||
out_hidden_size = kwargs.pop("out_hidden_size", None)
|
||||
out_rank = out_rank if out_rank is not None else rank
|
||||
out_hidden_size = out_hidden_size if out_hidden_size is not None else hidden_size
|
||||
|
||||
self.to_q_lora = LoRALinearLayer(q_hidden_size, q_hidden_size, q_rank, network_alpha)
|
||||
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
|
||||
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
|
||||
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
||||
self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha)
|
||||
self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha)
|
||||
|
||||
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0):
|
||||
residual = hidden_states
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
|
||||
@@ -48,14 +49,19 @@ class LoRALinearLayer(nn.Module):
|
||||
|
||||
|
||||
class LoRAConv2dLayer(nn.Module):
|
||||
def __init__(self, in_features, out_features, rank=4, network_alpha=None):
|
||||
def __init__(
|
||||
self, in_features, out_features, rank=4, kernel_size=(1, 1), stride=(1, 1), padding=0, network_alpha=None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if rank > min(in_features, out_features):
|
||||
raise ValueError(f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}")
|
||||
|
||||
self.down = nn.Conv2d(in_features, rank, (1, 1), (1, 1), bias=False)
|
||||
self.up = nn.Conv2d(rank, out_features, (1, 1), (1, 1), bias=False)
|
||||
self.down = nn.Conv2d(in_features, rank, kernel_size=kernel_size, stride=stride, padding=padding, bias=False)
|
||||
# according to the official kohya_ss trainer kernel_size are always fixed for the up layer
|
||||
# # see: https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L129
|
||||
self.up = nn.Conv2d(rank, out_features, kernel_size=(1, 1), stride=(1, 1), bias=False)
|
||||
|
||||
# This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
|
||||
# See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
|
||||
self.network_alpha = network_alpha
|
||||
@@ -91,7 +97,9 @@ class LoRACompatibleConv(nn.Conv2d):
|
||||
|
||||
def forward(self, x):
|
||||
if self.lora_layer is None:
|
||||
return super().forward(x)
|
||||
# make sure to the functional Conv2D function as otherwise torch.compile's graph will break
|
||||
# see: https://github.com/huggingface/diffusers/pull/4315
|
||||
return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
|
||||
else:
|
||||
return super().forward(x) + self.lora_layer(x)
|
||||
|
||||
|
||||
@@ -23,6 +23,7 @@ import torch.nn.functional as F
|
||||
from .activations import get_activation
|
||||
from .attention import AdaGroupNorm
|
||||
from .attention_processor import SpatialNorm
|
||||
from .lora import LoRACompatibleConv, LoRACompatibleLinear
|
||||
|
||||
|
||||
class Upsample1D(nn.Module):
|
||||
@@ -126,7 +127,7 @@ class Upsample2D(nn.Module):
|
||||
if use_conv_transpose:
|
||||
conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1)
|
||||
elif use_conv:
|
||||
conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=1)
|
||||
conv = LoRACompatibleConv(self.channels, self.out_channels, 3, padding=1)
|
||||
|
||||
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
|
||||
if name == "conv":
|
||||
@@ -196,7 +197,7 @@ class Downsample2D(nn.Module):
|
||||
self.name = name
|
||||
|
||||
if use_conv:
|
||||
conv = nn.Conv2d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
|
||||
conv = LoRACompatibleConv(self.channels, self.out_channels, 3, stride=stride, padding=padding)
|
||||
else:
|
||||
assert self.channels == self.out_channels
|
||||
conv = nn.AvgPool2d(kernel_size=stride, stride=stride)
|
||||
@@ -534,13 +535,13 @@ class ResnetBlock2D(nn.Module):
|
||||
else:
|
||||
self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
|
||||
|
||||
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
self.conv1 = LoRACompatibleConv(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
if temb_channels is not None:
|
||||
if self.time_embedding_norm == "default":
|
||||
self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels)
|
||||
self.time_emb_proj = LoRACompatibleLinear(temb_channels, out_channels)
|
||||
elif self.time_embedding_norm == "scale_shift":
|
||||
self.time_emb_proj = torch.nn.Linear(temb_channels, 2 * out_channels)
|
||||
self.time_emb_proj = LoRACompatibleLinear(temb_channels, 2 * out_channels)
|
||||
elif self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
|
||||
self.time_emb_proj = None
|
||||
else:
|
||||
@@ -557,7 +558,7 @@ class ResnetBlock2D(nn.Module):
|
||||
|
||||
self.dropout = torch.nn.Dropout(dropout)
|
||||
conv_2d_out_channels = conv_2d_out_channels or out_channels
|
||||
self.conv2 = torch.nn.Conv2d(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1)
|
||||
self.conv2 = LoRACompatibleConv(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
self.nonlinearity = get_activation(non_linearity)
|
||||
|
||||
@@ -583,7 +584,7 @@ class ResnetBlock2D(nn.Module):
|
||||
|
||||
self.conv_shortcut = None
|
||||
if self.use_in_shortcut:
|
||||
self.conv_shortcut = torch.nn.Conv2d(
|
||||
self.conv_shortcut = LoRACompatibleConv(
|
||||
in_channels, conv_2d_out_channels, kernel_size=1, stride=1, padding=0, bias=conv_shortcut_bias
|
||||
)
|
||||
|
||||
|
||||
@@ -23,7 +23,7 @@ from ..models.embeddings import ImagePositionalEmbeddings
|
||||
from ..utils import BaseOutput, deprecate
|
||||
from .attention import BasicTransformerBlock
|
||||
from .embeddings import PatchEmbed
|
||||
from .lora import LoRACompatibleConv
|
||||
from .lora import LoRACompatibleConv, LoRACompatibleLinear
|
||||
from .modeling_utils import ModelMixin
|
||||
|
||||
|
||||
@@ -137,7 +137,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
||||
|
||||
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
if use_linear_projection:
|
||||
self.proj_in = nn.Linear(in_channels, inner_dim)
|
||||
self.proj_in = LoRACompatibleLinear(in_channels, inner_dim)
|
||||
else:
|
||||
self.proj_in = LoRACompatibleConv(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
|
||||
elif self.is_input_vectorized:
|
||||
@@ -193,7 +193,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
||||
if self.is_input_continuous:
|
||||
# TODO: should use out_channels for continuous projections
|
||||
if use_linear_projection:
|
||||
self.proj_out = nn.Linear(inner_dim, in_channels)
|
||||
self.proj_out = LoRACompatibleLinear(inner_dim, in_channels)
|
||||
else:
|
||||
self.proj_out = LoRACompatibleConv(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
elif self.is_input_vectorized:
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from ..utils import (
|
||||
OptionalDependencyNotAvailable,
|
||||
is_flax_available,
|
||||
is_invisible_watermark_available,
|
||||
is_k_diffusion_available,
|
||||
is_librosa_available,
|
||||
is_note_seq_available,
|
||||
@@ -51,6 +50,7 @@ else:
|
||||
StableDiffusionControlNetImg2ImgPipeline,
|
||||
StableDiffusionControlNetInpaintPipeline,
|
||||
StableDiffusionControlNetPipeline,
|
||||
StableDiffusionXLControlNetPipeline,
|
||||
)
|
||||
from .deepfloyd_if import (
|
||||
IFImg2ImgPipeline,
|
||||
@@ -108,6 +108,12 @@ else:
|
||||
StableUnCLIPPipeline,
|
||||
)
|
||||
from .stable_diffusion_safe import StableDiffusionPipelineSafe
|
||||
from .stable_diffusion_xl import (
|
||||
StableDiffusionXLImg2ImgPipeline,
|
||||
StableDiffusionXLInpaintPipeline,
|
||||
StableDiffusionXLInstructPix2PixPipeline,
|
||||
StableDiffusionXLPipeline,
|
||||
)
|
||||
from .t2i_adapter import StableDiffusionAdapterPipeline
|
||||
from .text_to_video_synthesis import TextToVideoSDPipeline, TextToVideoZeroPipeline, VideoToVideoSDPipeline
|
||||
from .unclip import UnCLIPImageVariationPipeline, UnCLIPPipeline
|
||||
@@ -121,20 +127,6 @@ else:
|
||||
from .vq_diffusion import VQDiffusionPipeline
|
||||
|
||||
|
||||
try:
|
||||
if not (is_torch_available() and is_transformers_available() and is_invisible_watermark_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils.dummy_torch_and_transformers_and_invisible_watermark_objects import * # noqa F403
|
||||
else:
|
||||
from .controlnet import StableDiffusionXLControlNetPipeline
|
||||
from .stable_diffusion_xl import (
|
||||
StableDiffusionXLImg2ImgPipeline,
|
||||
StableDiffusionXLInpaintPipeline,
|
||||
StableDiffusionXLInstructPix2PixPipeline,
|
||||
StableDiffusionXLPipeline,
|
||||
)
|
||||
|
||||
try:
|
||||
if not is_onnx_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
|
||||
@@ -1,21 +1,11 @@
|
||||
from ...utils import (
|
||||
OptionalDependencyNotAvailable,
|
||||
is_flax_available,
|
||||
is_invisible_watermark_available,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available() and is_invisible_watermark_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_and_invisible_watermark_objects import * # noqa F403
|
||||
else:
|
||||
from .pipeline_controlnet_sd_xl import StableDiffusionXLControlNetPipeline
|
||||
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
@@ -26,6 +16,7 @@ else:
|
||||
from .pipeline_controlnet import StableDiffusionControlNetPipeline
|
||||
from .pipeline_controlnet_img2img import StableDiffusionControlNetImg2ImgPipeline
|
||||
from .pipeline_controlnet_inpaint import StableDiffusionControlNetInpaintPipeline
|
||||
from .pipeline_controlnet_sd_xl import StableDiffusionXLControlNetPipeline
|
||||
|
||||
|
||||
if is_transformers_available() and is_flax_available():
|
||||
|
||||
@@ -22,6 +22,8 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
|
||||
|
||||
from diffusers.utils.import_utils import is_invisible_watermark_available
|
||||
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
|
||||
@@ -42,7 +44,11 @@ from ...utils import (
|
||||
)
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from ..stable_diffusion_xl import StableDiffusionXLPipelineOutput
|
||||
from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
|
||||
|
||||
|
||||
if is_invisible_watermark_available():
|
||||
from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
|
||||
|
||||
from .multicontrolnet import MultiControlNetModel
|
||||
|
||||
|
||||
@@ -109,6 +115,7 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa
|
||||
controlnet: ControlNetModel,
|
||||
scheduler: KarrasDiffusionSchedulers,
|
||||
force_zeros_for_empty_prompt: bool = True,
|
||||
add_watermarker: Optional[bool] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -130,7 +137,13 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa
|
||||
self.control_image_processor = VaeImageProcessor(
|
||||
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
|
||||
)
|
||||
self.watermark = StableDiffusionXLWatermarker()
|
||||
add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
|
||||
|
||||
if add_watermarker:
|
||||
self.watermark = StableDiffusionXLWatermarker()
|
||||
else:
|
||||
self.watermark = None
|
||||
|
||||
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
|
||||
@@ -995,7 +1008,10 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa
|
||||
image = latents
|
||||
return StableDiffusionXLPipelineOutput(images=image)
|
||||
|
||||
image = self.watermark.apply_watermark(image)
|
||||
# apply watermark if available
|
||||
if self.watermark is not None:
|
||||
image = self.watermark.apply_watermark(image)
|
||||
|
||||
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||
|
||||
# Offload last model to CPU
|
||||
|
||||
@@ -494,6 +494,7 @@ class DiffusionPipeline(ConfigMixin):
|
||||
_optional_components = []
|
||||
_exclude_from_cpu_offload = []
|
||||
_load_connected_pipes = False
|
||||
_is_onnx = False
|
||||
|
||||
def register_modules(self, **kwargs):
|
||||
# import it here to avoid circular import
|
||||
@@ -839,6 +840,11 @@ class DiffusionPipeline(ConfigMixin):
|
||||
If set to `None`, the safetensors weights are downloaded if they're available **and** if the
|
||||
safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
|
||||
weights. If set to `False`, safetensors weights are not loaded.
|
||||
use_onnx (`bool`, *optional*, defaults to `None`):
|
||||
If set to `True`, ONNX weights will always be downloaded if present. If set to `False`, ONNX weights
|
||||
will never be downloaded. By default `use_onnx` defaults to the `_is_onnx` class attribute which is
|
||||
`False` for non-ONNX pipelines and `True` for ONNX pipelines. ONNX weights include both files ending
|
||||
with `.onnx` and `.pb`.
|
||||
kwargs (remaining dictionary of keyword arguments, *optional*):
|
||||
Can be used to overwrite load and saveable variables (the pipeline components of the specific pipeline
|
||||
class). The overwritten components are passed directly to the pipelines `__init__` method. See example
|
||||
@@ -1268,6 +1274,15 @@ class DiffusionPipeline(ConfigMixin):
|
||||
variant (`str`, *optional*):
|
||||
Load weights from a specified variant filename such as `"fp16"` or `"ema"`. This is ignored when
|
||||
loading `from_flax`.
|
||||
use_safetensors (`bool`, *optional*, defaults to `None`):
|
||||
If set to `None`, the safetensors weights are downloaded if they're available **and** if the
|
||||
safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
|
||||
weights. If set to `False`, safetensors weights are not loaded.
|
||||
use_onnx (`bool`, *optional*, defaults to `False`):
|
||||
If set to `True`, ONNX weights will always be downloaded if present. If set to `False`, ONNX weights
|
||||
will never be downloaded. By default `use_onnx` defaults to the `_is_onnx` class attribute which is
|
||||
`False` for non-ONNX pipelines and `True` for ONNX pipelines. ONNX weights include both files ending
|
||||
with `.onnx` and `.pb`.
|
||||
|
||||
Returns:
|
||||
`os.PathLike`:
|
||||
@@ -1293,6 +1308,7 @@ class DiffusionPipeline(ConfigMixin):
|
||||
custom_revision = kwargs.pop("custom_revision", None)
|
||||
variant = kwargs.pop("variant", None)
|
||||
use_safetensors = kwargs.pop("use_safetensors", None)
|
||||
use_onnx = kwargs.pop("use_onnx", None)
|
||||
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
|
||||
|
||||
if use_safetensors and not is_safetensors_available():
|
||||
@@ -1364,7 +1380,7 @@ class DiffusionPipeline(ConfigMixin):
|
||||
pretrained_model_name, use_auth_token, variant, revision, model_filenames
|
||||
)
|
||||
|
||||
model_folder_names = {os.path.split(f)[0] for f in model_filenames}
|
||||
model_folder_names = {os.path.split(f)[0] for f in model_filenames if os.path.split(f)[0] in folder_names}
|
||||
|
||||
# all filenames compatible with variant will be added
|
||||
allow_patterns = list(model_filenames)
|
||||
@@ -1411,6 +1427,10 @@ class DiffusionPipeline(ConfigMixin):
|
||||
):
|
||||
ignore_patterns = ["*.bin", "*.msgpack"]
|
||||
|
||||
use_onnx = use_onnx if use_onnx is not None else pipeline_class._is_onnx
|
||||
if not use_onnx:
|
||||
ignore_patterns += ["*.onnx", "*.pb"]
|
||||
|
||||
safetensors_variant_filenames = {f for f in variant_filenames if f.endswith(".safetensors")}
|
||||
safetensors_model_filenames = {f for f in model_filenames if f.endswith(".safetensors")}
|
||||
if (
|
||||
@@ -1423,6 +1443,10 @@ class DiffusionPipeline(ConfigMixin):
|
||||
else:
|
||||
ignore_patterns = ["*.safetensors", "*.msgpack"]
|
||||
|
||||
use_onnx = use_onnx if use_onnx is not None else pipeline_class._is_onnx
|
||||
if not use_onnx:
|
||||
ignore_patterns += ["*.onnx", "*.pb"]
|
||||
|
||||
bin_variant_filenames = {f for f in variant_filenames if f.endswith(".bin")}
|
||||
bin_model_filenames = {f for f in model_filenames if f.endswith(".bin")}
|
||||
if len(bin_variant_filenames) > 0 and bin_model_filenames != bin_variant_filenames:
|
||||
@@ -1474,11 +1498,25 @@ class DiffusionPipeline(ConfigMixin):
|
||||
user_agent=user_agent,
|
||||
)
|
||||
|
||||
if pipeline_class._load_connected_pipes:
|
||||
# retrieve pipeline class from local file
|
||||
cls_name = cls.load_config(os.path.join(cached_folder, "model_index.json")).get("_class_name", None)
|
||||
pipeline_class = getattr(diffusers, cls_name, None)
|
||||
|
||||
if pipeline_class is not None and pipeline_class._load_connected_pipes:
|
||||
modelcard = ModelCard.load(os.path.join(cached_folder, "README.md"))
|
||||
connected_pipes = sum([getattr(modelcard.data, k, []) for k in CONNECTED_PIPES_KEYS], [])
|
||||
for connected_pipe_repo_id in connected_pipes:
|
||||
DiffusionPipeline.download(connected_pipe_repo_id)
|
||||
download_kwargs = {
|
||||
"cache_dir": cache_dir,
|
||||
"resume_download": resume_download,
|
||||
"force_download": force_download,
|
||||
"proxies": proxies,
|
||||
"local_files_only": local_files_only,
|
||||
"use_auth_token": use_auth_token,
|
||||
"variant": variant,
|
||||
"use_safetensors": use_safetensors,
|
||||
}
|
||||
DiffusionPipeline.download(connected_pipe_repo_id, **download_kwargs)
|
||||
|
||||
return cached_folder
|
||||
|
||||
|
||||
@@ -1186,6 +1186,7 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
StableDiffusionInpaintPipeline,
|
||||
StableDiffusionPipeline,
|
||||
StableDiffusionXLImg2ImgPipeline,
|
||||
StableDiffusionXLPipeline,
|
||||
StableUnCLIPImg2ImgPipeline,
|
||||
StableUnCLIPPipeline,
|
||||
)
|
||||
@@ -1542,7 +1543,7 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
checkpoint, config_name, prefix="conditioner.embedders.1.model.", has_projection=True, **config_kwargs
|
||||
)
|
||||
|
||||
pipe = pipeline_class(
|
||||
pipe = StableDiffusionXLPipeline(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
|
||||
@@ -41,6 +41,7 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
|
||||
feature_extractor: CLIPImageProcessor
|
||||
|
||||
_optional_components = ["safety_checker", "feature_extractor"]
|
||||
_is_onnx = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -98,6 +98,7 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
feature_extractor: CLIPImageProcessor
|
||||
|
||||
_optional_components = ["safety_checker", "feature_extractor"]
|
||||
_is_onnx = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -90,6 +90,7 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
feature_extractor: CLIPImageProcessor
|
||||
|
||||
_optional_components = ["safety_checker", "feature_extractor"]
|
||||
_is_onnx = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -67,6 +67,7 @@ class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
"""
|
||||
_optional_components = ["safety_checker", "feature_extractor"]
|
||||
_is_onnx = True
|
||||
|
||||
vae_encoder: OnnxRuntimeModel
|
||||
vae_decoder: OnnxRuntimeModel
|
||||
|
||||
@@ -46,6 +46,8 @@ def preprocess(image):
|
||||
|
||||
|
||||
class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
|
||||
_is_onnx = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vae: OnnxRuntimeModel,
|
||||
|
||||
@@ -424,10 +424,13 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMi
|
||||
|
||||
# verify batch size of prompt and image are same if image is a list or tensor or numpy array
|
||||
if isinstance(image, list) or isinstance(image, torch.Tensor) or isinstance(image, np.ndarray):
|
||||
if isinstance(prompt, str):
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
else:
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if isinstance(image, list):
|
||||
image_batch_size = len(image)
|
||||
else:
|
||||
|
||||
@@ -7,7 +7,6 @@ import PIL
|
||||
from ...utils import (
|
||||
BaseOutput,
|
||||
OptionalDependencyNotAvailable,
|
||||
is_invisible_watermark_available,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
@@ -28,10 +27,10 @@ class StableDiffusionXLPipelineOutput(BaseOutput):
|
||||
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available() and is_invisible_watermark_available()):
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_and_invisible_watermark_objects import * # noqa F403
|
||||
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
|
||||
else:
|
||||
from .pipeline_stable_diffusion_xl import StableDiffusionXLPipeline
|
||||
from .pipeline_stable_diffusion_xl_img2img import StableDiffusionXLImg2ImgPipeline
|
||||
|
||||
@@ -32,13 +32,17 @@ from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import (
|
||||
is_accelerate_available,
|
||||
is_accelerate_version,
|
||||
is_invisible_watermark_available,
|
||||
logging,
|
||||
randn_tensor,
|
||||
replace_example_docstring,
|
||||
)
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from . import StableDiffusionXLPipelineOutput
|
||||
from .watermark import StableDiffusionXLWatermarker
|
||||
|
||||
|
||||
if is_invisible_watermark_available():
|
||||
from .watermark import StableDiffusionXLWatermarker
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
@@ -84,11 +88,11 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
|
||||
|
||||
In addition the pipeline inherits the following loading methods:
|
||||
- *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
|
||||
- *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
|
||||
- *LoRA*: [`StableDiffusionXLPipeline.load_lora_weights`]
|
||||
- *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`]
|
||||
|
||||
as well as the following saving methods:
|
||||
- *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
|
||||
- *LoRA*: [`loaders.StableDiffusionXLPipeline.save_lora_weights`]
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
@@ -125,6 +129,7 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: KarrasDiffusionSchedulers,
|
||||
force_zeros_for_empty_prompt: bool = True,
|
||||
add_watermarker: Optional[bool] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -142,7 +147,12 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
self.default_sample_size = self.unet.config.sample_size
|
||||
|
||||
self.watermark = StableDiffusionXLWatermarker()
|
||||
add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
|
||||
|
||||
if add_watermarker:
|
||||
self.watermark = StableDiffusionXLWatermarker()
|
||||
else:
|
||||
self.watermark = None
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
|
||||
def enable_vae_slicing(self):
|
||||
@@ -839,7 +849,10 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
|
||||
image = latents
|
||||
return StableDiffusionXLPipelineOutput(images=image)
|
||||
|
||||
image = self.watermark.apply_watermark(image)
|
||||
# apply watermark if available
|
||||
if self.watermark is not None:
|
||||
image = self.watermark.apply_watermark(image)
|
||||
|
||||
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||
|
||||
# Offload last model to CPU
|
||||
@@ -853,14 +866,21 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
|
||||
|
||||
# Overrride to properly handle the loading and unloading of the additional text encoder.
|
||||
def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
|
||||
state_dict, network_alpha = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
|
||||
self.load_lora_into_unet(state_dict, network_alpha=network_alpha, unet=self.unet)
|
||||
# We could have accessed the unet config from `lora_state_dict()` too. We pass
|
||||
# it here explicitly to be able to tell that it's coming from an SDXL
|
||||
# pipeline.
|
||||
state_dict, network_alphas = self.lora_state_dict(
|
||||
pretrained_model_name_or_path_or_dict,
|
||||
unet_config=self.unet.config,
|
||||
**kwargs,
|
||||
)
|
||||
self.load_lora_into_unet(state_dict, network_alphas=network_alphas, unet=self.unet)
|
||||
|
||||
text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
|
||||
if len(text_encoder_state_dict) > 0:
|
||||
self.load_lora_into_text_encoder(
|
||||
text_encoder_state_dict,
|
||||
network_alpha=network_alpha,
|
||||
network_alphas=network_alphas,
|
||||
text_encoder=self.text_encoder,
|
||||
prefix="text_encoder",
|
||||
lora_scale=self.lora_scale,
|
||||
@@ -870,7 +890,7 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
|
||||
if len(text_encoder_2_state_dict) > 0:
|
||||
self.load_lora_into_text_encoder(
|
||||
text_encoder_2_state_dict,
|
||||
network_alpha=network_alpha,
|
||||
network_alphas=network_alphas,
|
||||
text_encoder=self.text_encoder_2,
|
||||
prefix="text_encoder_2",
|
||||
lora_scale=self.lora_scale,
|
||||
|
||||
@@ -33,13 +33,17 @@ from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import (
|
||||
is_accelerate_available,
|
||||
is_accelerate_version,
|
||||
is_invisible_watermark_available,
|
||||
logging,
|
||||
randn_tensor,
|
||||
replace_example_docstring,
|
||||
)
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from . import StableDiffusionXLPipelineOutput
|
||||
from .watermark import StableDiffusionXLWatermarker
|
||||
|
||||
|
||||
if is_invisible_watermark_available():
|
||||
from .watermark import StableDiffusionXLWatermarker
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
@@ -131,6 +135,7 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, L
|
||||
scheduler: KarrasDiffusionSchedulers,
|
||||
requires_aesthetics_score: bool = False,
|
||||
force_zeros_for_empty_prompt: bool = True,
|
||||
add_watermarker: Optional[bool] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -148,7 +153,12 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, L
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
|
||||
self.watermark = StableDiffusionXLWatermarker()
|
||||
add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
|
||||
|
||||
if add_watermarker:
|
||||
self.watermark = StableDiffusionXLWatermarker()
|
||||
else:
|
||||
self.watermark = None
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
|
||||
def enable_vae_slicing(self):
|
||||
@@ -906,15 +916,17 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, L
|
||||
negative_aesthetic_score,
|
||||
dtype=prompt_embeds.dtype,
|
||||
)
|
||||
add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
|
||||
add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1)
|
||||
add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0)
|
||||
|
||||
prompt_embeds = prompt_embeds.to(device)
|
||||
add_text_embeds = add_text_embeds.to(device)
|
||||
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
|
||||
add_time_ids = add_time_ids.to(device)
|
||||
|
||||
# 9. Denoising loop
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
@@ -988,7 +1000,10 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, L
|
||||
image = latents
|
||||
return StableDiffusionXLPipelineOutput(images=image)
|
||||
|
||||
image = self.watermark.apply_watermark(image)
|
||||
# apply watermark if available
|
||||
if self.watermark is not None:
|
||||
image = self.watermark.apply_watermark(image)
|
||||
|
||||
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||
|
||||
# Offload last model to CPU
|
||||
|
||||
@@ -30,10 +30,20 @@ from ...models.attention_processor import (
|
||||
XFormersAttnProcessor,
|
||||
)
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import is_accelerate_available, is_accelerate_version, logging, randn_tensor, replace_example_docstring
|
||||
from ...utils import (
|
||||
is_accelerate_available,
|
||||
is_accelerate_version,
|
||||
is_invisible_watermark_available,
|
||||
logging,
|
||||
randn_tensor,
|
||||
replace_example_docstring,
|
||||
)
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from . import StableDiffusionXLPipelineOutput
|
||||
from .watermark import StableDiffusionXLWatermarker
|
||||
|
||||
|
||||
if is_invisible_watermark_available():
|
||||
from .watermark import StableDiffusionXLWatermarker
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
@@ -265,6 +275,7 @@ class StableDiffusionXLInpaintPipeline(
|
||||
scheduler: KarrasDiffusionSchedulers,
|
||||
requires_aesthetics_score: bool = False,
|
||||
force_zeros_for_empty_prompt: bool = True,
|
||||
add_watermarker: Optional[bool] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -282,7 +293,12 @@ class StableDiffusionXLInpaintPipeline(
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
|
||||
self.watermark = StableDiffusionXLWatermarker()
|
||||
add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
|
||||
|
||||
if add_watermarker:
|
||||
self.watermark = StableDiffusionXLWatermarker()
|
||||
else:
|
||||
self.watermark = None
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
|
||||
def enable_vae_slicing(self):
|
||||
@@ -1168,15 +1184,17 @@ class StableDiffusionXLInpaintPipeline(
|
||||
negative_aesthetic_score,
|
||||
dtype=prompt_embeds.dtype,
|
||||
)
|
||||
add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
|
||||
add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1)
|
||||
add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0)
|
||||
|
||||
prompt_embeds = prompt_embeds.to(device)
|
||||
add_text_embeds = add_text_embeds.to(device)
|
||||
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
|
||||
add_time_ids = add_time_ids.to(device)
|
||||
|
||||
# 11. Denoising loop
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
@@ -1264,6 +1282,10 @@ class StableDiffusionXLInpaintPipeline(
|
||||
else:
|
||||
return StableDiffusionXLPipelineOutput(images=latents)
|
||||
|
||||
# apply watermark if available
|
||||
if self.watermark is not None:
|
||||
image = self.watermark.apply_watermark(image)
|
||||
|
||||
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||
|
||||
# Offload last model to CPU
|
||||
|
||||
@@ -34,12 +34,16 @@ from ...utils import (
|
||||
deprecate,
|
||||
is_accelerate_available,
|
||||
is_accelerate_version,
|
||||
is_invisible_watermark_available,
|
||||
logging,
|
||||
randn_tensor,
|
||||
)
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from . import StableDiffusionXLPipelineOutput
|
||||
from .watermark import StableDiffusionXLWatermarker
|
||||
|
||||
|
||||
if is_invisible_watermark_available():
|
||||
from .watermark import StableDiffusionXLWatermarker
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
@@ -109,6 +113,7 @@ class StableDiffusionXLInstructPix2PixPipeline(DiffusionPipeline, FromSingleFile
|
||||
scheduler: KarrasDiffusionSchedulers,
|
||||
requires_aesthetics_score: bool = False,
|
||||
force_zeros_for_empty_prompt: bool = True,
|
||||
add_watermarker: Optional[bool] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -128,7 +133,12 @@ class StableDiffusionXLInstructPix2PixPipeline(DiffusionPipeline, FromSingleFile
|
||||
|
||||
self.vae.config.force_upcast = True # force the VAE to be in float32 mode, as it overflows in float16
|
||||
|
||||
self.watermark = StableDiffusionXLWatermarker()
|
||||
add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
|
||||
|
||||
if add_watermarker:
|
||||
self.watermark = StableDiffusionXLWatermarker()
|
||||
else:
|
||||
self.watermark = None
|
||||
|
||||
def enable_vae_slicing(self):
|
||||
r"""
|
||||
@@ -811,6 +821,7 @@ class StableDiffusionXLInstructPix2PixPipeline(DiffusionPipeline, FromSingleFile
|
||||
negative_aesthetic_score,
|
||||
dtype=prompt_embeds.dtype,
|
||||
)
|
||||
add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1)
|
||||
|
||||
original_prompt_embeds_len = len(prompt_embeds)
|
||||
original_add_text_embeds_len = len(add_text_embeds)
|
||||
@@ -819,6 +830,7 @@ class StableDiffusionXLInstructPix2PixPipeline(DiffusionPipeline, FromSingleFile
|
||||
if do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([prompt_embeds, negative_prompt_embeds], dim=0)
|
||||
add_text_embeds = torch.cat([add_text_embeds, negative_pooled_prompt_embeds], dim=0)
|
||||
add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1)
|
||||
add_time_ids = torch.cat([add_time_ids, add_neg_time_ids], dim=0)
|
||||
|
||||
# Make dimensions consistent
|
||||
@@ -828,7 +840,7 @@ class StableDiffusionXLInstructPix2PixPipeline(DiffusionPipeline, FromSingleFile
|
||||
|
||||
prompt_embeds = prompt_embeds.to(device).to(torch.float32)
|
||||
add_text_embeds = add_text_embeds.to(device).to(torch.float32)
|
||||
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
|
||||
add_time_ids = add_time_ids.to(device)
|
||||
|
||||
# 11. Denoising loop
|
||||
self.unet = self.unet.to(torch.float32)
|
||||
@@ -906,7 +918,10 @@ class StableDiffusionXLInstructPix2PixPipeline(DiffusionPipeline, FromSingleFile
|
||||
image = latents
|
||||
return StableDiffusionXLPipelineOutput(images=image)
|
||||
|
||||
image = self.watermark.apply_watermark(image)
|
||||
# apply watermark if available
|
||||
if self.watermark is not None:
|
||||
image = self.watermark.apply_watermark(image)
|
||||
|
||||
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||
|
||||
# Offload last model to CPU
|
||||
|
||||
@@ -1,6 +1,11 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
from imwatermark import WatermarkEncoder
|
||||
|
||||
from ...utils import is_invisible_watermark_available
|
||||
|
||||
|
||||
if is_invisible_watermark_available():
|
||||
from imwatermark import WatermarkEncoder
|
||||
|
||||
|
||||
# Copied from https://github.com/Stability-AI/generative-models/blob/613af104c6b85184091d42d374fef420eddb356d/scripts/demo/streamlit_helpers.py#L66
|
||||
|
||||
@@ -1,77 +0,0 @@
|
||||
# This file is autogenerated by the command `make fix-copies`, do not edit.
|
||||
from ..utils import DummyObject, requires_backends
|
||||
|
||||
|
||||
class StableDiffusionXLControlNetPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers", "invisible_watermark"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers", "invisible_watermark"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers", "invisible_watermark"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers", "invisible_watermark"])
|
||||
|
||||
|
||||
class StableDiffusionXLImg2ImgPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers", "invisible_watermark"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers", "invisible_watermark"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers", "invisible_watermark"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers", "invisible_watermark"])
|
||||
|
||||
|
||||
class StableDiffusionXLInpaintPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers", "invisible_watermark"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers", "invisible_watermark"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers", "invisible_watermark"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers", "invisible_watermark"])
|
||||
|
||||
|
||||
class StableDiffusionXLInstructPix2PixPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers", "invisible_watermark"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers", "invisible_watermark"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers", "invisible_watermark"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers", "invisible_watermark"])
|
||||
|
||||
|
||||
class StableDiffusionXLPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers", "invisible_watermark"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers", "invisible_watermark"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers", "invisible_watermark"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers", "invisible_watermark"])
|
||||
@@ -827,6 +827,81 @@ class StableDiffusionUpscalePipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class StableDiffusionXLControlNetPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class StableDiffusionXLImg2ImgPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class StableDiffusionXLInpaintPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class StableDiffusionXLInstructPix2PixPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class StableDiffusionXLPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class StableUnCLIPImg2ImgPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
|
||||
@@ -563,10 +563,10 @@ class SDXLLoraLoaderMixinTests(unittest.TestCase):
|
||||
projection_dim=32,
|
||||
)
|
||||
text_encoder = CLIPTextModel(text_encoder_config)
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip", local_files_only=True)
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
text_encoder_2 = CLIPTextModelWithProjection(text_encoder_config)
|
||||
tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip", local_files_only=True)
|
||||
tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
unet_lora_attn_procs, unet_lora_layers = create_unet_lora_layers(unet)
|
||||
text_encoder_one_lora_layers = create_text_encoder_lora_layers(text_encoder)
|
||||
@@ -737,8 +737,7 @@ class LoraIntegrationTests(unittest.TestCase):
|
||||
).images
|
||||
|
||||
images = images[0, -3:, -3:, -1].flatten()
|
||||
|
||||
expected = np.array([0.3636, 0.3708, 0.3694, 0.3679, 0.3829, 0.3677, 0.3692, 0.3688, 0.3292])
|
||||
expected = np.array([0.3725, 0.3767, 0.3761, 0.3796, 0.3827, 0.3763, 0.3831, 0.3809, 0.3392])
|
||||
|
||||
self.assertTrue(np.allclose(images, expected, atol=1e-4))
|
||||
|
||||
|
||||
@@ -210,6 +210,68 @@ class StableDiffusionUpscalePipelineFastTests(unittest.TestCase):
|
||||
image = output.images
|
||||
assert image.shape[0] == 2
|
||||
|
||||
def test_stable_diffusion_upscale_prompt_embeds(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
unet = self.dummy_cond_unet_upscale
|
||||
low_res_scheduler = DDPMScheduler()
|
||||
scheduler = DDIMScheduler(prediction_type="v_prediction")
|
||||
vae = self.dummy_vae
|
||||
text_encoder = self.dummy_text_encoder
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0]
|
||||
low_res_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((64, 64))
|
||||
|
||||
# make sure here that pndm scheduler skips prk
|
||||
sd_pipe = StableDiffusionUpscalePipeline(
|
||||
unet=unet,
|
||||
low_res_scheduler=low_res_scheduler,
|
||||
scheduler=scheduler,
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
max_noise_level=350,
|
||||
)
|
||||
sd_pipe = sd_pipe.to(device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "A painting of a squirrel eating a burger"
|
||||
generator = torch.Generator(device=device).manual_seed(0)
|
||||
output = sd_pipe(
|
||||
[prompt],
|
||||
image=low_res_image,
|
||||
generator=generator,
|
||||
guidance_scale=6.0,
|
||||
noise_level=20,
|
||||
num_inference_steps=2,
|
||||
output_type="np",
|
||||
)
|
||||
|
||||
image = output.images
|
||||
|
||||
generator = torch.Generator(device=device).manual_seed(0)
|
||||
prompt_embeds = sd_pipe._encode_prompt(prompt, device, 1, False)
|
||||
image_from_prompt_embeds = sd_pipe(
|
||||
prompt_embeds=prompt_embeds,
|
||||
image=[low_res_image],
|
||||
generator=generator,
|
||||
guidance_scale=6.0,
|
||||
noise_level=20,
|
||||
num_inference_steps=2,
|
||||
output_type="np",
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
image_from_prompt_embeds_slice = image_from_prompt_embeds[0, -3:, -3:, -1]
|
||||
|
||||
expected_height_width = low_res_image.size[0] * 4
|
||||
assert image.shape == (1, expected_height_width, expected_height_width, 3)
|
||||
expected_slice = np.array([0.3113, 0.3910, 0.4272, 0.4859, 0.5061, 0.4652, 0.5362, 0.5715, 0.5661])
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
assert np.abs(image_from_prompt_embeds_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
@unittest.skipIf(torch_device != "cuda", "This test requires a GPU")
|
||||
def test_stable_diffusion_upscale_fp16(self):
|
||||
"""Test that stable diffusion upscale works with fp16"""
|
||||
|
||||
@@ -100,10 +100,10 @@ class StableDiffusionXLPipelineFastTests(PipelineLatentTesterMixin, PipelineTest
|
||||
projection_dim=32,
|
||||
)
|
||||
text_encoder = CLIPTextModel(text_encoder_config)
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip", local_files_only=True)
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
text_encoder_2 = CLIPTextModelWithProjection(text_encoder_config)
|
||||
tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip", local_files_only=True)
|
||||
tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
components = {
|
||||
"unet": unet,
|
||||
|
||||
@@ -64,7 +64,7 @@ class StableDiffusionXLImg2ImgPipelineFastTests(PipelineLatentTesterMixin, Pipel
|
||||
addition_embed_type="text_time",
|
||||
addition_time_embed_dim=8,
|
||||
transformer_layers_per_block=(1, 2),
|
||||
projection_class_embeddings_input_dim=80, # 6 * 8 + 32
|
||||
projection_class_embeddings_input_dim=72, # 5 * 8 + 32
|
||||
cross_attention_dim=64 if not skip_first_text_encoder else 32,
|
||||
)
|
||||
scheduler = EulerDiscreteScheduler(
|
||||
@@ -100,10 +100,10 @@ class StableDiffusionXLImg2ImgPipelineFastTests(PipelineLatentTesterMixin, Pipel
|
||||
projection_dim=32,
|
||||
)
|
||||
text_encoder = CLIPTextModel(text_encoder_config)
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip", local_files_only=True)
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
text_encoder_2 = CLIPTextModelWithProjection(text_encoder_config)
|
||||
tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip", local_files_only=True)
|
||||
tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
components = {
|
||||
"unet": unet,
|
||||
@@ -113,9 +113,18 @@ class StableDiffusionXLImg2ImgPipelineFastTests(PipelineLatentTesterMixin, Pipel
|
||||
"tokenizer": tokenizer if not skip_first_text_encoder else None,
|
||||
"text_encoder_2": text_encoder_2,
|
||||
"tokenizer_2": tokenizer_2,
|
||||
"requires_aesthetics_score": True,
|
||||
}
|
||||
return components
|
||||
|
||||
def test_components_function(self):
|
||||
init_components = self.get_dummy_components()
|
||||
init_components.pop("requires_aesthetics_score")
|
||||
pipe = self.pipeline_class(**init_components)
|
||||
|
||||
self.assertTrue(hasattr(pipe, "components"))
|
||||
self.assertTrue(set(pipe.components.keys()) == set(init_components.keys()))
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
|
||||
image = image / 2 + 0.5
|
||||
@@ -147,7 +156,7 @@ class StableDiffusionXLImg2ImgPipelineFastTests(PipelineLatentTesterMixin, Pipel
|
||||
|
||||
assert image.shape == (1, 32, 32, 3)
|
||||
|
||||
expected_slice = np.array([0.4656, 0.4840, 0.4439, 0.6698, 0.5574, 0.4524, 0.5799, 0.5943, 0.5165])
|
||||
expected_slice = np.array([0.4664, 0.4886, 0.4403, 0.6902, 0.5592, 0.4534, 0.5931, 0.5951, 0.5224])
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
@@ -165,7 +174,7 @@ class StableDiffusionXLImg2ImgPipelineFastTests(PipelineLatentTesterMixin, Pipel
|
||||
|
||||
assert image.shape == (1, 32, 32, 3)
|
||||
|
||||
expected_slice = np.array([0.4676, 0.4865, 0.4335, 0.6715, 0.5578, 0.4497, 0.5847, 0.5967, 0.5198])
|
||||
expected_slice = np.array([0.4578, 0.4981, 0.4301, 0.6454, 0.5588, 0.4442, 0.5678, 0.5940, 0.5176])
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
|
||||
@@ -66,7 +66,7 @@ class StableDiffusionXLInpaintPipelineFastTests(PipelineLatentTesterMixin, Pipel
|
||||
addition_embed_type="text_time",
|
||||
addition_time_embed_dim=8,
|
||||
transformer_layers_per_block=(1, 2),
|
||||
projection_class_embeddings_input_dim=80, # 6 * 8 + 32
|
||||
projection_class_embeddings_input_dim=72, # 5 * 8 + 32
|
||||
cross_attention_dim=64 if not skip_first_text_encoder else 32,
|
||||
)
|
||||
scheduler = EulerDiscreteScheduler(
|
||||
@@ -102,10 +102,10 @@ class StableDiffusionXLInpaintPipelineFastTests(PipelineLatentTesterMixin, Pipel
|
||||
projection_dim=32,
|
||||
)
|
||||
text_encoder = CLIPTextModel(text_encoder_config)
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip", local_files_only=True)
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
text_encoder_2 = CLIPTextModelWithProjection(text_encoder_config)
|
||||
tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip", local_files_only=True)
|
||||
tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
components = {
|
||||
"unet": unet,
|
||||
@@ -115,6 +115,7 @@ class StableDiffusionXLInpaintPipelineFastTests(PipelineLatentTesterMixin, Pipel
|
||||
"tokenizer": tokenizer if not skip_first_text_encoder else None,
|
||||
"text_encoder_2": text_encoder_2,
|
||||
"tokenizer_2": tokenizer_2,
|
||||
"requires_aesthetics_score": True,
|
||||
}
|
||||
return components
|
||||
|
||||
@@ -142,6 +143,14 @@ class StableDiffusionXLInpaintPipelineFastTests(PipelineLatentTesterMixin, Pipel
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_components_function(self):
|
||||
init_components = self.get_dummy_components()
|
||||
init_components.pop("requires_aesthetics_score")
|
||||
pipe = self.pipeline_class(**init_components)
|
||||
|
||||
self.assertTrue(hasattr(pipe, "components"))
|
||||
self.assertTrue(set(pipe.components.keys()) == set(init_components.keys()))
|
||||
|
||||
def test_stable_diffusion_xl_inpaint_euler(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
@@ -155,7 +164,7 @@ class StableDiffusionXLInpaintPipelineFastTests(PipelineLatentTesterMixin, Pipel
|
||||
|
||||
assert image.shape == (1, 64, 64, 3)
|
||||
|
||||
expected_slice = np.array([0.6965, 0.5584, 0.5693, 0.5739, 0.6092, 0.6620, 0.5902, 0.5612, 0.5319])
|
||||
expected_slice = np.array([0.8029, 0.5523, 0.5825, 0.6003, 0.6702, 0.7018, 0.6369, 0.5955, 0.5123])
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
@@ -250,10 +259,9 @@ class StableDiffusionXLInpaintPipelineFastTests(PipelineLatentTesterMixin, Pipel
|
||||
image = sd_pipe(**inputs).images
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
print(torch.from_numpy(image_slice).flatten())
|
||||
assert image.shape == (1, 64, 64, 3)
|
||||
|
||||
expected_slice = np.array([0.9106, 0.6563, 0.6766, 0.6537, 0.6709, 0.7367, 0.6537, 0.5937, 0.5418])
|
||||
expected_slice = np.array([0.7045, 0.4838, 0.5454, 0.6270, 0.6168, 0.6717, 0.6484, 0.5681, 0.4922])
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
|
||||
@@ -68,7 +68,7 @@ class StableDiffusionXLInstructPix2PixPipelineFastTests(
|
||||
addition_embed_type="text_time",
|
||||
addition_time_embed_dim=8,
|
||||
transformer_layers_per_block=(1, 2),
|
||||
projection_class_embeddings_input_dim=80, # 6 * 8 + 32
|
||||
projection_class_embeddings_input_dim=72, # 5 * 8 + 32
|
||||
cross_attention_dim=64,
|
||||
)
|
||||
|
||||
@@ -105,10 +105,10 @@ class StableDiffusionXLInstructPix2PixPipelineFastTests(
|
||||
projection_dim=32,
|
||||
)
|
||||
text_encoder = CLIPTextModel(text_encoder_config)
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip", local_files_only=True)
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
text_encoder_2 = CLIPTextModelWithProjection(text_encoder_config)
|
||||
tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip", local_files_only=True)
|
||||
tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
components = {
|
||||
"unet": unet,
|
||||
@@ -118,8 +118,7 @@ class StableDiffusionXLInstructPix2PixPipelineFastTests(
|
||||
"tokenizer": tokenizer,
|
||||
"text_encoder_2": text_encoder_2,
|
||||
"tokenizer_2": tokenizer_2,
|
||||
# "safety_checker": None,
|
||||
# "feature_extractor": None,
|
||||
"requires_aesthetics_score": True,
|
||||
}
|
||||
return components
|
||||
|
||||
@@ -141,6 +140,14 @@ class StableDiffusionXLInstructPix2PixPipelineFastTests(
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_components_function(self):
|
||||
init_components = self.get_dummy_components()
|
||||
init_components.pop("requires_aesthetics_score")
|
||||
pipe = self.pipeline_class(**init_components)
|
||||
|
||||
self.assertTrue(hasattr(pipe, "components"))
|
||||
self.assertTrue(set(pipe.components.keys()) == set(init_components.keys()))
|
||||
|
||||
def test_inference_batch_single_identical(self):
|
||||
super().test_inference_batch_single_identical(expected_max_diff=3e-3)
|
||||
|
||||
|
||||
@@ -310,6 +310,49 @@ class DownloadTests(unittest.TestCase):
|
||||
assert len([f for f in files if ".bin" in f]) == 8
|
||||
assert not any(".safetensors" in f for f in files)
|
||||
|
||||
def test_download_no_openvino_by_default(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
tmpdirname = DiffusionPipeline.download(
|
||||
"hf-internal-testing/tiny-stable-diffusion-open-vino",
|
||||
cache_dir=tmpdirname,
|
||||
)
|
||||
|
||||
all_root_files = [t[-1] for t in os.walk(os.path.join(tmpdirname))]
|
||||
files = [item for sublist in all_root_files for item in sublist]
|
||||
|
||||
# make sure that by default no openvino weights are downloaded
|
||||
assert all((f.endswith(".json") or f.endswith(".bin") or f.endswith(".txt")) for f in files)
|
||||
assert not any("openvino_" in f for f in files)
|
||||
|
||||
def test_download_no_onnx_by_default(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
tmpdirname = DiffusionPipeline.download(
|
||||
"hf-internal-testing/tiny-random-OnnxStableDiffusionPipeline",
|
||||
cache_dir=tmpdirname,
|
||||
)
|
||||
|
||||
all_root_files = [t[-1] for t in os.walk(os.path.join(tmpdirname))]
|
||||
files = [item for sublist in all_root_files for item in sublist]
|
||||
|
||||
# make sure that by default no onnx weights are downloaded
|
||||
assert all((f.endswith(".json") or f.endswith(".bin") or f.endswith(".txt")) for f in files)
|
||||
assert not any((f.endswith(".onnx") or f.endswith(".pb")) for f in files)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
tmpdirname = DiffusionPipeline.download(
|
||||
"hf-internal-testing/tiny-random-OnnxStableDiffusionPipeline",
|
||||
cache_dir=tmpdirname,
|
||||
use_onnx=True,
|
||||
)
|
||||
|
||||
all_root_files = [t[-1] for t in os.walk(os.path.join(tmpdirname))]
|
||||
files = [item for sublist in all_root_files for item in sublist]
|
||||
|
||||
# if `use_onnx` is specified make sure weights are downloaded
|
||||
assert any((f.endswith(".json") or f.endswith(".bin") or f.endswith(".txt")) for f in files)
|
||||
assert any((f.endswith(".onnx")) for f in files)
|
||||
assert any((f.endswith(".pb")) for f in files)
|
||||
|
||||
def test_download_no_safety_checker(self):
|
||||
prompt = "hello"
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
@@ -374,7 +417,7 @@ class DownloadTests(unittest.TestCase):
|
||||
response_mock.json.return_value = {}
|
||||
|
||||
# Download this model to make sure it's in the cache.
|
||||
orig_pipe = StableDiffusionPipeline.from_pretrained(
|
||||
orig_pipe = DiffusionPipeline.from_pretrained(
|
||||
"hf-internal-testing/tiny-stable-diffusion-torch", safety_checker=None
|
||||
)
|
||||
orig_comps = {k: v for k, v in orig_pipe.components.items() if hasattr(v, "parameters")}
|
||||
@@ -382,7 +425,7 @@ class DownloadTests(unittest.TestCase):
|
||||
# Under the mock environment we get a 500 error when trying to reach the model.
|
||||
with mock.patch("requests.request", return_value=response_mock):
|
||||
# Download this model to make sure it's in the cache.
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
"hf-internal-testing/tiny-stable-diffusion-torch", safety_checker=None
|
||||
)
|
||||
comps = {k: v for k, v in pipe.components.items() if hasattr(v, "parameters")}
|
||||
@@ -392,6 +435,42 @@ class DownloadTests(unittest.TestCase):
|
||||
if p1.data.ne(p2.data).sum() > 0:
|
||||
assert False, "Parameters not the same!"
|
||||
|
||||
def test_local_files_only_are_used_when_no_internet(self):
|
||||
# A mock response for an HTTP head request to emulate server down
|
||||
response_mock = mock.Mock()
|
||||
response_mock.status_code = 500
|
||||
response_mock.headers = {}
|
||||
response_mock.raise_for_status.side_effect = HTTPError
|
||||
response_mock.json.return_value = {}
|
||||
|
||||
# first check that with local files only the pipeline can only be used if cached
|
||||
with self.assertRaises(FileNotFoundError):
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
orig_pipe = DiffusionPipeline.from_pretrained(
|
||||
"hf-internal-testing/tiny-stable-diffusion-torch", local_files_only=True, cache_dir=tmpdirname
|
||||
)
|
||||
|
||||
# now download
|
||||
orig_pipe = DiffusionPipeline.download("hf-internal-testing/tiny-stable-diffusion-torch")
|
||||
|
||||
# make sure it can be loaded with local_files_only
|
||||
orig_pipe = DiffusionPipeline.from_pretrained(
|
||||
"hf-internal-testing/tiny-stable-diffusion-torch", local_files_only=True
|
||||
)
|
||||
orig_comps = {k: v for k, v in orig_pipe.components.items() if hasattr(v, "parameters")}
|
||||
|
||||
# Under the mock environment we get a 500 error when trying to connect to the internet.
|
||||
# Make sure it works local_files_only only works here!
|
||||
with mock.patch("requests.request", return_value=response_mock):
|
||||
# Download this model to make sure it's in the cache.
|
||||
pipe = DiffusionPipeline.from_pretrained("hf-internal-testing/tiny-stable-diffusion-torch")
|
||||
comps = {k: v for k, v in pipe.components.items() if hasattr(v, "parameters")}
|
||||
|
||||
for m1, m2 in zip(orig_comps.values(), comps.values()):
|
||||
for p1, p2 in zip(m1.parameters(), m2.parameters()):
|
||||
if p1.data.ne(p2.data).sum() > 0:
|
||||
assert False, "Parameters not the same!"
|
||||
|
||||
def test_download_from_variant_folder(self):
|
||||
for safe_avail in [False, True]:
|
||||
import diffusers
|
||||
|
||||
@@ -387,7 +387,7 @@ class PipelineTesterMixin:
|
||||
batched_inputs[name] = [value[: len_prompt // i] for i in range(1, batch_size + 1)]
|
||||
|
||||
# make last batch super long
|
||||
batched_inputs[name][-1] = 2000 * "very long"
|
||||
batched_inputs[name][-1] = 100 * "very long"
|
||||
# or else we have images
|
||||
else:
|
||||
batched_inputs[name] = batch_size * [value]
|
||||
@@ -462,7 +462,7 @@ class PipelineTesterMixin:
|
||||
batched_inputs[name] = [value[: len_prompt // i] for i in range(1, batch_size + 1)]
|
||||
|
||||
# make last batch super long
|
||||
batched_inputs[name][-1] = 2000 * "very long"
|
||||
batched_inputs[name][-1] = 100 * "very long"
|
||||
# or else we have images
|
||||
else:
|
||||
batched_inputs[name] = batch_size * [value]
|
||||
|
||||
Reference in New Issue
Block a user