mirror of
https://github.com/huggingface/diffusers.git
synced 2026-02-12 22:05:25 +08:00
Compare commits
42 Commits
v0.10.2
...
improve_co
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7965655fd3 | ||
|
|
f22326de59 | ||
|
|
8cecc66a74 | ||
|
|
35b66c8e32 | ||
|
|
013edb641a | ||
|
|
2595aa0c2f | ||
|
|
86ac3ea1d7 | ||
|
|
ef3fcbb688 | ||
|
|
4725e488b9 | ||
|
|
4ab89f22fd | ||
|
|
7c823c2ed7 | ||
|
|
784beee969 | ||
|
|
8b7cb962a5 | ||
|
|
e1bb8f6188 | ||
|
|
e62dd5cfa8 | ||
|
|
07f95503e5 | ||
|
|
e01d6cf295 | ||
|
|
244e16a7ab | ||
|
|
b345c74d4d | ||
|
|
b417042291 | ||
|
|
40c16ed2f0 | ||
|
|
69de9b2eaa | ||
|
|
3ce6380d3a | ||
|
|
d2dc4de303 | ||
|
|
ded3299d68 | ||
|
|
8bf5e59931 | ||
|
|
4645e28355 | ||
|
|
589330595d | ||
|
|
31444f5790 | ||
|
|
c3b2f97534 | ||
|
|
fc94c60c83 | ||
|
|
ea64a7860a | ||
|
|
2868d99181 | ||
|
|
0c18d02cc9 | ||
|
|
6b68afd8e4 | ||
|
|
63c4944998 | ||
|
|
3ebe40fc5f | ||
|
|
089252542c | ||
|
|
cd91fc06fe | ||
|
|
ff65c2d72b | ||
|
|
f1b726e46e | ||
|
|
f242eba4fd |
48
README.md
48
README.md
@@ -29,13 +29,13 @@ More precisely, 🤗 Diffusers offers:
|
||||
|
||||
### For PyTorch
|
||||
|
||||
**With `pip`**
|
||||
**With `pip`** (official package)
|
||||
|
||||
```bash
|
||||
pip install --upgrade diffusers[torch]
|
||||
```
|
||||
|
||||
**With `conda`**
|
||||
**With `conda`** (maintained by the community)
|
||||
|
||||
```sh
|
||||
conda install -c conda-forge diffusers
|
||||
@@ -79,19 +79,13 @@ In order to get started, we recommend taking a look at two notebooks:
|
||||
Stable Diffusion is a text-to-image latent diffusion model created by the researchers and engineers from [CompVis](https://github.com/CompVis), [Stability AI](https://stability.ai/), [LAION](https://laion.ai/) and [RunwayML](https://runwayml.com/). It's trained on 512x512 images from a subset of the [LAION-5B](https://laion.ai/blog/laion-5b/) database. This model uses a frozen CLIP ViT-L/14 text encoder to condition the model on text prompts. With its 860M UNet and 123M text encoder, the model is relatively lightweight and runs on a GPU with at least 4GB VRAM.
|
||||
See the [model card](https://huggingface.co/CompVis/stable-diffusion) for more information.
|
||||
|
||||
You need to accept the model license before downloading or using the Stable Diffusion weights. Please, visit the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5), read the license carefully and tick the checkbox if you agree. You have to be a registered user in 🤗 Hugging Face Hub, and you'll also need to use an access token for the code to work. For more information on access tokens, please refer to [this section](https://huggingface.co/docs/hub/security-tokens) of the documentation.
|
||||
|
||||
|
||||
### Text-to-Image generation with Stable Diffusion
|
||||
|
||||
First let's install
|
||||
```bash
|
||||
pip install --upgrade diffusers transformers scipy
|
||||
```
|
||||
|
||||
Run this command to log in with your HF Hub token if you haven't before (you can skip this step if you prefer to run the model locally, follow [this](#running-the-model-locally) instead)
|
||||
```bash
|
||||
huggingface-cli login
|
||||
pip install --upgrade diffusers transformers accelerate
|
||||
```
|
||||
|
||||
We recommend using the model in [half-precision (`fp16`)](https://pytorch.org/blog/accelerating-training-on-nvidia-gpus-with-pytorch-automatic-mixed-precision/) as it gives almost always the same results as full
|
||||
@@ -101,7 +95,7 @@ precision while being roughly twice as fast and requiring half the amount of GPU
|
||||
import torch
|
||||
from diffusers import StableDiffusionPipeline
|
||||
|
||||
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, revision="fp16")
|
||||
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
|
||||
pipe = pipe.to("cuda")
|
||||
|
||||
prompt = "a photo of an astronaut riding a horse on mars"
|
||||
@@ -109,17 +103,16 @@ image = pipe(prompt).images[0]
|
||||
```
|
||||
|
||||
#### Running the model locally
|
||||
If you don't want to login to Hugging Face, you can also simply download the model folder
|
||||
(after having [accepted the license](https://huggingface.co/runwayml/stable-diffusion-v1-5)) and pass
|
||||
the path to the local folder to the `StableDiffusionPipeline`.
|
||||
|
||||
You can also simply download the model folder and pass the path to the local folder to the `StableDiffusionPipeline`.
|
||||
|
||||
```
|
||||
git lfs install
|
||||
git clone https://huggingface.co/runwayml/stable-diffusion-v1-5
|
||||
```
|
||||
|
||||
Assuming the folder is stored locally under `./stable-diffusion-v1-5`, you can also run stable diffusion
|
||||
without requiring an authentication token:
|
||||
Assuming the folder is stored locally under `./stable-diffusion-v1-5`, you can run stable diffusion
|
||||
as follows:
|
||||
|
||||
```python
|
||||
pipe = StableDiffusionPipeline.from_pretrained("./stable-diffusion-v1-5")
|
||||
@@ -134,11 +127,7 @@ to using `fp16`.
|
||||
The following snippet should result in less than 4GB VRAM.
|
||||
|
||||
```python
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5",
|
||||
revision="fp16",
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
|
||||
pipe = pipe.to("cuda")
|
||||
|
||||
prompt = "a photo of an astronaut riding a horse on mars"
|
||||
@@ -164,7 +153,6 @@ If you want to run Stable Diffusion on CPU or you want to have maximum precision
|
||||
please run the model in the default *full-precision* setting:
|
||||
|
||||
```python
|
||||
# make sure you're logged in with `huggingface-cli login`
|
||||
from diffusers import StableDiffusionPipeline
|
||||
|
||||
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
|
||||
@@ -262,11 +250,8 @@ from diffusers import StableDiffusionImg2ImgPipeline
|
||||
# load the pipeline
|
||||
device = "cuda"
|
||||
model_id_or_path = "runwayml/stable-diffusion-v1-5"
|
||||
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
|
||||
model_id_or_path,
|
||||
revision="fp16",
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16)
|
||||
|
||||
# or download via git clone https://huggingface.co/runwayml/stable-diffusion-v1-5
|
||||
# and pass `model_id_or_path="./stable-diffusion-v1-5"`.
|
||||
pipe = pipe.to(device)
|
||||
@@ -288,10 +273,7 @@ You can also run this example on colab [, read the license carefully and tick the checkbox if you agree. Note that this is an additional license, you need to accept it even if you accepted the text-to-image Stable Diffusion license in the past. You have to be a registered user in 🤗 Hugging Face Hub, and you'll also need to use an access token for the code to work. For more information on access tokens, please refer to [this section](https://huggingface.co/docs/hub/security-tokens) of the documentation.
|
||||
|
||||
The `StableDiffusionInpaintPipeline` lets you edit specific parts of an image by providing a mask and a text prompt.
|
||||
|
||||
```python
|
||||
import PIL
|
||||
@@ -311,11 +293,7 @@ mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data
|
||||
init_image = download_image(img_url).resize((512, 512))
|
||||
mask_image = download_image(mask_url).resize((512, 512))
|
||||
|
||||
pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-inpainting",
|
||||
revision="fp16",
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
pipe = StableDiffusionInpaintPipeline.from_pretrained("runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16)
|
||||
pipe = pipe.to("cuda")
|
||||
|
||||
prompt = "Face of a yellow cat, high resolution, sitting on a park bench"
|
||||
|
||||
@@ -26,6 +26,8 @@
|
||||
title: "Text-Guided Image-to-Image"
|
||||
- local: using-diffusers/inpaint
|
||||
title: "Text-Guided Image-Inpainting"
|
||||
- local: using-diffusers/depth2img
|
||||
title: "Text-Guided Depth-to-Image"
|
||||
- local: using-diffusers/custom_pipeline_examples
|
||||
title: "Community Pipelines"
|
||||
- local: using-diffusers/contribute_pipeline
|
||||
|
||||
@@ -18,7 +18,7 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
# 🧨 Diffusers
|
||||
|
||||
🤗 Diffusers provides pretrained vision diffusion models, and serves as a modular toolbox for inference and training.
|
||||
🤗 Diffusers provides pretrained vision and audio diffusion models, and serves as a modular toolbox for inference and training.
|
||||
|
||||
More precisely, 🤗 Diffusers offers:
|
||||
|
||||
|
||||
@@ -127,7 +127,8 @@ Our library gathers telemetry information during `from_pretrained()` requests.
|
||||
This data includes the version of Diffusers and PyTorch/Flax, the requested model or pipeline class,
|
||||
and the path to a pretrained checkpoint if it is hosted on the Hub.
|
||||
This usage data helps us debug issues and prioritize new features.
|
||||
No private data, such as paths to models saved locally on disk, is ever collected.
|
||||
Telemetry is only sent when loading models and pipelines from the HuggingFace Hub,
|
||||
and is not collected during local usage.
|
||||
|
||||
We understand that not everyone wants to share additional information, and we respect your privacy,
|
||||
so you can disable telemetry collection by setting the `DISABLE_TELEMETRY` environment variable from your terminal:
|
||||
|
||||
@@ -18,9 +18,12 @@ Whether you're a developer or an everyday user, this quick tour will help you ge
|
||||
Before you begin, make sure you have all the necessary libraries installed:
|
||||
|
||||
```bash
|
||||
pip install --upgrade diffusers
|
||||
pip install --upgrade diffusers accelerate transformers
|
||||
```
|
||||
|
||||
- [`accelerate`](https://huggingface.co/docs/accelerate/index) speeds up model loading for inference and training
|
||||
- [`transformers`](https://huggingface.co/docs/transformers/index) is required to run the most popular diffusion models, such as [Stable Diffusion](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion)
|
||||
|
||||
## DiffusionPipeline
|
||||
|
||||
The [`DiffusionPipeline`] is the easiest way to use a pre-trained diffusion system for inference. You can use the [`DiffusionPipeline`] out-of-the-box for many tasks across different modalities. Take a look at the table below for some supported tasks:
|
||||
@@ -29,19 +32,26 @@ The [`DiffusionPipeline`] is the easiest way to use a pre-trained diffusion syst
|
||||
|------------------------------|--------------------------------------------------------------------------------------------------------------|-----------------|
|
||||
| Unconditional Image Generation | generate an image from gaussian noise | [unconditional_image_generation](./using-diffusers/unconditional_image_generation`) |
|
||||
| Text-Guided Image Generation | generate an image given a text prompt | [conditional_image_generation](./using-diffusers/conditional_image_generation) |
|
||||
| Text-Guided Image-to-Image Translation | generate an image given an original image and a text prompt | [img2img](./using-diffusers/img2img) |
|
||||
| Text-Guided Image-to-Image Translation | adapt an image guided by a text prompt | [img2img](./using-diffusers/img2img) |
|
||||
| Text-Guided Image-Inpainting | fill the masked part of an image given the image, the mask and a text prompt | [inpaint](./using-diffusers/inpaint) |
|
||||
| Text-Guided Depth-to-Image Translation | adapt parts of an image guided by a text prompt while preserving structure via depth estimation | [depth2image](./using-diffusers/depth2image) |
|
||||
|
||||
For more in-detail information on how diffusion pipelines function for the different tasks, please have a look at the [**Using Diffusers**](./using-diffusers/overview) section.
|
||||
|
||||
As an example, start by creating an instance of [`DiffusionPipeline`] and specify which pipeline checkpoint you would like to download.
|
||||
You can use the [`DiffusionPipeline`] for any [Diffusers' checkpoint](https://huggingface.co/models?library=diffusers&sort=downloads).
|
||||
In this guide though, you'll use [`DiffusionPipeline`] for text-to-image generation with [Latent Diffusion](https://huggingface.co/CompVis/ldm-text2im-large-256):
|
||||
In this guide though, you'll use [`DiffusionPipeline`] for text-to-image generation with [Stable Diffusion](https://huggingface.co/CompVis/stable-diffusion).
|
||||
|
||||
For [Stable Diffusion](https://huggingface.co/CompVis/stable-diffusion), please carefully read its [license](https://huggingface.co/spaces/CompVis/stable-diffusion-license) before running the model.
|
||||
This is due to the improved image generation capabilities of the model and the potentially harmful content that could be produced with it.
|
||||
Please, head over to your stable diffusion model of choice, *e.g.* [`runwayml/stable-diffusion-v1-5`](https://huggingface.co/runwayml/stable-diffusion-v1-5), and read the license.
|
||||
|
||||
You can load the model as follows:
|
||||
|
||||
```python
|
||||
>>> from diffusers import DiffusionPipeline
|
||||
|
||||
>>> pipeline = DiffusionPipeline.from_pretrained("CompVis/ldm-text2im-large-256")
|
||||
>>> pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
|
||||
```
|
||||
|
||||
The [`DiffusionPipeline`] downloads and caches all modeling, tokenization, and scheduling components.
|
||||
@@ -66,40 +76,14 @@ You can save the image by simply calling:
|
||||
>>> image.save("image_of_squirrel_painting.png")
|
||||
```
|
||||
|
||||
More advanced models, like [Stable Diffusion](https://huggingface.co/CompVis/stable-diffusion) require you to accept a [license](https://huggingface.co/spaces/CompVis/stable-diffusion-license) before running the model.
|
||||
This is due to the improved image generation capabilities of the model and the potentially harmful content that could be produced with it.
|
||||
Please, head over to your stable diffusion model of choice, *e.g.* [`runwayml/stable-diffusion-v1-5`](https://huggingface.co/runwayml/stable-diffusion-v1-5), read the license carefully and tick the checkbox if you agree.
|
||||
You have to be a registered user in 🤗 Hugging Face Hub, and you'll also need to use an access token for the code to work. For more information on access tokens, please refer to [this section of the documentation](https://huggingface.co/docs/hub/security-tokens).
|
||||
Having "click-accepted" the license, you can save your token:
|
||||
|
||||
```python
|
||||
AUTH_TOKEN = "<please-fill-with-your-token>"
|
||||
```
|
||||
|
||||
You can then load [`runwayml/stable-diffusion-v1-5`](https://huggingface.co/runwayml/stable-diffusion-v1-5)
|
||||
just like we did before only that now you need to pass your `AUTH_TOKEN`:
|
||||
|
||||
```python
|
||||
>>> from diffusers import DiffusionPipeline
|
||||
|
||||
>>> pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", use_auth_token=AUTH_TOKEN)
|
||||
```
|
||||
|
||||
If you do not pass your authentication token you will see that the diffusion system will not be correctly
|
||||
downloaded. Forcing the user to pass an authentication token ensures that it can be verified that the
|
||||
user has indeed read and accepted the license, which also means that an internet connection is required.
|
||||
|
||||
**Note**: If you do not want to be forced to pass an authentication token, you can also simply download
|
||||
the weights locally via:
|
||||
**Note**: You can also use the pipeline locally by downloading the weights via:
|
||||
|
||||
```
|
||||
git lfs install
|
||||
git clone https://huggingface.co/runwayml/stable-diffusion-v1-5
|
||||
```
|
||||
|
||||
and then load locally saved weights into the pipeline. This way, you do not need to pass an authentication
|
||||
token. Assuming that `"./stable-diffusion-v1-5"` is the local path to the cloned stable-diffusion-v1-5 repo,
|
||||
you can also load the pipeline as follows:
|
||||
and then loading the saved weights into the pipeline.
|
||||
|
||||
```python
|
||||
>>> pipeline = DiffusionPipeline.from_pretrained("./stable-diffusion-v1-5")
|
||||
@@ -121,7 +105,7 @@ you could use it as follows:
|
||||
```python
|
||||
>>> from diffusers import EulerDiscreteScheduler
|
||||
|
||||
>>> pipeline = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", use_auth_token=AUTH_TOKEN)
|
||||
>>> pipeline = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
|
||||
|
||||
>>> # change scheduler to Euler
|
||||
>>> pipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config)
|
||||
|
||||
35
docs/source/using-diffusers/depth2img.mdx
Normal file
35
docs/source/using-diffusers/depth2img.mdx
Normal file
@@ -0,0 +1,35 @@
|
||||
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Text-Guided Image-to-Image Generation
|
||||
|
||||
The [`StableDiffusionDepth2ImgPipeline`] lets you pass a text prompt and an initial image to condition the generation of new images as well as a `depth_map` to preserve the images' structure. If no `depth_map` is provided, the pipeline will automatically predict the depth via an integrated depth-estimation model.
|
||||
|
||||
```python
|
||||
import torch
|
||||
import requests
|
||||
from PIL import Image
|
||||
|
||||
from diffusers import StableDiffusionDepth2ImgPipeline
|
||||
|
||||
pipe = StableDiffusionDepth2ImgPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-2-depth",
|
||||
torch_dtype=torch.float16,
|
||||
).to("cuda")
|
||||
|
||||
|
||||
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
init_image = Image.open(requests.get(url, stream=True).raw)
|
||||
prompt = "two tigers"
|
||||
n_prompt = "bad, deformed, ugly, bad anatomy"
|
||||
image = pipe(prompt=prompt, image=init_image, negative_prompt=n_prompt, strength=0.7).images[0]
|
||||
```
|
||||
@@ -23,7 +23,8 @@ If a community doesn't work as expected, please open an issue and ping the autho
|
||||
| Text Based Inpainting Stable Diffusion | Stable Diffusion Inpainting Pipeline that enables passing a text prompt to generate the mask for inpainting| [Text Based Inpainting Stable Diffusion](#image-to-image-inpainting-stable-diffusion) | - | [Dhruv Karan](https://github.com/unography) |
|
||||
| Bit Diffusion | Diffusion on discrete data | [Bit Diffusion](#bit-diffusion) | - |[Stuti R.](https://github.com/kingstut) |
|
||||
| K-Diffusion Stable Diffusion | Run Stable Diffusion with any of [K-Diffusion's samplers](https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py) | [Stable Diffusion with K Diffusion](#stable-diffusion-with-k-diffusion) | - | [Patrick von Platen](https://github.com/patrickvonplaten/) |
|
||||
| Checkpoint Merger Pipeline | Diffusion Pipeline that enables merging of saved model checkpoints | [Checkpoint Merger Pipeline](#checkpoint-merger-pipeline) | - | [Naga Sai Abhinay Devarinti](https://github.com/Abhinay1997/) |
|
||||
| Checkpoint Merger Pipeline | Diffusion Pipeline that enables merging of saved model checkpoints | [Checkpoint Merger Pipeline](#checkpoint-merger-pipeline) | - | [Naga Sai Abhinay Devarinti](https://github.com/Abhinay1997/) |
|
||||
Stable Diffusion v1.1-1.4 Comparison | Run all 4 model checkpoints for Stable Diffusion and compare their results together | [Stable Diffusion Comparison](#stable-diffusion-comparisons) | - | [Suvaditya Mukherjee](https://github.com/suvadityamuk) |
|
||||
|
||||
|
||||
|
||||
@@ -774,3 +775,44 @@ Some examples along with the merge details:
|
||||
3. "CompVis/stable-diffusion-v1-4" + "hakurei/waifu-diffusion" + "prompthero/openjourney"; Add Difference interpolation; alpha = 0.5
|
||||
|
||||

|
||||
|
||||
|
||||
### Stable Diffusion Comparisons
|
||||
|
||||
This Community Pipeline enables the comparison between the 4 checkpoints that exist for Stable Diffusion. They can be found through the following links:
|
||||
1. [Stable Diffusion v1.1](https://huggingface.co/CompVis/stable-diffusion-v1-1)
|
||||
2. [Stable Diffusion v1.2](https://huggingface.co/CompVis/stable-diffusion-v1-2)
|
||||
3. [Stable Diffusion v1.3](https://huggingface.co/CompVis/stable-diffusion-v1-3)
|
||||
4. [Stable Diffusion v1.4](https://huggingface.co/CompVis/stable-diffusion-v1-4)
|
||||
|
||||
```python
|
||||
from diffusers import DiffusionPipeline
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained('CompVis/stable-diffusion-v1-4', custom_pipeline='suvadityamuk/StableDiffusionComparison')
|
||||
pipe.enable_attention_slicing()
|
||||
pipe = pipe.to('cuda')
|
||||
prompt = "an astronaut riding a horse on mars"
|
||||
output = pipe(prompt)
|
||||
|
||||
plt.subplots(2,2,1)
|
||||
plt.imshow(output.images[0])
|
||||
plt.title('Stable Diffusion v1.1')
|
||||
plt.axis('off')
|
||||
plt.subplots(2,2,2)
|
||||
plt.imshow(output.images[1])
|
||||
plt.title('Stable Diffusion v1.2')
|
||||
plt.axis('off')
|
||||
plt.subplots(2,2,3)
|
||||
plt.imshow(output.images[2])
|
||||
plt.title('Stable Diffusion v1.3')
|
||||
plt.axis('off')
|
||||
plt.subplots(2,2,4)
|
||||
plt.imshow(output.images[3])
|
||||
plt.title('Stable Diffusion v1.4')
|
||||
plt.axis('off')
|
||||
|
||||
plt.show()
|
||||
```python
|
||||
|
||||
As a result, you can look at a grid of all 4 generated images being shown together, that captures a difference the advancement of the training between the 4 checkpoints.
|
||||
@@ -5,14 +5,37 @@ from typing import Callable, List, Optional, Union
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import diffusers
|
||||
import PIL
|
||||
from diffusers import SchedulerMixin, StableDiffusionPipeline
|
||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
|
||||
from diffusers.utils import PIL_INTERPOLATION, deprecate, logging
|
||||
from diffusers.utils import deprecate, logging
|
||||
from packaging import version
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
|
||||
try:
|
||||
from diffusers.utils import PIL_INTERPOLATION
|
||||
except ImportError:
|
||||
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
|
||||
PIL_INTERPOLATION = {
|
||||
"linear": PIL.Image.Resampling.BILINEAR,
|
||||
"bilinear": PIL.Image.Resampling.BILINEAR,
|
||||
"bicubic": PIL.Image.Resampling.BICUBIC,
|
||||
"lanczos": PIL.Image.Resampling.LANCZOS,
|
||||
"nearest": PIL.Image.Resampling.NEAREST,
|
||||
}
|
||||
else:
|
||||
PIL_INTERPOLATION = {
|
||||
"linear": PIL.Image.LINEAR,
|
||||
"bilinear": PIL.Image.BILINEAR,
|
||||
"bicubic": PIL.Image.BICUBIC,
|
||||
"lanczos": PIL.Image.LANCZOS,
|
||||
"nearest": PIL.Image.NEAREST,
|
||||
}
|
||||
# ------------------------------------------------------------------------------
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
re_attention = re.compile(
|
||||
@@ -404,27 +427,75 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vae: AutoencoderKL,
|
||||
text_encoder: CLIPTextModel,
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: SchedulerMixin,
|
||||
safety_checker: StableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
requires_safety_checker: bool = True,
|
||||
):
|
||||
super().__init__(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
requires_safety_checker=requires_safety_checker,
|
||||
)
|
||||
if version.parse(version.parse(diffusers.__version__).base_version) >= version.parse("0.9.0"):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vae: AutoencoderKL,
|
||||
text_encoder: CLIPTextModel,
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: SchedulerMixin,
|
||||
safety_checker: StableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
requires_safety_checker: bool = True,
|
||||
):
|
||||
super().__init__(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
requires_safety_checker=requires_safety_checker,
|
||||
)
|
||||
self.__init__additional__()
|
||||
|
||||
else:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vae: AutoencoderKL,
|
||||
text_encoder: CLIPTextModel,
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: SchedulerMixin,
|
||||
safety_checker: StableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
):
|
||||
super().__init__(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
self.__init__additional__()
|
||||
|
||||
def __init__additional__(self):
|
||||
if not hasattr(self, "vae_scale_factor"):
|
||||
setattr(self, "vae_scale_factor", 2 ** (len(self.vae.config.block_out_channels) - 1))
|
||||
|
||||
@property
|
||||
def _execution_device(self):
|
||||
r"""
|
||||
Returns the device on which the pipeline's models will be executed. After calling
|
||||
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
|
||||
hooks.
|
||||
"""
|
||||
if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
|
||||
return self.device
|
||||
for module in self.unet.modules():
|
||||
if (
|
||||
hasattr(module, "_hf_hook")
|
||||
and hasattr(module._hf_hook, "execution_device")
|
||||
and module._hf_hook.execution_device is not None
|
||||
):
|
||||
return torch.device(module._hf_hook.execution_device)
|
||||
return self.device
|
||||
|
||||
def _encode_prompt(
|
||||
self,
|
||||
@@ -752,37 +823,33 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 8. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
for i, t in enumerate(self.progress_bar(timesteps)):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
|
||||
if mask is not None:
|
||||
# masking
|
||||
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
|
||||
latents = (init_latents_proper * mask) + (latents * (1 - mask))
|
||||
if mask is not None:
|
||||
# masking
|
||||
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
|
||||
latents = (init_latents_proper * mask) + (latents * (1 - mask))
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
if i % callback_steps == 0:
|
||||
if callback is not None:
|
||||
callback(i, t, latents)
|
||||
if is_cancelled_callback is not None and is_cancelled_callback():
|
||||
return None
|
||||
# call the callback, if provided
|
||||
if i % callback_steps == 0:
|
||||
if callback is not None:
|
||||
callback(i, t, latents)
|
||||
if is_cancelled_callback is not None and is_cancelled_callback():
|
||||
return None
|
||||
|
||||
# 9. Post-processing
|
||||
image = self.decode_latents(latents)
|
||||
|
||||
@@ -5,14 +5,55 @@ from typing import Callable, List, Optional, Union
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import diffusers
|
||||
import PIL
|
||||
from diffusers import OnnxStableDiffusionPipeline, SchedulerMixin
|
||||
from diffusers.onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel
|
||||
from diffusers.onnx_utils import OnnxRuntimeModel
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||
from diffusers.utils import PIL_INTERPOLATION, deprecate, logging
|
||||
from diffusers.utils import deprecate, logging
|
||||
from packaging import version
|
||||
from transformers import CLIPFeatureExtractor, CLIPTokenizer
|
||||
|
||||
|
||||
try:
|
||||
from diffusers.onnx_utils import ORT_TO_NP_TYPE
|
||||
except ImportError:
|
||||
ORT_TO_NP_TYPE = {
|
||||
"tensor(bool)": np.bool_,
|
||||
"tensor(int8)": np.int8,
|
||||
"tensor(uint8)": np.uint8,
|
||||
"tensor(int16)": np.int16,
|
||||
"tensor(uint16)": np.uint16,
|
||||
"tensor(int32)": np.int32,
|
||||
"tensor(uint32)": np.uint32,
|
||||
"tensor(int64)": np.int64,
|
||||
"tensor(uint64)": np.uint64,
|
||||
"tensor(float16)": np.float16,
|
||||
"tensor(float)": np.float32,
|
||||
"tensor(double)": np.float64,
|
||||
}
|
||||
|
||||
try:
|
||||
from diffusers.utils import PIL_INTERPOLATION
|
||||
except ImportError:
|
||||
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
|
||||
PIL_INTERPOLATION = {
|
||||
"linear": PIL.Image.Resampling.BILINEAR,
|
||||
"bilinear": PIL.Image.Resampling.BILINEAR,
|
||||
"bicubic": PIL.Image.Resampling.BICUBIC,
|
||||
"lanczos": PIL.Image.Resampling.LANCZOS,
|
||||
"nearest": PIL.Image.Resampling.NEAREST,
|
||||
}
|
||||
else:
|
||||
PIL_INTERPOLATION = {
|
||||
"linear": PIL.Image.LINEAR,
|
||||
"bilinear": PIL.Image.BILINEAR,
|
||||
"bicubic": PIL.Image.BICUBIC,
|
||||
"lanczos": PIL.Image.LANCZOS,
|
||||
"nearest": PIL.Image.NEAREST,
|
||||
}
|
||||
# ------------------------------------------------------------------------------
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
re_attention = re.compile(
|
||||
@@ -390,30 +431,59 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(OnnxStableDiffusionPipeline
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
"""
|
||||
if version.parse(version.parse(diffusers.__version__).base_version) >= version.parse("0.9.0"):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vae_encoder: OnnxRuntimeModel,
|
||||
vae_decoder: OnnxRuntimeModel,
|
||||
text_encoder: OnnxRuntimeModel,
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: OnnxRuntimeModel,
|
||||
scheduler: SchedulerMixin,
|
||||
safety_checker: OnnxRuntimeModel,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
requires_safety_checker: bool = True,
|
||||
):
|
||||
super().__init__(
|
||||
vae_encoder=vae_encoder,
|
||||
vae_decoder=vae_decoder,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
requires_safety_checker=requires_safety_checker,
|
||||
)
|
||||
def __init__(
|
||||
self,
|
||||
vae_encoder: OnnxRuntimeModel,
|
||||
vae_decoder: OnnxRuntimeModel,
|
||||
text_encoder: OnnxRuntimeModel,
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: OnnxRuntimeModel,
|
||||
scheduler: SchedulerMixin,
|
||||
safety_checker: OnnxRuntimeModel,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
requires_safety_checker: bool = True,
|
||||
):
|
||||
super().__init__(
|
||||
vae_encoder=vae_encoder,
|
||||
vae_decoder=vae_decoder,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
requires_safety_checker=requires_safety_checker,
|
||||
)
|
||||
self.__init__additional__()
|
||||
|
||||
else:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vae_encoder: OnnxRuntimeModel,
|
||||
vae_decoder: OnnxRuntimeModel,
|
||||
text_encoder: OnnxRuntimeModel,
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: OnnxRuntimeModel,
|
||||
scheduler: SchedulerMixin,
|
||||
safety_checker: OnnxRuntimeModel,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
):
|
||||
super().__init__(
|
||||
vae_encoder=vae_encoder,
|
||||
vae_decoder=vae_decoder,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
self.__init__additional__()
|
||||
|
||||
def __init__additional__(self):
|
||||
self.unet_in_channels = 4
|
||||
self.vae_scale_factor = 8
|
||||
|
||||
@@ -741,49 +811,47 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(OnnxStableDiffusionPipeline
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 8. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(torch.from_numpy(latent_model_input), t)
|
||||
latent_model_input = latent_model_input.numpy()
|
||||
for i, t in enumerate(self.progress_bar(timesteps)):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(torch.from_numpy(latent_model_input), t)
|
||||
latent_model_input = latent_model_input.numpy()
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(
|
||||
sample=latent_model_input,
|
||||
timestep=np.array([t], dtype=timestep_dtype),
|
||||
encoder_hidden_states=text_embeddings,
|
||||
)
|
||||
noise_pred = noise_pred[0]
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(
|
||||
sample=latent_model_input,
|
||||
timestep=np.array([t], dtype=timestep_dtype),
|
||||
encoder_hidden_states=text_embeddings,
|
||||
)
|
||||
noise_pred = noise_pred[0]
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
scheduler_output = self.scheduler.step(
|
||||
torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs
|
||||
)
|
||||
latents = scheduler_output.prev_sample.numpy()
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
scheduler_output = self.scheduler.step(
|
||||
torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs
|
||||
)
|
||||
latents = scheduler_output.prev_sample.numpy()
|
||||
|
||||
if mask is not None:
|
||||
# masking
|
||||
init_latents_proper = self.scheduler.add_noise(
|
||||
torch.from_numpy(init_latents_orig),
|
||||
torch.from_numpy(noise),
|
||||
t,
|
||||
).numpy()
|
||||
latents = (init_latents_proper * mask) + (latents * (1 - mask))
|
||||
if mask is not None:
|
||||
# masking
|
||||
init_latents_proper = self.scheduler.add_noise(
|
||||
torch.from_numpy(init_latents_orig),
|
||||
torch.from_numpy(noise),
|
||||
t,
|
||||
).numpy()
|
||||
latents = (init_latents_proper * mask) + (latents * (1 - mask))
|
||||
|
||||
# call the callback, if provided
|
||||
if i % callback_steps == 0:
|
||||
if callback is not None:
|
||||
callback(i, t, latents)
|
||||
if is_cancelled_callback is not None and is_cancelled_callback():
|
||||
return None
|
||||
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
if i % callback_steps == 0:
|
||||
if callback is not None:
|
||||
callback(i, t, latents)
|
||||
if is_cancelled_callback is not None and is_cancelled_callback():
|
||||
return None
|
||||
# 9. Post-processing
|
||||
image = self.decode_latents(latents)
|
||||
|
||||
|
||||
@@ -19,4 +19,6 @@ class UnetSchedulerOneForwardPipeline(DiffusionPipeline):
|
||||
model_output = self.unet(image, timestep).sample
|
||||
scheduler_output = self.scheduler.step(model_output, timestep, image).prev_sample
|
||||
|
||||
return scheduler_output
|
||||
result = scheduler_output - scheduler_output + torch.ones_like(scheduler_output)
|
||||
|
||||
return result
|
||||
|
||||
405
examples/community/stable_diffusion_comparison.py
Normal file
405
examples/community/stable_diffusion_comparison.py
Normal file
@@ -0,0 +1,405 @@
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
DDIMScheduler,
|
||||
DiffusionPipeline,
|
||||
LMSDiscreteScheduler,
|
||||
PNDMScheduler,
|
||||
StableDiffusionPipeline,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
|
||||
pipe1_model_id = "CompVis/stable-diffusion-v1-1"
|
||||
pipe2_model_id = "CompVis/stable-diffusion-v1-2"
|
||||
pipe3_model_id = "CompVis/stable-diffusion-v1-3"
|
||||
pipe4_model_id = "CompVis/stable-diffusion-v1-4"
|
||||
|
||||
|
||||
class StableDiffusionComparisonPipeline(DiffusionPipeline):
|
||||
r"""
|
||||
Pipeline for parallel comparison of Stable Diffusion v1-v4
|
||||
This pipeline inherits from DiffusionPipeline and depends on the use of an Auth Token for
|
||||
downloading pre-trained checkpoints from Hugging Face Hub.
|
||||
If using Hugging Face Hub, pass the Model ID for Stable Diffusion v1.4 as the previous 3 checkpoints will be loaded
|
||||
automatically.
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
text_encoder ([`CLIPTextModel`]):
|
||||
Frozen text-encoder. Stable Diffusion uses the text portion of
|
||||
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
|
||||
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
|
||||
tokenizer (`CLIPTokenizer`):
|
||||
Tokenizer of class
|
||||
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
||||
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
|
||||
scheduler ([`SchedulerMixin`]):
|
||||
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
||||
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
||||
safety_checker ([`StableDiffusionMegaSafetyChecker`]):
|
||||
Classification module that estimates whether generated images could be considered offensive or harmful.
|
||||
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
|
||||
feature_extractor ([`CLIPFeatureExtractor`]):
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vae: AutoencoderKL,
|
||||
text_encoder: CLIPTextModel,
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
|
||||
safety_checker: StableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
requires_safety_checker: bool = True,
|
||||
):
|
||||
super()._init_()
|
||||
|
||||
self.pipe1 = StableDiffusionPipeline.from_pretrained(pipe1_model_id)
|
||||
self.pipe2 = StableDiffusionPipeline.from_pretrained(pipe2_model_id)
|
||||
self.pipe3 = StableDiffusionPipeline.from_pretrained(pipe3_model_id)
|
||||
self.pipe4 = StableDiffusionPipeline(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
requires_safety_checker=requires_safety_checker,
|
||||
)
|
||||
|
||||
self.register_modules(pipeline1=self.pipe1, pipeline2=self.pipe2, pipeline3=self.pipe3, pipeline4=self.pipe4)
|
||||
|
||||
@property
|
||||
def layers(self) -> Dict[str, Any]:
|
||||
return {k: getattr(self, k) for k in self.config.keys() if not k.startswith("_")}
|
||||
|
||||
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
|
||||
r"""
|
||||
Enable sliced attention computation.
|
||||
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
|
||||
in several steps. This is useful to save some memory in exchange for a small speed decrease.
|
||||
Args:
|
||||
slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
|
||||
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
|
||||
a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
|
||||
`attention_head_dim` must be a multiple of `slice_size`.
|
||||
"""
|
||||
if slice_size == "auto":
|
||||
# half the attention head size is usually a good trade-off between
|
||||
# speed and memory
|
||||
slice_size = self.unet.config.attention_head_dim // 2
|
||||
self.unet.set_attention_slice(slice_size)
|
||||
|
||||
def disable_attention_slicing(self):
|
||||
r"""
|
||||
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
|
||||
back to computing attention in one step.
|
||||
"""
|
||||
# set slice_size = `None` to disable `attention slicing`
|
||||
self.enable_attention_slicing(None)
|
||||
|
||||
@torch.no_grad()
|
||||
def text2img_sd1_1(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
height: int = 512,
|
||||
width: int = 512,
|
||||
num_inference_steps: int = 50,
|
||||
guidance_scale: float = 7.5,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
callback_steps: Optional[int] = 1,
|
||||
**kwargs,
|
||||
):
|
||||
return self.pipe1(
|
||||
prompt=prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
num_inference_steps=num_inference_steps,
|
||||
guidance_scale=guidance_scale,
|
||||
negative_prompt=negative_prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
eta=eta,
|
||||
generator=generator,
|
||||
latents=latents,
|
||||
output_type=output_type,
|
||||
return_dict=return_dict,
|
||||
callback=callback,
|
||||
callback_steps=callback_steps,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def text2img_sd1_2(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
height: int = 512,
|
||||
width: int = 512,
|
||||
num_inference_steps: int = 50,
|
||||
guidance_scale: float = 7.5,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
callback_steps: Optional[int] = 1,
|
||||
**kwargs,
|
||||
):
|
||||
return self.pipe2(
|
||||
prompt=prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
num_inference_steps=num_inference_steps,
|
||||
guidance_scale=guidance_scale,
|
||||
negative_prompt=negative_prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
eta=eta,
|
||||
generator=generator,
|
||||
latents=latents,
|
||||
output_type=output_type,
|
||||
return_dict=return_dict,
|
||||
callback=callback,
|
||||
callback_steps=callback_steps,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def text2img_sd1_3(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
height: int = 512,
|
||||
width: int = 512,
|
||||
num_inference_steps: int = 50,
|
||||
guidance_scale: float = 7.5,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
callback_steps: Optional[int] = 1,
|
||||
**kwargs,
|
||||
):
|
||||
return self.pipe3(
|
||||
prompt=prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
num_inference_steps=num_inference_steps,
|
||||
guidance_scale=guidance_scale,
|
||||
negative_prompt=negative_prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
eta=eta,
|
||||
generator=generator,
|
||||
latents=latents,
|
||||
output_type=output_type,
|
||||
return_dict=return_dict,
|
||||
callback=callback,
|
||||
callback_steps=callback_steps,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def text2img_sd1_4(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
height: int = 512,
|
||||
width: int = 512,
|
||||
num_inference_steps: int = 50,
|
||||
guidance_scale: float = 7.5,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
callback_steps: Optional[int] = 1,
|
||||
**kwargs,
|
||||
):
|
||||
return self.pipe4(
|
||||
prompt=prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
num_inference_steps=num_inference_steps,
|
||||
guidance_scale=guidance_scale,
|
||||
negative_prompt=negative_prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
eta=eta,
|
||||
generator=generator,
|
||||
latents=latents,
|
||||
output_type=output_type,
|
||||
return_dict=return_dict,
|
||||
callback=callback,
|
||||
callback_steps=callback_steps,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def _call_(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
height: int = 512,
|
||||
width: int = 512,
|
||||
num_inference_steps: int = 50,
|
||||
guidance_scale: float = 7.5,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
callback_steps: Optional[int] = 1,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation. This function will generate 4 results as part
|
||||
of running all the 4 pipelines for SD1.1-1.4 together in a serial-processing, parallel-invocation fashion.
|
||||
Args:
|
||||
prompt (`str` or `List[str]`):
|
||||
The prompt or prompts to guide the image generation.
|
||||
height (`int`, optional, defaults to 512):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, optional, defaults to 512):
|
||||
The width in pixels of the generated image.
|
||||
num_inference_steps (`int`, optional, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
guidance_scale (`float`, optional, defaults to 7.5):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
eta (`float`, optional, defaults to 0.0):
|
||||
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
||||
[`schedulers.DDIMScheduler`], will be ignored for others.
|
||||
generator (`torch.Generator`, optional):
|
||||
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
||||
deterministic.
|
||||
latents (`torch.FloatTensor`, optional):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
output_type (`str`, optional, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, optional, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
||||
plain tuple.
|
||||
Returns:
|
||||
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
||||
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
|
||||
When returning a tuple, the first element is a list with the generated images, and the second element is a
|
||||
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
||||
(nsfw) content, according to the `safety_checker`.
|
||||
"""
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
self.to(device)
|
||||
|
||||
# Checks if the height and width are divisible by 8 or not
|
||||
if height % 8 != 0 or width % 8 != 0:
|
||||
raise ValueError(f"`height` and `width` must be divisible by 8 but are {height} and {width}.")
|
||||
|
||||
# Get first result from Stable Diffusion Checkpoint v1.1
|
||||
res1 = self.text2img_sd1_1(
|
||||
prompt=prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
num_inference_steps=num_inference_steps,
|
||||
guidance_scale=guidance_scale,
|
||||
negative_prompt=negative_prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
eta=eta,
|
||||
generator=generator,
|
||||
latents=latents,
|
||||
output_type=output_type,
|
||||
return_dict=return_dict,
|
||||
callback=callback,
|
||||
callback_steps=callback_steps,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Get first result from Stable Diffusion Checkpoint v1.2
|
||||
res2 = self.text2img_sd1_2(
|
||||
prompt=prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
num_inference_steps=num_inference_steps,
|
||||
guidance_scale=guidance_scale,
|
||||
negative_prompt=negative_prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
eta=eta,
|
||||
generator=generator,
|
||||
latents=latents,
|
||||
output_type=output_type,
|
||||
return_dict=return_dict,
|
||||
callback=callback,
|
||||
callback_steps=callback_steps,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Get first result from Stable Diffusion Checkpoint v1.3
|
||||
res3 = self.text2img_sd1_3(
|
||||
prompt=prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
num_inference_steps=num_inference_steps,
|
||||
guidance_scale=guidance_scale,
|
||||
negative_prompt=negative_prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
eta=eta,
|
||||
generator=generator,
|
||||
latents=latents,
|
||||
output_type=output_type,
|
||||
return_dict=return_dict,
|
||||
callback=callback,
|
||||
callback_steps=callback_steps,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Get first result from Stable Diffusion Checkpoint v1.4
|
||||
res4 = self.text2img_sd1_4(
|
||||
prompt=prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
num_inference_steps=num_inference_steps,
|
||||
guidance_scale=guidance_scale,
|
||||
negative_prompt=negative_prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
eta=eta,
|
||||
generator=generator,
|
||||
latents=latents,
|
||||
output_type=output_type,
|
||||
return_dict=return_dict,
|
||||
callback=callback,
|
||||
callback_steps=callback_steps,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Get all result images into a single list and pass it via StableDiffusionPipelineOutput for final result
|
||||
return StableDiffusionPipelineOutput([res1[0], res2[0], res3[0], res4[0]])
|
||||
@@ -1,6 +1,6 @@
|
||||
accelerate
|
||||
torchvision
|
||||
transformers>=4.21.0
|
||||
transformers>=4.25.1
|
||||
ftfy
|
||||
tensorboard
|
||||
modelcards
|
||||
modelcards
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
transformers>=4.21.0
|
||||
transformers>=4.25.1
|
||||
flax
|
||||
optax
|
||||
torch
|
||||
torchvision
|
||||
ftfy
|
||||
tensorboard
|
||||
modelcards
|
||||
modelcards
|
||||
|
||||
@@ -3,6 +3,7 @@ import hashlib
|
||||
import itertools
|
||||
import math
|
||||
import os
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
@@ -17,6 +18,7 @@ from accelerate.utils import set_seed
|
||||
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.utils import check_min_version
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from huggingface_hub import HfFolder, Repository, whoami
|
||||
from PIL import Image
|
||||
from torchvision import transforms
|
||||
@@ -148,7 +150,24 @@ def parse_args(input_args=None):
|
||||
default=None,
|
||||
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
|
||||
)
|
||||
parser.add_argument("--save_steps", type=int, default=500, help="Save checkpoint every X updates steps.")
|
||||
parser.add_argument(
|
||||
"--checkpointing_steps",
|
||||
type=int,
|
||||
default=500,
|
||||
help=(
|
||||
"Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
|
||||
" training using `--resume_from_checkpoint`."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--resume_from_checkpoint",
|
||||
type=str,
|
||||
default=None,
|
||||
help=(
|
||||
"Whether training should be resumed from a previous checkpoint. Use a path saved by"
|
||||
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gradient_accumulation_steps",
|
||||
type=int,
|
||||
@@ -237,10 +256,11 @@ def parse_args(input_args=None):
|
||||
if args.class_prompt is None:
|
||||
raise ValueError("You must specify prompt for class images.")
|
||||
else:
|
||||
# logger is not available yet
|
||||
if args.class_data_dir is not None:
|
||||
logger.warning("You need not use --class_data_dir without --with_prior_preservation.")
|
||||
warnings.warn("You need not use --class_data_dir without --with_prior_preservation.")
|
||||
if args.class_prompt is not None:
|
||||
logger.warning("You need not use --class_prompt without --with_prior_preservation.")
|
||||
warnings.warn("You need not use --class_prompt without --with_prior_preservation.")
|
||||
|
||||
return args
|
||||
|
||||
@@ -488,6 +508,15 @@ def main(args):
|
||||
revision=args.revision,
|
||||
)
|
||||
|
||||
if is_xformers_available():
|
||||
try:
|
||||
unet.enable_xformers_memory_efficient_attention()
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Could not enable memory efficient attention. Make sure xformers is installed"
|
||||
f" correctly and a GPU is available: {e}"
|
||||
)
|
||||
|
||||
vae.requires_grad_(False)
|
||||
if not args.train_text_encoder:
|
||||
text_encoder.requires_grad_(False)
|
||||
@@ -568,6 +597,7 @@ def main(args):
|
||||
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
unet, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
accelerator.register_for_checkpointing(lr_scheduler)
|
||||
|
||||
weight_dtype = torch.float32
|
||||
if accelerator.mixed_precision == "fp16":
|
||||
@@ -605,16 +635,41 @@ def main(args):
|
||||
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
||||
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
|
||||
logger.info(f" Total optimization steps = {args.max_train_steps}")
|
||||
# Only show the progress bar once on each machine.
|
||||
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
|
||||
progress_bar.set_description("Steps")
|
||||
global_step = 0
|
||||
first_epoch = 0
|
||||
|
||||
for epoch in range(args.num_train_epochs):
|
||||
if args.resume_from_checkpoint:
|
||||
if args.resume_from_checkpoint != "latest":
|
||||
path = os.path.basename(args.resume_from_checkpoint)
|
||||
else:
|
||||
# Get the mos recent checkpoint
|
||||
dirs = os.listdir(args.output_dir)
|
||||
dirs = [d for d in dirs if d.startswith("checkpoint")]
|
||||
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
|
||||
path = dirs[-1]
|
||||
accelerator.print(f"Resuming from checkpoint {path}")
|
||||
accelerator.load_state(os.path.join(args.output_dir, path))
|
||||
global_step = int(path.split("-")[1])
|
||||
|
||||
resume_global_step = global_step * args.gradient_accumulation_steps
|
||||
first_epoch = resume_global_step // num_update_steps_per_epoch
|
||||
resume_step = resume_global_step % num_update_steps_per_epoch
|
||||
|
||||
# Only show the progress bar once on each machine.
|
||||
progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
|
||||
progress_bar.set_description("Steps")
|
||||
|
||||
for epoch in range(first_epoch, args.num_train_epochs):
|
||||
unet.train()
|
||||
if args.train_text_encoder:
|
||||
text_encoder.train()
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
# Skip steps until we reach the resumed step
|
||||
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
|
||||
if step % args.gradient_accumulation_steps == 0:
|
||||
progress_bar.update(1)
|
||||
continue
|
||||
|
||||
with accelerator.accumulate(unet):
|
||||
# Convert images to latent space
|
||||
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
|
||||
@@ -678,16 +733,11 @@ def main(args):
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
|
||||
if global_step % args.save_steps == 0:
|
||||
if global_step % args.checkpointing_steps == 0:
|
||||
if accelerator.is_main_process:
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
unet=accelerator.unwrap_model(unet),
|
||||
text_encoder=accelerator.unwrap_model(text_encoder),
|
||||
revision=args.revision,
|
||||
)
|
||||
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
||||
pipeline.save_pretrained(save_path)
|
||||
accelerator.save_state(save_path)
|
||||
logger.info(f"Saved state to {save_path}")
|
||||
|
||||
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
|
||||
progress_bar.set_postfix(**logs)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
accelerate
|
||||
torchvision
|
||||
transformers>=4.21.0
|
||||
transformers>=4.25.1
|
||||
datasets
|
||||
ftfy
|
||||
tensorboard
|
||||
modelcards
|
||||
modelcards
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
transformers>=4.21.0
|
||||
transformers>=4.25.1
|
||||
datasets
|
||||
flax
|
||||
optax
|
||||
@@ -6,4 +6,4 @@ torch
|
||||
torchvision
|
||||
ftfy
|
||||
tensorboard
|
||||
modelcards
|
||||
modelcards
|
||||
|
||||
@@ -18,6 +18,7 @@ from datasets import load_dataset
|
||||
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.utils import check_min_version
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from huggingface_hub import HfFolder, Repository, whoami
|
||||
from torchvision import transforms
|
||||
from tqdm.auto import tqdm
|
||||
@@ -364,6 +365,15 @@ def main():
|
||||
revision=args.revision,
|
||||
)
|
||||
|
||||
if is_xformers_available():
|
||||
try:
|
||||
unet.enable_xformers_memory_efficient_attention()
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Could not enable memory efficient attention. Make sure xformers is installed"
|
||||
f" correctly and a GPU is available: {e}"
|
||||
)
|
||||
|
||||
# Freeze vae and text_encoder
|
||||
vae.requires_grad_(False)
|
||||
text_encoder.requires_grad_(False)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
accelerate
|
||||
torchvision
|
||||
transformers>=4.21.0
|
||||
transformers>=4.25.1
|
||||
ftfy
|
||||
tensorboard
|
||||
modelcards
|
||||
modelcards
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
transformers>=4.21.0
|
||||
transformers>=4.25.1
|
||||
flax
|
||||
optax
|
||||
torch
|
||||
torchvision
|
||||
ftfy
|
||||
tensorboard
|
||||
modelcards
|
||||
modelcards
|
||||
|
||||
@@ -20,6 +20,7 @@ from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusi
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
|
||||
from diffusers.utils import check_min_version
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from huggingface_hub import HfFolder, Repository, whoami
|
||||
|
||||
# TODO: remove and import from diffusers.utils when the new version of diffusers is released
|
||||
@@ -439,6 +440,15 @@ def main():
|
||||
revision=args.revision,
|
||||
)
|
||||
|
||||
if is_xformers_available():
|
||||
try:
|
||||
unet.enable_xformers_memory_efficient_attention()
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Could not enable memory efficient attention. Make sure xformers is installed"
|
||||
f" correctly and a GPU is available: {e}"
|
||||
)
|
||||
|
||||
# Resize the token embeddings as we are adding new special tokens to the tokenizer
|
||||
text_encoder.resize_token_embeddings(len(tokenizer))
|
||||
|
||||
@@ -538,6 +548,9 @@ def main():
|
||||
progress_bar.set_description("Steps")
|
||||
global_step = 0
|
||||
|
||||
# keep original embeddings as reference
|
||||
orig_embeds_params = text_encoder.get_input_embeddings().weight.data.clone()
|
||||
|
||||
for epoch in range(args.num_train_epochs):
|
||||
text_encoder.train()
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
@@ -575,20 +588,15 @@ def main():
|
||||
loss = F.mse_loss(model_pred, target, reduction="none").mean([1, 2, 3]).mean()
|
||||
accelerator.backward(loss)
|
||||
|
||||
# Zero out the gradients for all token embeddings except the newly added
|
||||
# embeddings for the concept, as we only want to optimize the concept embeddings
|
||||
if accelerator.num_processes > 1:
|
||||
grads = text_encoder.module.get_input_embeddings().weight.grad
|
||||
else:
|
||||
grads = text_encoder.get_input_embeddings().weight.grad
|
||||
# Get the index for tokens that we want to zero the grads for
|
||||
index_grads_to_zero = torch.arange(len(tokenizer)) != placeholder_token_id
|
||||
grads.data[index_grads_to_zero, :] = grads.data[index_grads_to_zero, :].fill_(0)
|
||||
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Let's make sure we don't update any embedding weights besides the newly added token
|
||||
index_no_updates = torch.arange(len(tokenizer)) != placeholder_token_id
|
||||
with torch.no_grad():
|
||||
text_encoder.get_input_embeddings().weight[index_no_updates] = orig_embeds_params[index_no_updates]
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
progress_bar.update(1)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import argparse
|
||||
import inspect
|
||||
import math
|
||||
import os
|
||||
from pathlib import Path
|
||||
@@ -31,9 +32,192 @@ from tqdm.auto import tqdm
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.10.0.dev0")
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def _extract_into_tensor(arr, timesteps, broadcast_shape):
|
||||
"""
|
||||
Extract values from a 1-D numpy array for a batch of indices.
|
||||
|
||||
:param arr: the 1-D numpy array.
|
||||
:param timesteps: a tensor of indices into the array to extract.
|
||||
:param broadcast_shape: a larger shape of K dimensions with the batch
|
||||
dimension equal to the length of timesteps.
|
||||
:return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
|
||||
"""
|
||||
if not isinstance(arr, torch.Tensor):
|
||||
arr = torch.from_numpy(arr)
|
||||
res = arr[timesteps].float().to(timesteps.device)
|
||||
while len(res.shape) < len(broadcast_shape):
|
||||
res = res[..., None]
|
||||
return res.expand(broadcast_shape)
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
||||
parser.add_argument(
|
||||
"--dataset_name",
|
||||
type=str,
|
||||
default=None,
|
||||
help=(
|
||||
"The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
|
||||
" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
|
||||
" or to a folder containing files that HF Datasets can understand."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset_config_name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The config of the Dataset, leave as None if there's only one config.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train_data_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help=(
|
||||
"A folder containing the training data. Folder contents must follow the structure described in"
|
||||
" https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
|
||||
" must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
default="ddpm-model-64",
|
||||
help="The output directory where the model predictions and checkpoints will be written.",
|
||||
)
|
||||
parser.add_argument("--overwrite_output_dir", action="store_true")
|
||||
parser.add_argument(
|
||||
"--cache_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The directory where the downloaded models and datasets will be stored.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--resolution",
|
||||
type=int,
|
||||
default=64,
|
||||
help=(
|
||||
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
|
||||
" resolution"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--eval_batch_size", type=int, default=16, help="The number of images to generate for evaluation."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataloader_num_workers",
|
||||
type=int,
|
||||
default=0,
|
||||
help=(
|
||||
"The number of subprocesses to use for data loading. 0 means that the data will be loaded in the main"
|
||||
" process."
|
||||
),
|
||||
)
|
||||
parser.add_argument("--num_epochs", type=int, default=100)
|
||||
parser.add_argument("--save_images_epochs", type=int, default=10, help="How often to save images during training.")
|
||||
parser.add_argument(
|
||||
"--save_model_epochs", type=int, default=10, help="How often to save the model during training."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gradient_accumulation_steps",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--learning_rate",
|
||||
type=float,
|
||||
default=1e-4,
|
||||
help="Initial learning rate (after the potential warmup period) to use.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lr_scheduler",
|
||||
type=str,
|
||||
default="cosine",
|
||||
help=(
|
||||
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
|
||||
' "constant", "constant_with_warmup"]'
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
|
||||
)
|
||||
parser.add_argument("--adam_beta1", type=float, default=0.95, help="The beta1 parameter for the Adam optimizer.")
|
||||
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
|
||||
parser.add_argument(
|
||||
"--adam_weight_decay", type=float, default=1e-6, help="Weight decay magnitude for the Adam optimizer."
|
||||
)
|
||||
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer.")
|
||||
parser.add_argument(
|
||||
"--use_ema",
|
||||
action="store_true",
|
||||
default=True,
|
||||
help="Whether to use Exponential Moving Average for the final model weights.",
|
||||
)
|
||||
parser.add_argument("--ema_inv_gamma", type=float, default=1.0, help="The inverse gamma value for the EMA decay.")
|
||||
parser.add_argument("--ema_power", type=float, default=3 / 4, help="The power value for the EMA decay.")
|
||||
parser.add_argument("--ema_max_decay", type=float, default=0.9999, help="The maximum decay magnitude for EMA.")
|
||||
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
|
||||
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
|
||||
parser.add_argument(
|
||||
"--hub_model_id",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The name of the repository to keep in sync with the local `output_dir`.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hub_private_repo", action="store_true", help="Whether or not to create a private repository."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--logging_dir",
|
||||
type=str,
|
||||
default="logs",
|
||||
help=(
|
||||
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
|
||||
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
|
||||
),
|
||||
)
|
||||
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
||||
parser.add_argument(
|
||||
"--mixed_precision",
|
||||
type=str,
|
||||
default="no",
|
||||
choices=["no", "fp16", "bf16"],
|
||||
help=(
|
||||
"Whether to use mixed precision. Choose"
|
||||
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
|
||||
"and an Nvidia Ampere GPU."
|
||||
),
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--prediction_type",
|
||||
type=str,
|
||||
default="epsilon",
|
||||
choices=["epsilon", "sample"],
|
||||
help="Whether the model should predict the 'epsilon'/noise error or directly the reconstructed image 'x0'.",
|
||||
)
|
||||
|
||||
parser.add_argument("--ddpm_num_steps", type=int, default=1000)
|
||||
parser.add_argument("--ddpm_beta_schedule", type=str, default="linear")
|
||||
|
||||
args = parser.parse_args()
|
||||
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
||||
if env_local_rank != -1 and env_local_rank != args.local_rank:
|
||||
args.local_rank = env_local_rank
|
||||
|
||||
if args.dataset_name is None and args.train_data_dir is None:
|
||||
raise ValueError("You must specify either a dataset name from the hub or a train data directory.")
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
|
||||
if token is None:
|
||||
token = HfFolder.get_token()
|
||||
@@ -77,7 +261,17 @@ def main(args):
|
||||
),
|
||||
)
|
||||
model = ORTModule(model)
|
||||
noise_scheduler = DDPMScheduler(num_train_timesteps=1000, tensor_format="pt")
|
||||
accepts_prediction_type = "prediction_type" in set(inspect.signature(DDPMScheduler.__init__).parameters.keys())
|
||||
|
||||
if accepts_prediction_type:
|
||||
noise_scheduler = DDPMScheduler(
|
||||
num_train_timesteps=args.ddpm_num_steps,
|
||||
beta_schedule=args.ddpm_beta_schedule,
|
||||
prediction_type=args.prediction_type,
|
||||
)
|
||||
else:
|
||||
noise_scheduler = DDPMScheduler(num_train_timesteps=args.ddpm_num_steps, beta_schedule=args.ddpm_beta_schedule)
|
||||
|
||||
optimizer = torch.optim.AdamW(
|
||||
model.parameters(),
|
||||
lr=args.learning_rate,
|
||||
@@ -101,7 +295,6 @@ def main(args):
|
||||
args.dataset_name,
|
||||
args.dataset_config_name,
|
||||
cache_dir=args.cache_dir,
|
||||
use_auth_token=True if args.use_auth_token else None,
|
||||
split="train",
|
||||
)
|
||||
else:
|
||||
@@ -111,8 +304,12 @@ def main(args):
|
||||
images = [augmentations(image.convert("RGB")) for image in examples["image"]]
|
||||
return {"input": images}
|
||||
|
||||
logger.info(f"Dataset size: {len(dataset)}")
|
||||
|
||||
dataset.set_transform(transforms)
|
||||
train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.train_batch_size, shuffle=True)
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers
|
||||
)
|
||||
|
||||
lr_scheduler = get_scheduler(
|
||||
args.lr_scheduler,
|
||||
@@ -127,7 +324,12 @@ def main(args):
|
||||
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
|
||||
ema_model = EMAModel(model, inv_gamma=args.ema_inv_gamma, power=args.ema_power, max_value=args.ema_max_decay)
|
||||
ema_model = EMAModel(
|
||||
accelerator.unwrap_model(model),
|
||||
inv_gamma=args.ema_inv_gamma,
|
||||
power=args.ema_power,
|
||||
max_value=args.ema_max_decay,
|
||||
)
|
||||
|
||||
# Handle the repository creation
|
||||
if accelerator.is_main_process:
|
||||
@@ -171,11 +373,26 @@ def main(args):
|
||||
|
||||
with accelerator.accumulate(model):
|
||||
# Predict the noise residual
|
||||
noise_pred = model(noisy_images, timesteps, return_dict=True)[0]
|
||||
loss = F.mse_loss(noise_pred, noise)
|
||||
model_output = model(noisy_images, timesteps, return_dict=True)[0]
|
||||
|
||||
if args.prediction_type == "epsilon":
|
||||
loss = F.mse_loss(model_output, noise) # this could have different weights!
|
||||
elif args.prediction_type == "sample":
|
||||
alpha_t = _extract_into_tensor(
|
||||
noise_scheduler.alphas_cumprod, timesteps, (clean_images.shape[0], 1, 1, 1)
|
||||
)
|
||||
snr_weights = alpha_t / (1 - alpha_t)
|
||||
loss = snr_weights * F.mse_loss(
|
||||
model_output, clean_images, reduction="none"
|
||||
) # use SNR weighting from distillation paper
|
||||
loss = loss.mean()
|
||||
else:
|
||||
raise ValueError(f"Unsupported prediction type: {args.prediction_type}")
|
||||
|
||||
accelerator.backward(loss)
|
||||
|
||||
accelerator.clip_grad_norm_(model.parameters(), 1.0)
|
||||
if accelerator.sync_gradients:
|
||||
accelerator.clip_grad_norm_(model.parameters(), 1.0)
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
if args.use_ema:
|
||||
@@ -204,9 +421,13 @@ def main(args):
|
||||
scheduler=noise_scheduler,
|
||||
)
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
generator = torch.Generator(device=pipeline.device).manual_seed(0)
|
||||
# run pipeline in inference (sample random noise and denoise)
|
||||
images = pipeline(generator=generator, batch_size=args.eval_batch_size, output_type="numpy").images
|
||||
images = pipeline(
|
||||
generator=generator,
|
||||
batch_size=args.eval_batch_size,
|
||||
output_type="numpy",
|
||||
).images
|
||||
|
||||
# denormalize the images and save to tensorboard
|
||||
images_processed = (images * 255).round().astype("uint8")
|
||||
@@ -225,56 +446,5 @@ def main(args):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
||||
parser.add_argument("--local_rank", type=int, default=-1)
|
||||
parser.add_argument("--dataset_name", type=str, default=None)
|
||||
parser.add_argument("--dataset_config_name", type=str, default=None)
|
||||
parser.add_argument("--train_data_dir", type=str, default=None, help="A folder containing the training data.")
|
||||
parser.add_argument("--output_dir", type=str, default="ddpm-model-64")
|
||||
parser.add_argument("--overwrite_output_dir", action="store_true")
|
||||
parser.add_argument("--cache_dir", type=str, default=None)
|
||||
parser.add_argument("--resolution", type=int, default=64)
|
||||
parser.add_argument("--train_batch_size", type=int, default=16)
|
||||
parser.add_argument("--eval_batch_size", type=int, default=16)
|
||||
parser.add_argument("--num_epochs", type=int, default=100)
|
||||
parser.add_argument("--save_images_epochs", type=int, default=10)
|
||||
parser.add_argument("--save_model_epochs", type=int, default=10)
|
||||
parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
|
||||
parser.add_argument("--learning_rate", type=float, default=1e-4)
|
||||
parser.add_argument("--lr_scheduler", type=str, default="cosine")
|
||||
parser.add_argument("--lr_warmup_steps", type=int, default=500)
|
||||
parser.add_argument("--adam_beta1", type=float, default=0.95)
|
||||
parser.add_argument("--adam_beta2", type=float, default=0.999)
|
||||
parser.add_argument("--adam_weight_decay", type=float, default=1e-6)
|
||||
parser.add_argument("--adam_epsilon", type=float, default=1e-08)
|
||||
parser.add_argument("--use_ema", action="store_true", default=True)
|
||||
parser.add_argument("--ema_inv_gamma", type=float, default=1.0)
|
||||
parser.add_argument("--ema_power", type=float, default=3 / 4)
|
||||
parser.add_argument("--ema_max_decay", type=float, default=0.9999)
|
||||
parser.add_argument("--push_to_hub", action="store_true")
|
||||
parser.add_argument("--use_auth_token", action="store_true")
|
||||
parser.add_argument("--hub_token", type=str, default=None)
|
||||
parser.add_argument("--hub_model_id", type=str, default=None)
|
||||
parser.add_argument("--hub_private_repo", action="store_true")
|
||||
parser.add_argument("--logging_dir", type=str, default="logs")
|
||||
parser.add_argument(
|
||||
"--mixed_precision",
|
||||
type=str,
|
||||
default="no",
|
||||
choices=["no", "fp16", "bf16"],
|
||||
help=(
|
||||
"Whether to use mixed precision. Choose"
|
||||
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
|
||||
"and an Nvidia Ampere GPU."
|
||||
),
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
||||
if env_local_rank != -1 and env_local_rank != args.local_rank:
|
||||
args.local_rank = env_local_rank
|
||||
|
||||
if args.dataset_name is None and args.train_data_dir is None:
|
||||
raise ValueError("You must specify either a dataset name from the hub or a train data directory.")
|
||||
|
||||
args = parse_args()
|
||||
main(args)
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
|
||||
import argparse
|
||||
import os.path as osp
|
||||
import re
|
||||
|
||||
import torch
|
||||
|
||||
@@ -187,7 +188,72 @@ def convert_vae_state_dict(vae_state_dict):
|
||||
# =========================#
|
||||
# Text Encoder Conversion #
|
||||
# =========================#
|
||||
# pretty much a no-op
|
||||
|
||||
|
||||
textenc_conversion_lst = [
|
||||
# (stable-diffusion, HF Diffusers)
|
||||
("resblocks.", "text_model.encoder.layers."),
|
||||
("ln_1", "layer_norm1"),
|
||||
("ln_2", "layer_norm2"),
|
||||
(".c_fc.", ".fc1."),
|
||||
(".c_proj.", ".fc2."),
|
||||
(".attn", ".self_attn"),
|
||||
("ln_final.", "transformer.text_model.final_layer_norm."),
|
||||
("token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"),
|
||||
("positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"),
|
||||
]
|
||||
protected = {re.escape(x[1]): x[0] for x in textenc_conversion_lst}
|
||||
textenc_pattern = re.compile("|".join(protected.keys()))
|
||||
|
||||
# Ordering is from https://github.com/pytorch/pytorch/blob/master/test/cpp/api/modules.cpp
|
||||
code2idx = {"q": 0, "k": 1, "v": 2}
|
||||
|
||||
|
||||
def convert_text_enc_state_dict_v20(text_enc_dict):
|
||||
new_state_dict = {}
|
||||
capture_qkv_weight = {}
|
||||
capture_qkv_bias = {}
|
||||
for k, v in text_enc_dict.items():
|
||||
if (
|
||||
k.endswith(".self_attn.q_proj.weight")
|
||||
or k.endswith(".self_attn.k_proj.weight")
|
||||
or k.endswith(".self_attn.v_proj.weight")
|
||||
):
|
||||
k_pre = k[: -len(".q_proj.weight")]
|
||||
k_code = k[-len("q_proj.weight")]
|
||||
if k_pre not in capture_qkv_weight:
|
||||
capture_qkv_weight[k_pre] = [None, None, None]
|
||||
capture_qkv_weight[k_pre][code2idx[k_code]] = v
|
||||
continue
|
||||
|
||||
if (
|
||||
k.endswith(".self_attn.q_proj.bias")
|
||||
or k.endswith(".self_attn.k_proj.bias")
|
||||
or k.endswith(".self_attn.v_proj.bias")
|
||||
):
|
||||
k_pre = k[: -len(".q_proj.bias")]
|
||||
k_code = k[-len("q_proj.bias")]
|
||||
if k_pre not in capture_qkv_bias:
|
||||
capture_qkv_bias[k_pre] = [None, None, None]
|
||||
capture_qkv_bias[k_pre][code2idx[k_code]] = v
|
||||
continue
|
||||
|
||||
relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k)
|
||||
new_state_dict[relabelled_key] = v
|
||||
|
||||
for k_pre, tensors in capture_qkv_weight.items():
|
||||
if None in tensors:
|
||||
raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
|
||||
relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
|
||||
new_state_dict[relabelled_key + ".in_proj_weight"] = torch.cat(tensors)
|
||||
|
||||
for k_pre, tensors in capture_qkv_bias.items():
|
||||
if None in tensors:
|
||||
raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
|
||||
relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
|
||||
new_state_dict[relabelled_key + ".in_proj_bias"] = torch.cat(tensors)
|
||||
|
||||
return new_state_dict
|
||||
|
||||
|
||||
def convert_text_enc_state_dict(text_enc_dict):
|
||||
@@ -223,8 +289,18 @@ if __name__ == "__main__":
|
||||
|
||||
# Convert the text encoder model
|
||||
text_enc_dict = torch.load(text_enc_path, map_location="cpu")
|
||||
text_enc_dict = convert_text_enc_state_dict(text_enc_dict)
|
||||
text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()}
|
||||
|
||||
# Easiest way to identify v2.0 model seems to be that the text encoder (OpenCLIP) is deeper
|
||||
is_v20_model = "text_model.encoder.layers.22.layer_norm2.bias" in text_enc_dict
|
||||
|
||||
if is_v20_model:
|
||||
# Need to add the tag 'transformer' in advance so we can knock it out from the final layer-norm
|
||||
text_enc_dict = {"transformer." + k: v for k, v in text_enc_dict.items()}
|
||||
text_enc_dict = convert_text_enc_state_dict_v20(text_enc_dict)
|
||||
text_enc_dict = {"cond_stage_model.model." + k: v for k, v in text_enc_dict.items()}
|
||||
else:
|
||||
text_enc_dict = convert_text_enc_state_dict(text_enc_dict)
|
||||
text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()}
|
||||
|
||||
# Put together new checkpoint
|
||||
state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict}
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import re
|
||||
|
||||
import torch
|
||||
|
||||
@@ -101,15 +102,6 @@ def renew_attention_paths(old_list, n_shave_prefix_segments=0):
|
||||
mapping = []
|
||||
for old_item in old_list:
|
||||
new_item = old_item
|
||||
|
||||
# new_item = new_item.replace('norm.weight', 'group_norm.weight')
|
||||
# new_item = new_item.replace('norm.bias', 'group_norm.bias')
|
||||
|
||||
# new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
|
||||
# new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
|
||||
|
||||
# new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
||||
|
||||
mapping.append({"old": old_item, "new": new_item})
|
||||
|
||||
return mapping
|
||||
@@ -475,15 +467,8 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False
|
||||
return new_checkpoint
|
||||
|
||||
|
||||
def convert_ldm_vae_checkpoint(checkpoint, config):
|
||||
def convert_ldm_vae_checkpoint(vae_state_dict, config):
|
||||
# extract state dict for VAE
|
||||
vae_state_dict = {}
|
||||
vae_key = "first_stage_model."
|
||||
keys = list(checkpoint.keys())
|
||||
for key in keys:
|
||||
if key.startswith(vae_key):
|
||||
vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
|
||||
|
||||
new_checkpoint = {}
|
||||
|
||||
new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
|
||||
@@ -648,6 +633,30 @@ def convert_ldm_clip_checkpoint(checkpoint):
|
||||
return text_model
|
||||
|
||||
|
||||
textenc_conversion_lst = [
|
||||
("cond_stage_model.model.positional_embedding", "text_model.embeddings.position_embedding.weight"),
|
||||
("cond_stage_model.model.token_embedding.weight", "text_model.embeddings.token_embedding.weight"),
|
||||
("cond_stage_model.model.ln_final.weight", "text_model.final_layer_norm.weight"),
|
||||
("cond_stage_model.model.ln_final.bias", "text_model.final_layer_norm.bias"),
|
||||
]
|
||||
textenc_conversion_map = {x[0]: x[1] for x in textenc_conversion_lst}
|
||||
|
||||
textenc_transformer_conversion_lst = [
|
||||
# (stable-diffusion, HF Diffusers)
|
||||
("resblocks.", "text_model.encoder.layers."),
|
||||
("ln_1", "layer_norm1"),
|
||||
("ln_2", "layer_norm2"),
|
||||
(".c_fc.", ".fc1."),
|
||||
(".c_proj.", ".fc2."),
|
||||
(".attn", ".self_attn"),
|
||||
("ln_final.", "transformer.text_model.final_layer_norm."),
|
||||
("token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"),
|
||||
("positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"),
|
||||
]
|
||||
protected = {re.escape(x[0]): x[1] for x in textenc_transformer_conversion_lst}
|
||||
textenc_pattern = re.compile("|".join(protected.keys()))
|
||||
|
||||
|
||||
def convert_paint_by_example_checkpoint(checkpoint):
|
||||
config = CLIPVisionConfig.from_pretrained("openai/clip-vit-large-patch14")
|
||||
model = PaintByExampleImageEncoder(config)
|
||||
@@ -718,15 +727,39 @@ def convert_paint_by_example_checkpoint(checkpoint):
|
||||
def convert_open_clip_checkpoint(checkpoint):
|
||||
text_model = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2", subfolder="text_encoder")
|
||||
|
||||
# SKIP for now - need openclip -> HF conversion script here
|
||||
# keys = list(checkpoint.keys())
|
||||
#
|
||||
# text_model_dict = {}
|
||||
# for key in keys:
|
||||
# if key.startswith("cond_stage_model.model.transformer"):
|
||||
# text_model_dict[key[len("cond_stage_model.model.transformer.") :]] = checkpoint[key]
|
||||
#
|
||||
# text_model.load_state_dict(text_model_dict)
|
||||
keys = list(checkpoint.keys())
|
||||
|
||||
text_model_dict = {}
|
||||
|
||||
d_model = int(checkpoint["cond_stage_model.model.text_projection"].shape[0])
|
||||
|
||||
text_model_dict["text_model.embeddings.position_ids"] = text_model.text_model.embeddings.get_buffer("position_ids")
|
||||
|
||||
for key in keys:
|
||||
if "resblocks.23" in key: # Diffusers drops the final layer and only uses the penultimate layer
|
||||
continue
|
||||
if key in textenc_conversion_map:
|
||||
text_model_dict[textenc_conversion_map[key]] = checkpoint[key]
|
||||
if key.startswith("cond_stage_model.model.transformer."):
|
||||
new_key = key[len("cond_stage_model.model.transformer.") :]
|
||||
if new_key.endswith(".in_proj_weight"):
|
||||
new_key = new_key[: -len(".in_proj_weight")]
|
||||
new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key)
|
||||
text_model_dict[new_key + ".q_proj.weight"] = checkpoint[key][:d_model, :]
|
||||
text_model_dict[new_key + ".k_proj.weight"] = checkpoint[key][d_model : d_model * 2, :]
|
||||
text_model_dict[new_key + ".v_proj.weight"] = checkpoint[key][d_model * 2 :, :]
|
||||
elif new_key.endswith(".in_proj_bias"):
|
||||
new_key = new_key[: -len(".in_proj_bias")]
|
||||
new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key)
|
||||
text_model_dict[new_key + ".q_proj.bias"] = checkpoint[key][:d_model]
|
||||
text_model_dict[new_key + ".k_proj.bias"] = checkpoint[key][d_model : d_model * 2]
|
||||
text_model_dict[new_key + ".v_proj.bias"] = checkpoint[key][d_model * 2 :]
|
||||
else:
|
||||
new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key)
|
||||
|
||||
text_model_dict[new_key] = checkpoint[key]
|
||||
|
||||
text_model.load_state_dict(text_model_dict)
|
||||
|
||||
return text_model
|
||||
|
||||
@@ -744,6 +777,12 @@ if __name__ == "__main__":
|
||||
type=str,
|
||||
help="The YAML config file corresponding to the original architecture.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vae_checkpoint_path",
|
||||
default=None,
|
||||
type=str,
|
||||
help="The path to a vae checkpoint. If left to `None` the vae will be extracted from `checkpoint_path`."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_in_channels",
|
||||
default=None,
|
||||
@@ -789,6 +828,15 @@ if __name__ == "__main__":
|
||||
" higher quality images for inference. Non-EMA weights are usually better to continue fine-tuning."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--upcast_attn",
|
||||
default=False,
|
||||
type=bool,
|
||||
help=(
|
||||
"Whether the attention computation should always be upcasted. This is necessary when running stable"
|
||||
" diffusion 2.1."
|
||||
),
|
||||
)
|
||||
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -796,23 +844,40 @@ if __name__ == "__main__":
|
||||
prediction_type = args.prediction_type
|
||||
|
||||
checkpoint = torch.load(args.checkpoint_path)
|
||||
global_step = checkpoint["global_step"]
|
||||
checkpoint = checkpoint["state_dict"]
|
||||
|
||||
# Sometimes models don't have the global_step item
|
||||
if "global_step" in checkpoint:
|
||||
global_step = checkpoint["global_step"]
|
||||
else:
|
||||
print("global_step key not found in model")
|
||||
global_step = None
|
||||
|
||||
if "state_dict" in checkpoint:
|
||||
checkpoint = checkpoint["state_dict"]
|
||||
|
||||
upcast_attention = False
|
||||
if args.original_config_file is None:
|
||||
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
|
||||
|
||||
if key_name in checkpoint and checkpoint[key_name].shape[-1] == 1024:
|
||||
# model_type = "v2"
|
||||
os.system(
|
||||
"wget https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml"
|
||||
)
|
||||
if not os.path.isfile("v2-inference-v.yaml"):
|
||||
# model_type = "v2"
|
||||
os.system(
|
||||
"wget https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml"
|
||||
" -O v2-inference-v.yaml"
|
||||
)
|
||||
args.original_config_file = "./v2-inference-v.yaml"
|
||||
|
||||
if global_step == 110000:
|
||||
# v2.1 needs to upcast attention
|
||||
upcast_attention = True
|
||||
else:
|
||||
# model_type = "v1"
|
||||
os.system(
|
||||
"wget https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
|
||||
)
|
||||
if not os.path.isfile("v1-inference.yaml"):
|
||||
# model_type = "v1"
|
||||
os.system(
|
||||
"wget https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
|
||||
" -O v1-inference.yaml"
|
||||
)
|
||||
args.original_config_file = "./v1-inference.yaml"
|
||||
|
||||
original_config = OmegaConf.load(args.original_config_file)
|
||||
@@ -852,6 +917,9 @@ if __name__ == "__main__":
|
||||
set_alpha_to_one=False,
|
||||
prediction_type=prediction_type,
|
||||
)
|
||||
# make sure scheduler works correctly with DDIM
|
||||
scheduler.register_to_config(clip_sample=False)
|
||||
|
||||
if args.scheduler_type == "pndm":
|
||||
config = dict(scheduler.config)
|
||||
config["skip_prk_steps"] = True
|
||||
@@ -873,6 +941,7 @@ if __name__ == "__main__":
|
||||
|
||||
# Convert the UNet2DConditionModel model.
|
||||
unet_config = create_unet_diffusers_config(original_config, image_size=image_size)
|
||||
unet_config["upcast_attention"] = upcast_attention
|
||||
unet = UNet2DConditionModel(**unet_config)
|
||||
|
||||
converted_unet_checkpoint = convert_ldm_unet_checkpoint(
|
||||
@@ -883,7 +952,19 @@ if __name__ == "__main__":
|
||||
|
||||
# Convert the VAE model.
|
||||
vae_config = create_vae_diffusers_config(original_config, image_size=image_size)
|
||||
converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
|
||||
|
||||
if args.vae_checkpoint_path is not None:
|
||||
vae_state_dict = torch.load(args.vae_checkpoint_path)
|
||||
vae_state_dict = vae_state_dict["state_dict"]
|
||||
else:
|
||||
vae_state_dict = {}
|
||||
vae_key = "first_stage_model."
|
||||
keys = list(checkpoint.keys())
|
||||
for key in keys:
|
||||
if key.startswith(vae_key):
|
||||
vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
|
||||
|
||||
converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_state_dict, vae_config)
|
||||
|
||||
vae = AutoencoderKL(**vae_config)
|
||||
vae.load_state_dict(converted_vae_checkpoint)
|
||||
|
||||
2
setup.py
2
setup.py
@@ -218,7 +218,7 @@ install_requires = [
|
||||
|
||||
setup(
|
||||
name="diffusers",
|
||||
version="0.10.0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
version="0.11.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
description="Diffusers",
|
||||
long_description=open("README.md", "r", encoding="utf-8").read(),
|
||||
long_description_content_type="text/markdown",
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
__version__ = "0.10.0"
|
||||
__version__ = "0.11.0.dev0"
|
||||
|
||||
from .configuration_utils import ConfigMixin
|
||||
from .onnx_utils import OnnxRuntimeModel
|
||||
@@ -18,18 +18,6 @@ from .utils import (
|
||||
)
|
||||
|
||||
|
||||
# Make sure `transformers` is up to date
|
||||
if is_transformers_available():
|
||||
import transformers
|
||||
|
||||
if is_transformers_version("<", "4.25.1"):
|
||||
raise ImportError(
|
||||
f"`diffusers` requires transformers >= 4.25.1 to function correctly, but {transformers.__version__} was"
|
||||
" found in your environment. You can upgrade it with pip: `pip install transformers --upgrade`"
|
||||
)
|
||||
else:
|
||||
pass
|
||||
|
||||
try:
|
||||
if not is_torch_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
|
||||
@@ -16,26 +16,36 @@
|
||||
|
||||
import importlib
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import sys
|
||||
from distutils.version import StrictVersion
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional, Union
|
||||
from urllib import request
|
||||
|
||||
from huggingface_hub import HfFolder, cached_download, hf_hub_download, model_info
|
||||
|
||||
from . import __version__
|
||||
from .utils import DIFFUSERS_DYNAMIC_MODULE_NAME, HF_MODULES_CACHE, logging
|
||||
|
||||
|
||||
COMMUNITY_PIPELINES_URL = (
|
||||
"https://raw.githubusercontent.com/huggingface/diffusers/main/examples/community/{pipeline}.py"
|
||||
"https://raw.githubusercontent.com/huggingface/diffusers/{revision}/examples/community/{pipeline}.py"
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def get_diffusers_versions():
|
||||
url = "https://pypi.org/pypi/diffusers/json"
|
||||
releases = json.loads(request.urlopen(url).read())["releases"].keys()
|
||||
return sorted(releases, key=StrictVersion)
|
||||
|
||||
|
||||
def init_hf_modules():
|
||||
"""
|
||||
Creates the cache directory for modules with an init, and adds it to the Python path.
|
||||
@@ -251,8 +261,26 @@ def get_cached_module_file(
|
||||
resolved_module_file = module_file_or_url
|
||||
submodule = "local"
|
||||
elif pretrained_model_name_or_path.count("/") == 0:
|
||||
available_versions = get_diffusers_versions()
|
||||
# cut ".dev0"
|
||||
latest_version = "v" + ".".join(__version__.split(".")[:3])
|
||||
|
||||
# retrieve github version that matches
|
||||
if revision is None:
|
||||
revision = latest_version if latest_version in available_versions else "main"
|
||||
logger.info(f"Defaulting to latest_version: {revision}.")
|
||||
elif revision in available_versions:
|
||||
revision = f"v{revision}"
|
||||
elif revision == "main":
|
||||
revision = revision
|
||||
else:
|
||||
raise ValueError(
|
||||
f"`custom_revision`: {revision} does not exist. Please make sure to choose one of"
|
||||
f" {', '.join(available_versions + ['main'])}."
|
||||
)
|
||||
|
||||
# community pipeline on GitHub
|
||||
github_url = COMMUNITY_PIPELINES_URL.format(pipeline=pretrained_model_name_or_path)
|
||||
github_url = COMMUNITY_PIPELINES_URL.format(revision=revision, pipeline=pretrained_model_name_or_path)
|
||||
try:
|
||||
resolved_module_file = cached_download(
|
||||
github_url,
|
||||
|
||||
@@ -20,7 +20,6 @@ from pathlib import Path
|
||||
from typing import Dict, Optional, Union
|
||||
from uuid import uuid4
|
||||
|
||||
import requests
|
||||
from huggingface_hub import HfFolder, whoami
|
||||
|
||||
from . import __version__
|
||||
@@ -56,7 +55,7 @@ def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str:
|
||||
Formats a user-agent string with basic info about a request.
|
||||
"""
|
||||
ua = f"diffusers/{__version__}; python/{sys.version.split()[0]}; session_id/{SESSION_ID}"
|
||||
if DISABLE_TELEMETRY:
|
||||
if DISABLE_TELEMETRY or HF_HUB_OFFLINE:
|
||||
return ua + "; telemetry/off"
|
||||
if is_torch_available():
|
||||
ua += f"; torch/{_torch_version}"
|
||||
@@ -75,27 +74,6 @@ def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str:
|
||||
return ua
|
||||
|
||||
|
||||
def send_telemetry(data: Dict, name: str):
|
||||
"""
|
||||
Sends logs to the Hub telemetry endpoint.
|
||||
|
||||
Args:
|
||||
data: the fields to track, e.g. {"example_name": "dreambooth"}
|
||||
name: a unique name to differentiate the telemetry logs, e.g. "diffusers_examples" or "diffusers_notebooks"
|
||||
"""
|
||||
if DISABLE_TELEMETRY or HF_HUB_OFFLINE:
|
||||
pass
|
||||
|
||||
headers = {"user-agent": http_user_agent(data)}
|
||||
endpoint = HUGGINGFACE_CO_TELEMETRY + name
|
||||
try:
|
||||
r = requests.head(endpoint, headers=headers)
|
||||
r.raise_for_status()
|
||||
except Exception:
|
||||
# We don't want to error in case of connection errors of any kind.
|
||||
pass
|
||||
|
||||
|
||||
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
|
||||
if token is None:
|
||||
token = HfFolder.get_token()
|
||||
|
||||
@@ -28,7 +28,6 @@ from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, R
|
||||
from requests import HTTPError
|
||||
|
||||
from . import __version__, is_torch_available
|
||||
from .hub_utils import send_telemetry
|
||||
from .modeling_flax_pytorch_utils import convert_pytorch_state_dict_to_flax
|
||||
from .utils import (
|
||||
CONFIG_NAME,
|
||||
@@ -340,10 +339,6 @@ class FlaxModelMixin:
|
||||
f"Error no file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME} found in directory "
|
||||
f"{pretrained_path_with_subfolder}."
|
||||
)
|
||||
send_telemetry(
|
||||
{"model_class": cls.__name__, "model_path": "local", "framework": "flax"},
|
||||
name="diffusers_from_pretrained",
|
||||
)
|
||||
else:
|
||||
try:
|
||||
model_file = hf_hub_download(
|
||||
@@ -359,10 +354,6 @@ class FlaxModelMixin:
|
||||
subfolder=subfolder,
|
||||
revision=revision,
|
||||
)
|
||||
send_telemetry(
|
||||
{"model_class": cls.__name__, "model_path": "hub", "framework": "flax"},
|
||||
name="diffusers_from_pretrained",
|
||||
)
|
||||
|
||||
except RepositoryNotFoundError:
|
||||
raise EnvironmentError(
|
||||
|
||||
@@ -26,7 +26,6 @@ from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, R
|
||||
from requests import HTTPError
|
||||
|
||||
from . import __version__
|
||||
from .hub_utils import send_telemetry
|
||||
from .utils import (
|
||||
CONFIG_NAME,
|
||||
DIFFUSERS_CACHE,
|
||||
@@ -188,6 +187,39 @@ class ModelMixin(torch.nn.Module):
|
||||
if self._supports_gradient_checkpointing:
|
||||
self.apply(partial(self._set_gradient_checkpointing, value=False))
|
||||
|
||||
def set_use_memory_efficient_attention_xformers(self, valid: bool) -> None:
|
||||
# Recursively walk through all the children.
|
||||
# Any children which exposes the set_use_memory_efficient_attention_xformers method
|
||||
# gets the message
|
||||
def fn_recursive_set_mem_eff(module: torch.nn.Module):
|
||||
if hasattr(module, "set_use_memory_efficient_attention_xformers"):
|
||||
module.set_use_memory_efficient_attention_xformers(valid)
|
||||
|
||||
for child in module.children():
|
||||
fn_recursive_set_mem_eff(child)
|
||||
|
||||
for module in self.children():
|
||||
if isinstance(module, torch.nn.Module):
|
||||
fn_recursive_set_mem_eff(module)
|
||||
|
||||
def enable_xformers_memory_efficient_attention(self):
|
||||
r"""
|
||||
Enable memory efficient attention as implemented in xformers.
|
||||
|
||||
When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
|
||||
time. Speed up at training time is not guaranteed.
|
||||
|
||||
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
|
||||
is used.
|
||||
"""
|
||||
self.set_use_memory_efficient_attention_xformers(True)
|
||||
|
||||
def disable_xformers_memory_efficient_attention(self):
|
||||
r"""
|
||||
Disable memory efficient attention as implemented in xformers.
|
||||
"""
|
||||
self.set_use_memory_efficient_attention_xformers(False)
|
||||
|
||||
def save_pretrained(
|
||||
self,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
@@ -561,10 +593,6 @@ class ModelMixin(torch.nn.Module):
|
||||
raise EnvironmentError(
|
||||
f"Error no file named {weights_name} found in directory {pretrained_model_name_or_path}."
|
||||
)
|
||||
send_telemetry(
|
||||
{"model_class": cls.__name__, "model_path": "local", "framework": "pytorch"},
|
||||
name="diffusers_from_pretrained",
|
||||
)
|
||||
return model_file
|
||||
else:
|
||||
try:
|
||||
@@ -582,10 +610,6 @@ class ModelMixin(torch.nn.Module):
|
||||
subfolder=subfolder,
|
||||
revision=revision,
|
||||
)
|
||||
send_telemetry(
|
||||
{"model_class": cls.__name__, "model_path": "hub", "framework": "pytorch"},
|
||||
name="diffusers_from_pretrained",
|
||||
)
|
||||
return model_file
|
||||
|
||||
except RepositoryNotFoundError:
|
||||
|
||||
@@ -12,7 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import math
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
@@ -447,16 +446,6 @@ class BasicTransformerBlock(nn.Module):
|
||||
# 3. Feed-forward
|
||||
self.norm3 = nn.LayerNorm(dim)
|
||||
|
||||
# if xformers is installed try to use memory_efficient_attention by default
|
||||
if is_xformers_available():
|
||||
try:
|
||||
self.set_use_memory_efficient_attention_xformers(True)
|
||||
except Exception as e:
|
||||
warnings.warn(
|
||||
"Could not enable memory efficient attention. Make sure xformers is installed"
|
||||
f" correctly and a GPU is available: {e}"
|
||||
)
|
||||
|
||||
def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
|
||||
if not is_xformers_available():
|
||||
print("Here is how to install it")
|
||||
|
||||
@@ -218,6 +218,7 @@ class UNet1DModel(ModelMixin, ConfigMixin):
|
||||
else:
|
||||
timestep_embed = timestep_embed[..., None]
|
||||
timestep_embed = timestep_embed.repeat([1, 1, sample.shape[2]]).to(sample.dtype)
|
||||
timestep_embed = timestep_embed.broadcast_to((sample.shape[:1] + timestep_embed.shape[1:]))
|
||||
|
||||
# 2. down
|
||||
down_block_res_samples = ()
|
||||
|
||||
@@ -29,7 +29,7 @@ from PIL import Image
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from .configuration_utils import ConfigMixin
|
||||
from .hub_utils import http_user_agent, send_telemetry
|
||||
from .hub_utils import http_user_agent
|
||||
from .modeling_flax_utils import FLAX_WEIGHTS_NAME, FlaxModelMixin
|
||||
from .schedulers.scheduling_utils_flax import SCHEDULER_CONFIG_NAME, FlaxSchedulerMixin
|
||||
from .utils import CONFIG_NAME, DIFFUSERS_CACHE, BaseOutput, is_transformers_available, logging
|
||||
@@ -346,16 +346,8 @@ class FlaxDiffusionPipeline(ConfigMixin):
|
||||
ignore_patterns=ignore_patterns,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
send_telemetry(
|
||||
{"pipeline_class": requested_pipeline_class, "pipeline_path": "hub", "framework": "flax"},
|
||||
name="diffusers_from_pretrained",
|
||||
)
|
||||
else:
|
||||
cached_folder = pretrained_model_name_or_path
|
||||
send_telemetry(
|
||||
{"pipeline_class": cls.__name__, "pipeline_path": "local", "framework": "flax"},
|
||||
name="diffusers_from_pretrained",
|
||||
)
|
||||
|
||||
config_dict = cls.load_config(cached_folder)
|
||||
|
||||
|
||||
@@ -33,7 +33,7 @@ from tqdm.auto import tqdm
|
||||
|
||||
from .configuration_utils import ConfigMixin
|
||||
from .dynamic_modules_utils import get_class_from_dynamic_module
|
||||
from .hub_utils import http_user_agent, send_telemetry
|
||||
from .hub_utils import http_user_agent
|
||||
from .modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT
|
||||
from .schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
|
||||
from .utils import (
|
||||
@@ -375,6 +375,10 @@ class DiffusionPipeline(ConfigMixin):
|
||||
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
||||
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
|
||||
identifier allowed by git.
|
||||
custom_revision (`str`, *optional*, defaults to `"main"` when loading from the Hub and to local version of `diffusers` when loading from GitHub):
|
||||
The specific model version to use. It can be a branch name, a tag name, or a commit id similar to
|
||||
`revision` when loading a custom pipeline from the Hub. It can be a diffusers version when loading a
|
||||
custom pipeline from GitHub.
|
||||
mirror (`str`, *optional*):
|
||||
Mirror source to accelerate downloads in China. If you are from China and have an accessibility
|
||||
problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
|
||||
@@ -442,6 +446,7 @@ class DiffusionPipeline(ConfigMixin):
|
||||
revision = kwargs.pop("revision", None)
|
||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||
custom_pipeline = kwargs.pop("custom_pipeline", None)
|
||||
custom_revision = kwargs.pop("custom_revision", None)
|
||||
provider = kwargs.pop("provider", None)
|
||||
sess_options = kwargs.pop("sess_options", None)
|
||||
device_map = kwargs.pop("device_map", None)
|
||||
@@ -504,16 +509,8 @@ class DiffusionPipeline(ConfigMixin):
|
||||
ignore_patterns=ignore_patterns,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
send_telemetry(
|
||||
{"pipeline_class": requested_pipeline_class, "pipeline_path": "hub", "framework": "pytorch"},
|
||||
name="diffusers_from_pretrained",
|
||||
)
|
||||
else:
|
||||
cached_folder = pretrained_model_name_or_path
|
||||
send_telemetry(
|
||||
{"pipeline_class": cls.__name__, "pipeline_path": "local", "framework": "pytorch"},
|
||||
name="diffusers_from_pretrained",
|
||||
)
|
||||
|
||||
config_dict = cls.load_config(cached_folder)
|
||||
|
||||
@@ -528,7 +525,9 @@ class DiffusionPipeline(ConfigMixin):
|
||||
else:
|
||||
file_name = CUSTOM_PIPELINE_FILE_NAME
|
||||
|
||||
pipeline_class = get_class_from_dynamic_module(custom_pipeline, module_file=file_name, cache_dir=cache_dir)
|
||||
pipeline_class = get_class_from_dynamic_module(
|
||||
custom_pipeline, module_file=file_name, cache_dir=cache_dir, revision=custom_revision
|
||||
)
|
||||
elif cls != DiffusionPipeline:
|
||||
pipeline_class = cls
|
||||
else:
|
||||
|
||||
@@ -249,9 +249,9 @@ class AltDiffusionPipeline(DiffusionPipeline):
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids
|
||||
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
|
||||
if not torch.equal(text_input_ids, untruncated_ids):
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
||||
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
@@ -522,8 +522,8 @@ class AltDiffusionPipeline(DiffusionPipeline):
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
|
||||
@@ -44,13 +44,24 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess
|
||||
def preprocess(image):
|
||||
w, h = image.size
|
||||
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
|
||||
image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
|
||||
image = np.array(image).astype(np.float32) / 255.0
|
||||
image = image[None].transpose(0, 3, 1, 2)
|
||||
image = torch.from_numpy(image)
|
||||
return 2.0 * image - 1.0
|
||||
if isinstance(image, torch.Tensor):
|
||||
return image
|
||||
elif isinstance(image, PIL.Image.Image):
|
||||
image = [image]
|
||||
|
||||
if isinstance(image[0], PIL.Image.Image):
|
||||
w, h = image[0].size
|
||||
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
|
||||
|
||||
image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image]
|
||||
image = np.concatenate(image, axis=0)
|
||||
image = np.array(image).astype(np.float32) / 255.0
|
||||
image = image.transpose(0, 3, 1, 2)
|
||||
image = 2.0 * image - 1.0
|
||||
image = torch.from_numpy(image)
|
||||
elif isinstance(image[0], torch.Tensor):
|
||||
image = torch.cat(image, dim=0)
|
||||
return image
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline with Stable->Alt, CLIPTextModel->RobertaSeriesModelWithTransformation, CLIPTokenizer->XLMRobertaTokenizer, AltDiffusionSafetyChecker->StableDiffusionSafetyChecker
|
||||
@@ -81,7 +92,7 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
feature_extractor ([`CLIPFeatureExtractor`]):
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
"""
|
||||
_optional_components = ["safety_checker", "feature_extractor"]
|
||||
_optional_components = ["safety_checker"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -246,9 +257,9 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids
|
||||
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
|
||||
if not torch.equal(text_input_ids, untruncated_ids):
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
||||
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
@@ -376,11 +387,9 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
|
||||
def get_timesteps(self, num_inference_steps, strength, device):
|
||||
# get the original timestep using init_timestep
|
||||
offset = self.scheduler.config.get("steps_offset", 0)
|
||||
init_timestep = int(num_inference_steps * strength) + offset
|
||||
init_timestep = min(init_timestep, num_inference_steps)
|
||||
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
|
||||
|
||||
t_start = max(num_inference_steps - init_timestep + offset, 0)
|
||||
t_start = max(num_inference_steps - init_timestep, 0)
|
||||
timesteps = self.scheduler.timesteps[t_start:]
|
||||
|
||||
return timesteps, num_inference_steps - t_start
|
||||
@@ -422,7 +431,7 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
image: Union[torch.FloatTensor, PIL.Image.Image],
|
||||
image: Union[torch.FloatTensor, PIL.Image.Image] = None,
|
||||
strength: float = 0.8,
|
||||
num_inference_steps: Optional[int] = 50,
|
||||
guidance_scale: Optional[float] = 7.5,
|
||||
@@ -512,8 +521,7 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
)
|
||||
|
||||
# 4. Preprocess image
|
||||
if isinstance(image, PIL.Image.Image):
|
||||
image = preprocess(image)
|
||||
image = preprocess(image)
|
||||
|
||||
# 5. set timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
|
||||
@@ -46,7 +46,6 @@ class DDIMPipeline(DiffusionPipeline):
|
||||
use_clipped_model_output: Optional[bool] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
**kwargs,
|
||||
) -> Union[ImagePipelineOutput, Tuple]:
|
||||
r"""
|
||||
Args:
|
||||
@@ -83,7 +82,7 @@ class DDIMPipeline(DiffusionPipeline):
|
||||
)
|
||||
deprecate(
|
||||
"generator.device == 'cpu'",
|
||||
"0.11.0",
|
||||
"0.12.0",
|
||||
message,
|
||||
)
|
||||
generator = None
|
||||
|
||||
@@ -73,7 +73,7 @@ class DDPMPipeline(DiffusionPipeline):
|
||||
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
|
||||
" DDPMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
|
||||
)
|
||||
predict_epsilon = deprecate("predict_epsilon", "0.11.0", message, take_from=kwargs)
|
||||
predict_epsilon = deprecate("predict_epsilon", "0.12.0", message, take_from=kwargs)
|
||||
|
||||
if predict_epsilon is not None:
|
||||
new_config = dict(self.scheduler.config)
|
||||
@@ -88,7 +88,7 @@ class DDPMPipeline(DiffusionPipeline):
|
||||
)
|
||||
deprecate(
|
||||
"generator.device == 'cpu'",
|
||||
"0.11.0",
|
||||
"0.12.0",
|
||||
message,
|
||||
)
|
||||
generator = None
|
||||
|
||||
@@ -66,7 +66,7 @@ class LDMSuperResolutionPipeline(DiffusionPipeline):
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
image: Union[torch.Tensor, PIL.Image.Image],
|
||||
image: Union[torch.Tensor, PIL.Image.Image] = None,
|
||||
batch_size: Optional[int] = 1,
|
||||
num_inference_steps: Optional[int] = 100,
|
||||
eta: Optional[float] = 0.0,
|
||||
|
||||
@@ -109,16 +109,18 @@ def prepare_mask_and_masked_image(image, mask):
|
||||
raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not")
|
||||
else:
|
||||
if isinstance(image, PIL.Image.Image):
|
||||
image = np.array(image.convert("RGB"))
|
||||
image = [image]
|
||||
|
||||
image = image[None].transpose(0, 3, 1, 2)
|
||||
image = np.concatenate([np.array(i.convert("RGB"))[None, :] for i in image], axis=0)
|
||||
image = image.transpose(0, 3, 1, 2)
|
||||
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
|
||||
|
||||
# preprocess mask
|
||||
if isinstance(mask, PIL.Image.Image):
|
||||
mask = np.array(mask.convert("L"))
|
||||
mask = mask.astype(np.float32) / 255.0
|
||||
mask = [mask]
|
||||
|
||||
mask = mask[None, None]
|
||||
mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0)
|
||||
mask = mask.astype(np.float32) / 255.0
|
||||
|
||||
# paint-by-example inverses the mask
|
||||
mask = 1 - mask
|
||||
@@ -159,7 +161,7 @@ class PaintByExamplePipeline(DiffusionPipeline):
|
||||
feature_extractor ([`CLIPFeatureExtractor`]):
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
"""
|
||||
_optional_components = ["safety_checker", "feature_extractor"]
|
||||
_optional_components = ["safety_checker"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -271,7 +273,8 @@ class PaintByExamplePipeline(DiffusionPipeline):
|
||||
and not isinstance(image, list)
|
||||
):
|
||||
raise ValueError(
|
||||
f"`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `list` but is {type(image)}"
|
||||
"`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is"
|
||||
f" {type(image)}"
|
||||
)
|
||||
|
||||
if height % 8 != 0 or width % 8 != 0:
|
||||
@@ -322,8 +325,22 @@ class PaintByExamplePipeline(DiffusionPipeline):
|
||||
masked_image_latents = 0.18215 * masked_image_latents
|
||||
|
||||
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
|
||||
mask = mask.repeat(batch_size, 1, 1, 1)
|
||||
masked_image_latents = masked_image_latents.repeat(batch_size, 1, 1, 1)
|
||||
if mask.shape[0] < batch_size:
|
||||
if not batch_size % mask.shape[0] == 0:
|
||||
raise ValueError(
|
||||
"The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
|
||||
f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
|
||||
" of masks that you pass is divisible by the total requested batch size."
|
||||
)
|
||||
mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
|
||||
if masked_image_latents.shape[0] < batch_size:
|
||||
if not batch_size % masked_image_latents.shape[0] == 0:
|
||||
raise ValueError(
|
||||
"The passed images and the required batch size don't match. Images are supposed to be duplicated"
|
||||
f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
|
||||
" Make sure the number of images that you pass is divisible by the total requested batch size."
|
||||
)
|
||||
masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1)
|
||||
|
||||
mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
|
||||
masked_image_latents = (
|
||||
@@ -350,7 +367,7 @@ class PaintByExamplePipeline(DiffusionPipeline):
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
uncond_embeddings = self.image_encoder.uncond_vector
|
||||
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
|
||||
uncond_embeddings = uncond_embeddings.repeat(1, image_embeddings.shape[0], 1)
|
||||
uncond_embeddings = uncond_embeddings.view(bs_embed * num_images_per_prompt, 1, -1)
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
|
||||
@@ -46,7 +46,7 @@ if is_transformers_available() and is_torch_available():
|
||||
from .safety_checker import StableDiffusionSafetyChecker
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.25.0")):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import StableDiffusionImageVariationPipeline
|
||||
|
||||
@@ -35,14 +35,26 @@ from .safety_checker import StableDiffusionSafetyChecker
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess
|
||||
def preprocess(image):
|
||||
w, h = image.size
|
||||
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
|
||||
image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
|
||||
image = np.array(image).astype(np.float32) / 255.0
|
||||
image = image[None].transpose(0, 3, 1, 2)
|
||||
image = torch.from_numpy(image)
|
||||
return 2.0 * image - 1.0
|
||||
if isinstance(image, torch.Tensor):
|
||||
return image
|
||||
elif isinstance(image, PIL.Image.Image):
|
||||
image = [image]
|
||||
|
||||
if isinstance(image[0], PIL.Image.Image):
|
||||
w, h = image[0].size
|
||||
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
|
||||
|
||||
image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image]
|
||||
image = np.concatenate(image, axis=0)
|
||||
image = np.array(image).astype(np.float32) / 255.0
|
||||
image = image.transpose(0, 3, 1, 2)
|
||||
image = 2.0 * image - 1.0
|
||||
image = torch.from_numpy(image)
|
||||
elif isinstance(image[0], torch.Tensor):
|
||||
image = torch.cat(image, dim=0)
|
||||
return image
|
||||
|
||||
|
||||
def posterior_sample(scheduler, latents, timestep, clean_latents, generator, eta):
|
||||
@@ -279,9 +291,9 @@ class CycleDiffusionPipeline(DiffusionPipeline):
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids
|
||||
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
|
||||
if not torch.equal(text_input_ids, untruncated_ids):
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
||||
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
@@ -414,11 +426,9 @@ class CycleDiffusionPipeline(DiffusionPipeline):
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
|
||||
def get_timesteps(self, num_inference_steps, strength, device):
|
||||
# get the original timestep using init_timestep
|
||||
offset = self.scheduler.config.get("steps_offset", 0)
|
||||
init_timestep = int(num_inference_steps * strength) + offset
|
||||
init_timestep = min(init_timestep, num_inference_steps)
|
||||
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
|
||||
|
||||
t_start = max(num_inference_steps - init_timestep + offset, 0)
|
||||
t_start = max(num_inference_steps - init_timestep, 0)
|
||||
timesteps = self.scheduler.timesteps[t_start:]
|
||||
|
||||
return timesteps, num_inference_steps - t_start
|
||||
@@ -462,7 +472,7 @@ class CycleDiffusionPipeline(DiffusionPipeline):
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
source_prompt: Union[str, List[str]],
|
||||
image: Union[torch.FloatTensor, PIL.Image.Image],
|
||||
image: Union[torch.FloatTensor, PIL.Image.Image] = None,
|
||||
strength: float = 0.8,
|
||||
num_inference_steps: Optional[int] = 50,
|
||||
guidance_scale: Optional[float] = 7.5,
|
||||
@@ -553,8 +563,7 @@ class CycleDiffusionPipeline(DiffusionPipeline):
|
||||
)
|
||||
|
||||
# 4. Preprocess image
|
||||
if isinstance(image, PIL.Image.Image):
|
||||
image = preprocess(image)
|
||||
image = preprocess(image)
|
||||
|
||||
# 5. Prepare timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
|
||||
@@ -32,13 +32,26 @@ from . import StableDiffusionPipelineOutput
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess
|
||||
def preprocess(image):
|
||||
w, h = image.size
|
||||
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
|
||||
image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
|
||||
image = np.array(image).astype(np.float32) / 255.0
|
||||
image = image[None].transpose(0, 3, 1, 2)
|
||||
return 2.0 * image - 1.0
|
||||
if isinstance(image, torch.Tensor):
|
||||
return image
|
||||
elif isinstance(image, PIL.Image.Image):
|
||||
image = [image]
|
||||
|
||||
if isinstance(image[0], PIL.Image.Image):
|
||||
w, h = image[0].size
|
||||
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
|
||||
|
||||
image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image]
|
||||
image = np.concatenate(image, axis=0)
|
||||
image = np.array(image).astype(np.float32) / 255.0
|
||||
image = image.transpose(0, 3, 1, 2)
|
||||
image = 2.0 * image - 1.0
|
||||
image = torch.from_numpy(image)
|
||||
elif isinstance(image[0], torch.Tensor):
|
||||
image = torch.cat(image, dim=0)
|
||||
return image
|
||||
|
||||
|
||||
class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
@@ -77,7 +90,7 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
safety_checker: OnnxRuntimeModel
|
||||
feature_extractor: CLIPFeatureExtractor
|
||||
|
||||
_optional_components = ["safety_checker", "feature_extractor"]
|
||||
_optional_components = ["safety_checker"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -229,7 +242,7 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
image: Union[np.ndarray, PIL.Image.Image],
|
||||
image: Union[np.ndarray, PIL.Image.Image] = None,
|
||||
strength: float = 0.8,
|
||||
num_inference_steps: Optional[int] = 50,
|
||||
guidance_scale: Optional[float] = 7.5,
|
||||
@@ -325,8 +338,7 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
# set timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
if isinstance(image, PIL.Image.Image):
|
||||
image = preprocess(image)
|
||||
image = preprocess(image)
|
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
|
||||
@@ -228,8 +228,8 @@ class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
image: Union[np.ndarray, PIL.Image.Image],
|
||||
mask_image: Union[np.ndarray, PIL.Image.Image],
|
||||
image: Union[np.ndarray, PIL.Image.Image] = None,
|
||||
mask_image: Union[np.ndarray, PIL.Image.Image] = None,
|
||||
strength: float = 0.8,
|
||||
num_inference_steps: Optional[int] = 50,
|
||||
guidance_scale: Optional[float] = 7.5,
|
||||
|
||||
@@ -248,9 +248,9 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids
|
||||
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
|
||||
if not torch.equal(text_input_ids, untruncated_ids):
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
||||
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
@@ -521,8 +521,8 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
|
||||
@@ -41,14 +41,26 @@ from ...utils import PIL_INTERPOLATION, deprecate, logging
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess
|
||||
def preprocess(image):
|
||||
w, h = image.size
|
||||
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
|
||||
image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
|
||||
image = np.array(image).astype(np.float32) / 255.0
|
||||
image = image[None].transpose(0, 3, 1, 2)
|
||||
image = torch.from_numpy(image)
|
||||
return 2.0 * image - 1.0
|
||||
if isinstance(image, torch.Tensor):
|
||||
return image
|
||||
elif isinstance(image, PIL.Image.Image):
|
||||
image = [image]
|
||||
|
||||
if isinstance(image[0], PIL.Image.Image):
|
||||
w, h = image[0].size
|
||||
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
|
||||
|
||||
image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image]
|
||||
image = np.concatenate(image, axis=0)
|
||||
image = np.array(image).astype(np.float32) / 255.0
|
||||
image = image.transpose(0, 3, 1, 2)
|
||||
image = 2.0 * image - 1.0
|
||||
image = torch.from_numpy(image)
|
||||
elif isinstance(image[0], torch.Tensor):
|
||||
image = torch.cat(image, dim=0)
|
||||
return image
|
||||
|
||||
|
||||
class StableDiffusionDepth2ImgPipeline(DiffusionPipeline):
|
||||
@@ -189,9 +201,9 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline):
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids
|
||||
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
|
||||
if not torch.equal(text_input_ids, untruncated_ids):
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
||||
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
@@ -323,11 +335,9 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline):
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
|
||||
def get_timesteps(self, num_inference_steps, strength, device):
|
||||
# get the original timestep using init_timestep
|
||||
offset = self.scheduler.config.get("steps_offset", 0)
|
||||
init_timestep = int(num_inference_steps * strength) + offset
|
||||
init_timestep = min(init_timestep, num_inference_steps)
|
||||
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
|
||||
|
||||
t_start = max(num_inference_steps - init_timestep + offset, 0)
|
||||
t_start = max(num_inference_steps - init_timestep, 0)
|
||||
timesteps = self.scheduler.timesteps[t_start:]
|
||||
|
||||
return timesteps, num_inference_steps - t_start
|
||||
@@ -368,12 +378,13 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline):
|
||||
|
||||
def prepare_depth_map(self, image, depth_map, batch_size, do_classifier_free_guidance, dtype, device):
|
||||
if isinstance(image, PIL.Image.Image):
|
||||
width, height = image.size
|
||||
width, height = map(lambda dim: dim - dim % 32, (width, height)) # resize to integer multiple of 32
|
||||
image = image.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
|
||||
width, height = image.size
|
||||
image = [image]
|
||||
else:
|
||||
image = [img for img in image]
|
||||
|
||||
if isinstance(image[0], PIL.Image.Image):
|
||||
width, height = image[0].size
|
||||
else:
|
||||
width, height = image[0].shape[-2:]
|
||||
|
||||
if depth_map is None:
|
||||
@@ -423,7 +434,6 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline):
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
callback_steps: Optional[int] = 1,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
@@ -496,7 +506,7 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline):
|
||||
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
|
||||
)
|
||||
|
||||
# 4. Prepare depth mask
|
||||
# 4. Preprocess image
|
||||
depth_mask = self.prepare_depth_map(
|
||||
image,
|
||||
depth_map,
|
||||
@@ -506,11 +516,8 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline):
|
||||
device,
|
||||
)
|
||||
|
||||
# 5. Preprocess image
|
||||
if isinstance(image, PIL.Image.Image):
|
||||
image = preprocess(image)
|
||||
else:
|
||||
image = 2.0 * (image / 255.0) - 1.0
|
||||
# 5. Prepare depth mask
|
||||
image = preprocess(image)
|
||||
|
||||
# 6. set timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
|
||||
@@ -65,7 +65,7 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline):
|
||||
feature_extractor ([`CLIPFeatureExtractor`]):
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
"""
|
||||
_optional_components = ["safety_checker", "feature_extractor"]
|
||||
_optional_components = ["safety_checker"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -240,7 +240,8 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline):
|
||||
and not isinstance(image, list)
|
||||
):
|
||||
raise ValueError(
|
||||
f"`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `list` but is {type(image)}"
|
||||
"`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is"
|
||||
f" {type(image)}"
|
||||
)
|
||||
|
||||
if height % 8 != 0 or width % 8 != 0:
|
||||
|
||||
@@ -43,13 +43,24 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def preprocess(image):
|
||||
w, h = image.size
|
||||
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
|
||||
image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
|
||||
image = np.array(image).astype(np.float32) / 255.0
|
||||
image = image[None].transpose(0, 3, 1, 2)
|
||||
image = torch.from_numpy(image)
|
||||
return 2.0 * image - 1.0
|
||||
if isinstance(image, torch.Tensor):
|
||||
return image
|
||||
elif isinstance(image, PIL.Image.Image):
|
||||
image = [image]
|
||||
|
||||
if isinstance(image[0], PIL.Image.Image):
|
||||
w, h = image[0].size
|
||||
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
|
||||
|
||||
image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image]
|
||||
image = np.concatenate(image, axis=0)
|
||||
image = np.array(image).astype(np.float32) / 255.0
|
||||
image = image.transpose(0, 3, 1, 2)
|
||||
image = 2.0 * image - 1.0
|
||||
image = torch.from_numpy(image)
|
||||
elif isinstance(image[0], torch.Tensor):
|
||||
image = torch.cat(image, dim=0)
|
||||
return image
|
||||
|
||||
|
||||
class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
@@ -79,7 +90,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
feature_extractor ([`CLIPFeatureExtractor`]):
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
"""
|
||||
_optional_components = ["safety_checker", "feature_extractor"]
|
||||
_optional_components = ["safety_checker"]
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.__init__
|
||||
def __init__(
|
||||
@@ -248,9 +259,9 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids
|
||||
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
|
||||
if not torch.equal(text_input_ids, untruncated_ids):
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
||||
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
@@ -381,11 +392,9 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
|
||||
def get_timesteps(self, num_inference_steps, strength, device):
|
||||
# get the original timestep using init_timestep
|
||||
offset = self.scheduler.config.get("steps_offset", 0)
|
||||
init_timestep = int(num_inference_steps * strength) + offset
|
||||
init_timestep = min(init_timestep, num_inference_steps)
|
||||
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
|
||||
|
||||
t_start = max(num_inference_steps - init_timestep + offset, 0)
|
||||
t_start = max(num_inference_steps - init_timestep, 0)
|
||||
timesteps = self.scheduler.timesteps[t_start:]
|
||||
|
||||
return timesteps, num_inference_steps - t_start
|
||||
@@ -427,7 +436,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
image: Union[torch.FloatTensor, PIL.Image.Image],
|
||||
image: Union[torch.FloatTensor, PIL.Image.Image] = None,
|
||||
strength: float = 0.8,
|
||||
num_inference_steps: Optional[int] = 50,
|
||||
guidance_scale: Optional[float] = 7.5,
|
||||
@@ -517,8 +526,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
)
|
||||
|
||||
# 4. Preprocess image
|
||||
if isinstance(image, PIL.Image.Image):
|
||||
image = preprocess(image)
|
||||
image = preprocess(image)
|
||||
|
||||
# 5. set timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
|
||||
@@ -107,14 +107,29 @@ def prepare_mask_and_masked_image(image, mask):
|
||||
elif isinstance(mask, torch.Tensor):
|
||||
raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not")
|
||||
else:
|
||||
if isinstance(image, PIL.Image.Image):
|
||||
image = np.array(image.convert("RGB"))
|
||||
image = image[None].transpose(0, 3, 1, 2)
|
||||
# preprocess image
|
||||
if isinstance(image, (PIL.Image.Image, np.ndarray)):
|
||||
image = [image]
|
||||
|
||||
if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):
|
||||
image = [np.array(i.convert("RGB"))[None, :] for i in image]
|
||||
image = np.concatenate(image, axis=0)
|
||||
elif isinstance(image, list) and isinstance(image[0], np.ndarray):
|
||||
image = np.concatenate([i[None, :] for i in image], axis=0)
|
||||
|
||||
image = image.transpose(0, 3, 1, 2)
|
||||
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
|
||||
if isinstance(mask, PIL.Image.Image):
|
||||
mask = np.array(mask.convert("L"))
|
||||
|
||||
# preprocess mask
|
||||
if isinstance(mask, (PIL.Image.Image, np.ndarray)):
|
||||
mask = [mask]
|
||||
|
||||
if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image):
|
||||
mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0)
|
||||
mask = mask.astype(np.float32) / 255.0
|
||||
mask = mask[None, None]
|
||||
elif isinstance(mask, list) and isinstance(mask[0], np.ndarray):
|
||||
mask = np.concatenate([m[None, None, :] for m in mask], axis=0)
|
||||
|
||||
mask[mask < 0.5] = 0
|
||||
mask[mask >= 0.5] = 1
|
||||
mask = torch.from_numpy(mask)
|
||||
@@ -151,7 +166,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
feature_extractor ([`CLIPFeatureExtractor`]):
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
"""
|
||||
_optional_components = ["safety_checker", "feature_extractor"]
|
||||
_optional_components = ["safety_checker"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -313,9 +328,9 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids
|
||||
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
|
||||
if not torch.equal(text_input_ids, untruncated_ids):
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
||||
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
@@ -481,8 +496,22 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
masked_image_latents = 0.18215 * masked_image_latents
|
||||
|
||||
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
|
||||
mask = mask.repeat(batch_size, 1, 1, 1)
|
||||
masked_image_latents = masked_image_latents.repeat(batch_size, 1, 1, 1)
|
||||
if mask.shape[0] < batch_size:
|
||||
if not batch_size % mask.shape[0] == 0:
|
||||
raise ValueError(
|
||||
"The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
|
||||
f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
|
||||
" of masks that you pass is divisible by the total requested batch size."
|
||||
)
|
||||
mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
|
||||
if masked_image_latents.shape[0] < batch_size:
|
||||
if not batch_size % masked_image_latents.shape[0] == 0:
|
||||
raise ValueError(
|
||||
"The passed images and the required batch size don't match. Images are supposed to be duplicated"
|
||||
f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
|
||||
" Make sure the number of images that you pass is divisible by the total requested batch size."
|
||||
)
|
||||
masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1)
|
||||
|
||||
mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
|
||||
masked_image_latents = (
|
||||
|
||||
@@ -92,7 +92,7 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
|
||||
feature_extractor ([`CLIPFeatureExtractor`]):
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
"""
|
||||
_optional_components = ["safety_checker", "feature_extractor"]
|
||||
_optional_components = ["feature_extractor"]
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.__init__
|
||||
def __init__(
|
||||
@@ -261,9 +261,9 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids
|
||||
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
|
||||
if not torch.equal(text_input_ids, untruncated_ids):
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
||||
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
@@ -396,11 +396,9 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
|
||||
def get_timesteps(self, num_inference_steps, strength, device):
|
||||
# get the original timestep using init_timestep
|
||||
offset = self.scheduler.config.get("steps_offset", 0)
|
||||
init_timestep = int(num_inference_steps * strength) + offset
|
||||
init_timestep = min(init_timestep, num_inference_steps)
|
||||
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
|
||||
|
||||
t_start = max(num_inference_steps - init_timestep + offset, 0)
|
||||
t_start = max(num_inference_steps - init_timestep, 0)
|
||||
timesteps = self.scheduler.timesteps[t_start:]
|
||||
|
||||
return timesteps, num_inference_steps - t_start
|
||||
@@ -425,8 +423,8 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
image: Union[torch.FloatTensor, PIL.Image.Image],
|
||||
mask_image: Union[torch.FloatTensor, PIL.Image.Image],
|
||||
image: Union[torch.FloatTensor, PIL.Image.Image] = None,
|
||||
mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
|
||||
strength: float = 0.8,
|
||||
num_inference_steps: Optional[int] = 50,
|
||||
guidance_scale: Optional[float] = 7.5,
|
||||
|
||||
@@ -192,9 +192,9 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline):
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids
|
||||
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
|
||||
if not torch.equal(text_input_ids, untruncated_ids):
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
||||
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
@@ -325,8 +325,8 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline):
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
height: int = 512,
|
||||
width: int = 512,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
num_inference_steps: int = 50,
|
||||
guidance_scale: float = 7.5,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
@@ -345,9 +345,9 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline):
|
||||
Args:
|
||||
prompt (`str` or `List[str]`):
|
||||
The prompt or prompts to guide the image generation.
|
||||
height (`int`, *optional*, defaults to 512):
|
||||
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*, defaults to 512):
|
||||
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The width in pixels of the generated image.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
@@ -393,6 +393,9 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline):
|
||||
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
||||
(nsfw) content, according to the `safety_checker`.
|
||||
"""
|
||||
# 0. Default height and width to unet
|
||||
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
||||
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(prompt, height, width, callback_steps)
|
||||
@@ -436,6 +439,7 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline):
|
||||
# 6. Define model function
|
||||
def model_fn(x, t):
|
||||
latent_model_input = torch.cat([x] * 2)
|
||||
t = torch.cat([t] * 2)
|
||||
|
||||
noise_pred = self.k_diffusion_model(latent_model_input, t, cond=text_embeddings)
|
||||
|
||||
|
||||
@@ -32,15 +32,23 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def preprocess(image):
|
||||
# resize to multiple of 64
|
||||
width, height = image.size
|
||||
width = width - width % 64
|
||||
height = height - height % 64
|
||||
image = image.resize((width, height))
|
||||
if isinstance(image, torch.Tensor):
|
||||
return image
|
||||
elif isinstance(image, PIL.Image.Image):
|
||||
image = [image]
|
||||
|
||||
image = np.array(image.convert("RGB"))
|
||||
image = image[None].transpose(0, 3, 1, 2)
|
||||
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
|
||||
if isinstance(image[0], PIL.Image.Image):
|
||||
w, h = image[0].size
|
||||
w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 32
|
||||
|
||||
image = [np.array(i.resize((w, h)))[None, :] for i in image]
|
||||
image = np.concatenate(image, axis=0)
|
||||
image = np.array(image).astype(np.float32) / 255.0
|
||||
image = image.transpose(0, 3, 1, 2)
|
||||
image = 2.0 * image - 1.0
|
||||
image = torch.from_numpy(image)
|
||||
elif isinstance(image[0], torch.Tensor):
|
||||
image = torch.cat(image, dim=0)
|
||||
return image
|
||||
|
||||
|
||||
@@ -156,9 +164,9 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline):
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids
|
||||
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
|
||||
if not torch.equal(text_input_ids, untruncated_ids):
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
||||
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
@@ -407,10 +415,7 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline):
|
||||
)
|
||||
|
||||
# 4. Preprocess image
|
||||
image = [image] if isinstance(image, PIL.Image.Image) else image
|
||||
if isinstance(image, list):
|
||||
image = [preprocess(img) for img in image]
|
||||
image = torch.cat(image, dim=0)
|
||||
image = preprocess(image)
|
||||
image = image.to(dtype=text_embeddings.dtype, device=device)
|
||||
|
||||
# 5. set timesteps
|
||||
|
||||
@@ -7,7 +7,7 @@ from ...utils import (
|
||||
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.25.0")):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import (
|
||||
|
||||
@@ -134,6 +134,9 @@ class VersatileDiffusionImageVariationPipeline(DiffusionPipeline):
|
||||
embeds = embeds / torch.norm(embeds_pooled, dim=-1, keepdim=True)
|
||||
return embeds
|
||||
|
||||
if isinstance(prompt, torch.Tensor) and len(prompt.shape) == 4:
|
||||
prompt = [p for p in prompt]
|
||||
|
||||
batch_size = len(prompt) if isinstance(prompt, list) else 1
|
||||
|
||||
# get prompt text embeddings
|
||||
@@ -212,9 +215,17 @@ class VersatileDiffusionImageVariationPipeline(DiffusionPipeline):
|
||||
extra_step_kwargs["generator"] = generator
|
||||
return extra_step_kwargs
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_image_variation.StableDiffusionImageVariationPipeline.check_inputs
|
||||
def check_inputs(self, image, height, width, callback_steps):
|
||||
if not isinstance(image, PIL.Image.Image) and not isinstance(image, torch.Tensor):
|
||||
raise ValueError(f"`image` has to be of type `PIL.Image.Image` or `torch.Tensor` but is {type(image)}")
|
||||
if (
|
||||
not isinstance(image, torch.Tensor)
|
||||
and not isinstance(image, PIL.Image.Image)
|
||||
and not isinstance(image, list)
|
||||
):
|
||||
raise ValueError(
|
||||
"`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is"
|
||||
f" {type(image)}"
|
||||
)
|
||||
|
||||
if height % 8 != 0 or width % 8 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
||||
|
||||
@@ -134,7 +134,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
|
||||
" DDIMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
|
||||
)
|
||||
predict_epsilon = deprecate("predict_epsilon", "0.11.0", message, take_from=kwargs)
|
||||
predict_epsilon = deprecate("predict_epsilon", "0.12.0", message, take_from=kwargs)
|
||||
if predict_epsilon is not None:
|
||||
self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")
|
||||
|
||||
|
||||
@@ -138,7 +138,7 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
|
||||
" FlaxDDIMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
|
||||
)
|
||||
predict_epsilon = deprecate("predict_epsilon", "0.11.0", message, take_from=kwargs)
|
||||
predict_epsilon = deprecate("predict_epsilon", "0.12.0", message, take_from=kwargs)
|
||||
if predict_epsilon is not None:
|
||||
self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")
|
||||
|
||||
|
||||
@@ -126,7 +126,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
|
||||
" DDPMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
|
||||
)
|
||||
predict_epsilon = deprecate("predict_epsilon", "0.11.0", message, take_from=kwargs)
|
||||
predict_epsilon = deprecate("predict_epsilon", "0.12.0", message, take_from=kwargs)
|
||||
if predict_epsilon is not None:
|
||||
self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")
|
||||
|
||||
@@ -256,7 +256,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
|
||||
" DDPMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
|
||||
)
|
||||
predict_epsilon = deprecate("predict_epsilon", "0.11.0", message, take_from=kwargs)
|
||||
predict_epsilon = deprecate("predict_epsilon", "0.12.0", message, take_from=kwargs)
|
||||
if predict_epsilon is not None:
|
||||
new_config = dict(self.config)
|
||||
new_config["prediction_type"] = "epsilon" if predict_epsilon else "sample"
|
||||
|
||||
@@ -132,7 +132,7 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
|
||||
" FlaxDDPMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
|
||||
)
|
||||
predict_epsilon = deprecate("predict_epsilon", "0.11.0", message, take_from=kwargs)
|
||||
predict_epsilon = deprecate("predict_epsilon", "0.12.0", message, take_from=kwargs)
|
||||
if predict_epsilon is not None:
|
||||
self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")
|
||||
|
||||
@@ -239,7 +239,7 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
|
||||
" FlaxDDPMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
|
||||
)
|
||||
predict_epsilon = deprecate("predict_epsilon", "0.11.0", message, take_from=kwargs)
|
||||
predict_epsilon = deprecate("predict_epsilon", "0.12.0", message, take_from=kwargs)
|
||||
if predict_epsilon is not None:
|
||||
new_config = dict(self.config)
|
||||
new_config["prediction_type"] = "epsilon" if predict_epsilon else "sample"
|
||||
|
||||
@@ -143,7 +143,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
|
||||
" DPMSolverMultistepScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
|
||||
)
|
||||
predict_epsilon = deprecate("predict_epsilon", "0.11.0", message, take_from=kwargs)
|
||||
predict_epsilon = deprecate("predict_epsilon", "0.12.0", message, take_from=kwargs)
|
||||
if predict_epsilon is not None:
|
||||
self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")
|
||||
|
||||
|
||||
@@ -177,7 +177,7 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
|
||||
" FlaxDPMSolverMultistepScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
|
||||
)
|
||||
predict_epsilon = deprecate("predict_epsilon", "0.11.0", message, take_from=kwargs)
|
||||
predict_epsilon = deprecate("predict_epsilon", "0.12.0", message, take_from=kwargs)
|
||||
if predict_epsilon is not None:
|
||||
self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")
|
||||
|
||||
|
||||
@@ -204,7 +204,7 @@ try:
|
||||
if _torch_available:
|
||||
import torch
|
||||
|
||||
if torch.__version__ < version.Version("1.12"):
|
||||
if version.Version(torch.__version__) < version.Version("1.12"):
|
||||
raise ValueError("PyTorch should be >= 1.12")
|
||||
logger.debug(f"Successfully imported xformers version {_xformers_version}")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
@@ -354,7 +354,20 @@ def requires_backends(obj, backends):
|
||||
if failed:
|
||||
raise ImportError("".join(failed))
|
||||
|
||||
if name in ["StableDiffusionDepth2ImgPipeline"] and is_transformers_version("<", "4.26.0.dev0"):
|
||||
if name in [
|
||||
"VersatileDiffusionTextToImagePipeline",
|
||||
"VersatileDiffusionPipeline",
|
||||
"VersatileDiffusionDualGuidedPipeline",
|
||||
"StableDiffusionImageVariationPipeline",
|
||||
] and is_transformers_version("<", "4.25.0"):
|
||||
raise ImportError(
|
||||
f"You need to install `transformers>=4.25` in order to use {name}: \n```\n pip install"
|
||||
" --upgrade transformers \n```"
|
||||
)
|
||||
|
||||
if name in [
|
||||
"StableDiffusionDepth2ImgPipeline",
|
||||
] and is_transformers_version("<", "4.26.0.dev0"):
|
||||
raise ImportError(
|
||||
f"You need to install `transformers` from 'main' in order to use {name}: \n```\n pip install"
|
||||
" git+https://github.com/huggingface/transformers \n```"
|
||||
|
||||
@@ -30,6 +30,7 @@ from diffusers.utils import (
|
||||
torch_all_close,
|
||||
torch_device,
|
||||
)
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from parameterized import parameterized
|
||||
|
||||
from ..test_modeling_common import ModelTesterMixin
|
||||
@@ -255,6 +256,20 @@ class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
@unittest.skipIf(
|
||||
torch_device != "cuda" or not is_xformers_available(),
|
||||
reason="XFormers attention is only available with CUDA and `xformers` installed",
|
||||
)
|
||||
def test_xformers_enable_works(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict)
|
||||
|
||||
model.enable_xformers_memory_efficient_attention()
|
||||
|
||||
assert (
|
||||
model.mid_block.attentions[0].transformer_blocks[0].attn1._use_memory_efficient_attention_xformers
|
||||
), "xformers is not enabled"
|
||||
|
||||
@unittest.skipIf(torch_device == "mps", "Gradient checkpointing skipped on MPS")
|
||||
def test_gradient_checkpointing(self):
|
||||
# enable deterministic behavior for gradient checkpointing
|
||||
|
||||
@@ -64,6 +64,7 @@ class DanceDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
inputs = {
|
||||
"batch_size": 1,
|
||||
"generator": generator,
|
||||
"num_inference_steps": 4,
|
||||
}
|
||||
|
||||
@@ -52,6 +52,7 @@ class DDIMPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
inputs = {
|
||||
"batch_size": 1,
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"output_type": "numpy",
|
||||
|
||||
@@ -67,7 +67,7 @@ class DDPMPipelineFastTests(unittest.TestCase):
|
||||
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
def test_inference_deprecated_predict_epsilon(self):
|
||||
deprecate("remove this test", "0.11.0", "remove")
|
||||
deprecate("remove this test", "0.12.0", "remove")
|
||||
unet = self.dummy_uncond_unet
|
||||
scheduler = DDPMScheduler(predict_epsilon=False)
|
||||
|
||||
|
||||
@@ -25,7 +25,7 @@ from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder
|
||||
from diffusers.utils import floats_tensor, load_image, slow, torch_device
|
||||
from diffusers.utils.testing_utils import require_torch_gpu
|
||||
from PIL import Image
|
||||
from transformers import CLIPVisionConfig
|
||||
from transformers import CLIPImageProcessor, CLIPVisionConfig
|
||||
|
||||
from ...test_pipelines_common import PipelineTesterMixin
|
||||
|
||||
@@ -76,6 +76,7 @@ class PaintByExamplePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
patch_size=4,
|
||||
)
|
||||
image_encoder = PaintByExampleImageEncoder(config, proj_size=32)
|
||||
feature_extractor = CLIPImageProcessor(crop_size=32, size=32)
|
||||
|
||||
components = {
|
||||
"unet": unet,
|
||||
@@ -83,7 +84,7 @@ class PaintByExamplePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
"vae": vae,
|
||||
"image_encoder": image_encoder,
|
||||
"safety_checker": None,
|
||||
"feature_extractor": None,
|
||||
"feature_extractor": feature_extractor,
|
||||
}
|
||||
return components
|
||||
|
||||
@@ -100,7 +101,6 @@ class PaintByExamplePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
init_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((64, 64))
|
||||
mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((64, 64))
|
||||
example_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((32, 32))
|
||||
example_image = self.convert_to_pt(example_image)
|
||||
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
|
||||
@@ -29,7 +29,8 @@ from diffusers import (
|
||||
)
|
||||
from diffusers.utils import floats_tensor, load_image, load_numpy, slow, torch_device
|
||||
from diffusers.utils.testing_utils import require_torch_gpu
|
||||
from transformers import CLIPVisionConfig, CLIPVisionModelWithProjection
|
||||
from PIL import Image
|
||||
from transformers import CLIPImageProcessor, CLIPVisionConfig, CLIPVisionModelWithProjection
|
||||
|
||||
from ...test_pipelines_common import PipelineTesterMixin
|
||||
|
||||
@@ -74,19 +75,22 @@ class StableDiffusionImageVariationPipelineFastTests(PipelineTesterMixin, unitte
|
||||
patch_size=4,
|
||||
)
|
||||
image_encoder = CLIPVisionModelWithProjection(image_encoder_config)
|
||||
feature_extractor = CLIPImageProcessor(crop_size=32, size=32)
|
||||
|
||||
components = {
|
||||
"unet": unet,
|
||||
"scheduler": scheduler,
|
||||
"vae": vae,
|
||||
"image_encoder": image_encoder,
|
||||
"feature_extractor": feature_extractor,
|
||||
"safety_checker": None,
|
||||
"feature_extractor": None,
|
||||
}
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
|
||||
image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed))
|
||||
image = image.cpu().permute(0, 2, 3, 1)[0]
|
||||
image = Image.fromarray(np.uint8(image)).convert("RGB").resize((32, 32))
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
@@ -112,7 +116,7 @@ class StableDiffusionImageVariationPipelineFastTests(PipelineTesterMixin, unitte
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 64, 64, 3)
|
||||
expected_slice = np.array([0.5093, 0.5717, 0.4806, 0.4891, 0.5552, 0.4594, 0.5177, 0.4894, 0.4904])
|
||||
expected_slice = np.array([0.5167, 0.5746, 0.4835, 0.4914, 0.5605, 0.4691, 0.5201, 0.4898, 0.4958])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
||||
|
||||
def test_stable_diffusion_img_variation_multiple_images(self):
|
||||
@@ -123,7 +127,7 @@ class StableDiffusionImageVariationPipelineFastTests(PipelineTesterMixin, unitte
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
inputs["image"] = inputs["image"].repeat(2, 1, 1, 1)
|
||||
inputs["image"] = 2 * [inputs["image"]]
|
||||
output = sd_pipe(**inputs)
|
||||
|
||||
image = output.images
|
||||
@@ -131,7 +135,7 @@ class StableDiffusionImageVariationPipelineFastTests(PipelineTesterMixin, unitte
|
||||
image_slice = image[-1, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (2, 64, 64, 3)
|
||||
expected_slice = np.array([0.6427, 0.5452, 0.5602, 0.5478, 0.5968, 0.6211, 0.5538, 0.5514, 0.5281])
|
||||
expected_slice = np.array([0.6568, 0.5470, 0.5684, 0.5444, 0.5945, 0.6221, 0.5508, 0.5531, 0.5263])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
||||
|
||||
def test_stable_diffusion_img_variation_num_images_per_prompt(self):
|
||||
@@ -150,7 +154,7 @@ class StableDiffusionImageVariationPipelineFastTests(PipelineTesterMixin, unitte
|
||||
# test num_images_per_prompt=1 (default) for batch of images
|
||||
batch_size = 2
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
inputs["image"] = inputs["image"].repeat(batch_size, 1, 1, 1)
|
||||
inputs["image"] = batch_size * [inputs["image"]]
|
||||
images = sd_pipe(**inputs).images
|
||||
|
||||
assert images.shape == (batch_size, 64, 64, 3)
|
||||
@@ -165,7 +169,7 @@ class StableDiffusionImageVariationPipelineFastTests(PipelineTesterMixin, unitte
|
||||
# test num_images_per_prompt for batch of prompts
|
||||
batch_size = 2
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
inputs["image"] = inputs["image"].repeat(batch_size, 1, 1, 1)
|
||||
inputs["image"] = batch_size * [inputs["image"]]
|
||||
images = sd_pipe(**inputs, num_images_per_prompt=num_images_per_prompt).images
|
||||
|
||||
assert images.shape == (batch_size * num_images_per_prompt, 64, 64, 3)
|
||||
|
||||
@@ -30,7 +30,7 @@ from diffusers import (
|
||||
)
|
||||
from diffusers.utils import floats_tensor, load_image, load_numpy, slow, torch_device
|
||||
from diffusers.utils.testing_utils import require_torch_gpu
|
||||
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
|
||||
from transformers import CLIPImageProcessor, CLIPTextConfig, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...test_pipelines_common import PipelineTesterMixin
|
||||
|
||||
@@ -77,6 +77,7 @@ class StableDiffusionImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.Test
|
||||
)
|
||||
text_encoder = CLIPTextModel(text_encoder_config)
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
feature_extractor = CLIPImageProcessor(crop_size=32, size=32)
|
||||
|
||||
components = {
|
||||
"unet": unet,
|
||||
@@ -85,7 +86,7 @@ class StableDiffusionImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.Test
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
"safety_checker": None,
|
||||
"feature_extractor": None,
|
||||
"feature_extractor": feature_extractor,
|
||||
}
|
||||
return components
|
||||
|
||||
|
||||
@@ -31,7 +31,7 @@ from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint impo
|
||||
from diffusers.utils import floats_tensor, load_image, load_numpy, slow, torch_device
|
||||
from diffusers.utils.testing_utils import require_torch_gpu
|
||||
from PIL import Image
|
||||
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
|
||||
from transformers import CLIPImageProcessor, CLIPTextConfig, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...test_pipelines_common import PipelineTesterMixin
|
||||
|
||||
@@ -78,6 +78,7 @@ class StableDiffusionInpaintPipelineFastTests(PipelineTesterMixin, unittest.Test
|
||||
)
|
||||
text_encoder = CLIPTextModel(text_encoder_config)
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
feature_extractor = CLIPImageProcessor(crop_size=32, size=32)
|
||||
|
||||
components = {
|
||||
"unet": unet,
|
||||
@@ -86,7 +87,7 @@ class StableDiffusionInpaintPipelineFastTests(PipelineTesterMixin, unittest.Test
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
"safety_checker": None,
|
||||
"feature_extractor": None,
|
||||
"feature_extractor": feature_extractor,
|
||||
}
|
||||
return components
|
||||
|
||||
|
||||
@@ -136,7 +136,9 @@ class StableDiffusiondepth2imgPipelineFastTests(PipelineTesterMixin, unittest.Te
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
|
||||
image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed))
|
||||
image = image.cpu().permute(0, 2, 3, 1)[0]
|
||||
image = Image.fromarray(np.uint8(image)).convert("RGB").resize((32, 32))
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
@@ -171,7 +173,7 @@ class StableDiffusiondepth2imgPipelineFastTests(PipelineTesterMixin, unittest.Te
|
||||
output_loaded = pipe_loaded(**inputs)[0]
|
||||
|
||||
max_diff = np.abs(output - output_loaded).max()
|
||||
self.assertLess(max_diff, 3e-5)
|
||||
self.assertLess(max_diff, 1e-4)
|
||||
|
||||
@unittest.skipIf(torch_device != "cuda", reason="float16 requires CUDA")
|
||||
def test_save_load_float16(self):
|
||||
@@ -243,7 +245,7 @@ class StableDiffusiondepth2imgPipelineFastTests(PipelineTesterMixin, unittest.Te
|
||||
output_with_offload = pipe(**inputs)[0]
|
||||
|
||||
max_diff = np.abs(output_with_offload - output_without_offload).max()
|
||||
self.assertLess(max_diff, 3e-5, "CPU offloading should not affect the inference results")
|
||||
self.assertLess(max_diff, 1e-4, "CPU offloading should not affect the inference results")
|
||||
|
||||
@unittest.skipIf(torch_device == "mps", reason="The depth model does not support MPS yet")
|
||||
def test_dict_tuple_outputs_equivalent(self):
|
||||
@@ -260,7 +262,7 @@ class StableDiffusiondepth2imgPipelineFastTests(PipelineTesterMixin, unittest.Te
|
||||
output_tuple = pipe(**self.get_dummy_inputs(torch_device), return_dict=False)[0]
|
||||
|
||||
max_diff = np.abs(output - output_tuple).max()
|
||||
self.assertLess(max_diff, 3e-5)
|
||||
self.assertLess(max_diff, 1e-4)
|
||||
|
||||
@unittest.skipIf(torch_device == "mps", reason="The depth model does not support MPS yet")
|
||||
def test_num_inference_steps_consistent(self):
|
||||
@@ -285,7 +287,7 @@ class StableDiffusiondepth2imgPipelineFastTests(PipelineTesterMixin, unittest.Te
|
||||
if torch_device == "mps":
|
||||
expected_slice = np.array([0.6071, 0.5035, 0.4378, 0.5776, 0.5753, 0.4316, 0.4513, 0.5263, 0.4546])
|
||||
else:
|
||||
expected_slice = np.array([0.6907, 0.5135, 0.4688, 0.5169, 0.5738, 0.4600, 0.4435, 0.5640, 0.4653])
|
||||
expected_slice = np.array([0.6854, 0.3740, 0.4857, 0.7130, 0.7403, 0.5536, 0.4829, 0.6182, 0.5053])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
||||
|
||||
def test_stable_diffusion_depth2img_negative_prompt(self):
|
||||
@@ -305,7 +307,7 @@ class StableDiffusiondepth2imgPipelineFastTests(PipelineTesterMixin, unittest.Te
|
||||
if torch_device == "mps":
|
||||
expected_slice = np.array([0.5825, 0.5135, 0.4095, 0.5452, 0.6059, 0.4211, 0.3994, 0.5177, 0.4335])
|
||||
else:
|
||||
expected_slice = np.array([0.755, 0.521, 0.473, 0.554, 0.629, 0.442, 0.440, 0.582, 0.449])
|
||||
expected_slice = np.array([0.6074, 0.3096, 0.4802, 0.7463, 0.7388, 0.5393, 0.4531, 0.5928, 0.4972])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
||||
|
||||
def test_stable_diffusion_depth2img_multiple_init_images(self):
|
||||
@@ -317,7 +319,7 @@ class StableDiffusiondepth2imgPipelineFastTests(PipelineTesterMixin, unittest.Te
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
inputs["prompt"] = [inputs["prompt"]] * 2
|
||||
inputs["image"] = inputs["image"].repeat(2, 1, 1, 1)
|
||||
inputs["image"] = 2 * [inputs["image"]]
|
||||
image = sd_pipe(**inputs).images
|
||||
image_slice = image[-1, -3:, -3:, -1]
|
||||
|
||||
@@ -326,7 +328,7 @@ class StableDiffusiondepth2imgPipelineFastTests(PipelineTesterMixin, unittest.Te
|
||||
if torch_device == "mps":
|
||||
expected_slice = np.array([0.6501, 0.5150, 0.4939, 0.6688, 0.5437, 0.5758, 0.5115, 0.4406, 0.4551])
|
||||
else:
|
||||
expected_slice = np.array([0.6475, 0.6302, 0.5627, 0.5222, 0.4318, 0.5489, 0.5079, 0.4419, 0.4494])
|
||||
expected_slice = np.array([0.6681, 0.5023, 0.6611, 0.7605, 0.5724, 0.7959, 0.7240, 0.5871, 0.5383])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
||||
|
||||
def test_stable_diffusion_depth2img_num_images_per_prompt(self):
|
||||
@@ -374,7 +376,6 @@ class StableDiffusiondepth2imgPipelineFastTests(PipelineTesterMixin, unittest.Te
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
|
||||
inputs["image"] = Image.fromarray(inputs["image"][0].permute(1, 2, 0).numpy().astype(np.uint8))
|
||||
image = sd_pipe(**inputs).images
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
@@ -452,7 +453,7 @@ class StableDiffusionDepth2ImgPipelineIntegrationTests(unittest.TestCase):
|
||||
image = output.images[0]
|
||||
|
||||
assert image.shape == (480, 640, 3)
|
||||
assert np.abs(expected_image - image).max() < 1e-3
|
||||
assert np.abs(expected_image - image).max() < 5e-3
|
||||
|
||||
def test_stable_diffusion_depth2img_pipeline_ddim(self):
|
||||
init_image = load_image(
|
||||
@@ -540,8 +541,7 @@ class StableDiffusionDepth2ImgPipelineIntegrationTests(unittest.TestCase):
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
init_image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
|
||||
"/depth2img/sketch-mountains-input.jpg"
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/depth2img/two_cats.png"
|
||||
)
|
||||
init_image = init_image.resize((768, 512))
|
||||
|
||||
@@ -565,7 +565,7 @@ class StableDiffusionDepth2ImgPipelineIntegrationTests(unittest.TestCase):
|
||||
guidance_scale=7.5,
|
||||
generator=generator,
|
||||
output_type="np",
|
||||
num_inference_steps=5,
|
||||
num_inference_steps=2,
|
||||
)
|
||||
|
||||
mem_bytes = torch.cuda.max_memory_allocated()
|
||||
|
||||
@@ -24,7 +24,7 @@ from diffusers import AutoencoderKL, PNDMScheduler, StableDiffusionInpaintPipeli
|
||||
from diffusers.utils import floats_tensor, load_image, load_numpy, torch_device
|
||||
from diffusers.utils.testing_utils import require_torch_gpu, slow
|
||||
from PIL import Image
|
||||
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
|
||||
from transformers import CLIPImageProcessor, CLIPTextConfig, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...test_pipelines_common import PipelineTesterMixin
|
||||
|
||||
@@ -78,6 +78,7 @@ class StableDiffusion2InpaintPipelineFastTests(PipelineTesterMixin, unittest.Tes
|
||||
)
|
||||
text_encoder = CLIPTextModel(text_encoder_config)
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
feature_extractor = CLIPImageProcessor(crop_size=32, size=32)
|
||||
|
||||
components = {
|
||||
"unet": unet,
|
||||
@@ -86,7 +87,7 @@ class StableDiffusion2InpaintPipelineFastTests(PipelineTesterMixin, unittest.Tes
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
"safety_checker": None,
|
||||
"feature_extractor": None,
|
||||
"feature_extractor": feature_extractor,
|
||||
}
|
||||
return components
|
||||
|
||||
|
||||
@@ -203,7 +203,7 @@ class ConfigTester(unittest.TestCase):
|
||||
ddpm_2 = DDPMScheduler.from_pretrained("google/ddpm-celebahq-256", beta_start=88)
|
||||
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
deprecate("remove this case", "0.11.0", "remove")
|
||||
deprecate("remove this case", "0.12.0", "remove")
|
||||
ddpm_3 = DDPMScheduler.from_pretrained(
|
||||
"hf-internal-testing/tiny-stable-diffusion-torch",
|
||||
subfolder="scheduler",
|
||||
|
||||
@@ -18,6 +18,7 @@ import json
|
||||
import os
|
||||
import random
|
||||
import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
@@ -207,6 +208,31 @@ class CustomPipelineTests(unittest.TestCase):
|
||||
# under https://huggingface.co/hf-internal-testing/diffusers-dummy-pipeline/blob/main/pipeline.py#L24
|
||||
assert pipeline.__class__.__name__ == "CustomPipeline"
|
||||
|
||||
def test_load_custom_github(self):
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"google/ddpm-cifar10-32", custom_pipeline="one_step_unet", custom_revision="main"
|
||||
)
|
||||
|
||||
# make sure that on "main" pipeline gives only ones because of: https://github.com/huggingface/diffusers/pull/1690
|
||||
with torch.no_grad():
|
||||
output = pipeline()
|
||||
|
||||
assert output.numel() == output.sum()
|
||||
|
||||
# hack since Python doesn't like overwriting modules: https://stackoverflow.com/questions/3105801/unload-a-module-in-python
|
||||
# Could in the future work with hashes instead.
|
||||
del sys.modules["diffusers_modules.git.one_step_unet"]
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"google/ddpm-cifar10-32", custom_pipeline="one_step_unet", custom_revision="0.10.2"
|
||||
)
|
||||
with torch.no_grad():
|
||||
output = pipeline()
|
||||
|
||||
assert output.numel() != output.sum()
|
||||
|
||||
assert pipeline.__class__.__name__ == "UnetSchedulerOneForwardPipeline"
|
||||
|
||||
def test_run_custom_pipeline(self):
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"google/ddpm-cifar10-32", custom_pipeline="hf-internal-testing/diffusers-dummy-pipeline"
|
||||
|
||||
@@ -11,6 +11,7 @@ from typing import Callable, Union
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import diffusers
|
||||
from diffusers import (
|
||||
CycleDiffusionPipeline,
|
||||
DanceDiffusionPipeline,
|
||||
@@ -18,6 +19,7 @@ from diffusers import (
|
||||
StableDiffusionDepth2ImgPipeline,
|
||||
StableDiffusionImg2ImgPipeline,
|
||||
)
|
||||
from diffusers.utils import logging
|
||||
from diffusers.utils.import_utils import is_accelerate_available, is_xformers_available
|
||||
from diffusers.utils.testing_utils import require_torch, torch_device
|
||||
|
||||
@@ -25,6 +27,9 @@ from diffusers.utils.testing_utils import require_torch, torch_device
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
|
||||
|
||||
ALLOWED_REQUIRED_ARGS = ["source_prompt", "prompt", "image", "mask_image", "example_image"]
|
||||
|
||||
|
||||
@require_torch
|
||||
class PipelineTesterMixin:
|
||||
"""
|
||||
@@ -94,7 +99,80 @@ class PipelineTesterMixin:
|
||||
output_loaded = pipe_loaded(**inputs)[0]
|
||||
|
||||
max_diff = np.abs(output - output_loaded).max()
|
||||
self.assertLess(max_diff, 1e-5)
|
||||
self.assertLess(max_diff, 1e-4)
|
||||
|
||||
def test_pipeline_call_implements_required_args(self):
|
||||
assert hasattr(self.pipeline_class, "__call__"), f"{self.pipeline_class} should have a `__call__` method"
|
||||
parameters = inspect.signature(self.pipeline_class.__call__).parameters
|
||||
required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty}
|
||||
required_parameters.pop("self")
|
||||
required_parameters = set(required_parameters)
|
||||
optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty})
|
||||
|
||||
for param in required_parameters:
|
||||
if param == "kwargs":
|
||||
# kwargs can be added if arguments of pipeline call function are deprecated
|
||||
continue
|
||||
assert param in ALLOWED_REQUIRED_ARGS
|
||||
|
||||
optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty})
|
||||
|
||||
required_optional_params = ["generator", "num_inference_steps", "return_dict"]
|
||||
for param in required_optional_params:
|
||||
assert param in optional_parameters
|
||||
|
||||
def test_inference_batch_consistent(self):
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
|
||||
logger = logging.get_logger(pipe.__module__)
|
||||
logger.setLevel(level=diffusers.logging.FATAL)
|
||||
|
||||
# batchify inputs
|
||||
for batch_size in [2, 4, 13]:
|
||||
batched_inputs = {}
|
||||
for name, value in inputs.items():
|
||||
if name in ALLOWED_REQUIRED_ARGS:
|
||||
# prompt is string
|
||||
if name == "prompt":
|
||||
len_prompt = len(value)
|
||||
# make unequal batch sizes
|
||||
batched_inputs[name] = [value[: len_prompt // i] for i in range(1, batch_size + 1)]
|
||||
|
||||
# make last batch super long
|
||||
batched_inputs[name][-1] = 2000 * "very long"
|
||||
# or else we have images
|
||||
else:
|
||||
batched_inputs[name] = batch_size * [value]
|
||||
elif name == "batch_size":
|
||||
batched_inputs[name] = batch_size
|
||||
else:
|
||||
batched_inputs[name] = value
|
||||
|
||||
batched_inputs["num_inference_steps"] = inputs["num_inference_steps"]
|
||||
batched_inputs["output_type"] = None
|
||||
|
||||
if self.pipeline_class.__name__ == "DanceDiffusionPipeline":
|
||||
batched_inputs.pop("output_type")
|
||||
|
||||
output = pipe(**batched_inputs)
|
||||
|
||||
assert len(output[0]) == batch_size
|
||||
|
||||
batched_inputs["output_type"] = "np"
|
||||
|
||||
if self.pipeline_class.__name__ == "DanceDiffusionPipeline":
|
||||
batched_inputs.pop("output_type")
|
||||
|
||||
output = pipe(**batched_inputs)[0]
|
||||
|
||||
assert output.shape[0] == batch_size
|
||||
|
||||
logger.setLevel(level=diffusers.logging.WARNING)
|
||||
|
||||
def test_dict_tuple_outputs_equivalent(self):
|
||||
if torch_device == "mps" and self.pipeline_class in (
|
||||
@@ -118,13 +196,7 @@ class PipelineTesterMixin:
|
||||
output_tuple = pipe(**self.get_dummy_inputs(torch_device), return_dict=False)[0]
|
||||
|
||||
max_diff = np.abs(output - output_tuple).max()
|
||||
self.assertLess(max_diff, 1e-5)
|
||||
|
||||
def test_pipeline_call_implements_required_args(self):
|
||||
required_args = ["num_inference_steps", "generator", "return_dict"]
|
||||
|
||||
for arg in required_args:
|
||||
self.assertTrue(arg in inspect.signature(self.pipeline_class.__call__).parameters)
|
||||
self.assertLess(max_diff, 1e-4)
|
||||
|
||||
def test_num_inference_steps_consistent(self):
|
||||
components = self.get_dummy_components()
|
||||
@@ -138,7 +210,7 @@ class PipelineTesterMixin:
|
||||
|
||||
outputs = []
|
||||
times = []
|
||||
for num_steps in [3, 6, 9]:
|
||||
for num_steps in [9, 6, 3]:
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
inputs["num_inference_steps"] = num_steps
|
||||
|
||||
@@ -152,7 +224,7 @@ class PipelineTesterMixin:
|
||||
# check that all outputs have the same shape
|
||||
self.assertTrue(all(outputs[0].shape == output.shape for output in outputs))
|
||||
# check that the inference time increases with the number of inference steps
|
||||
self.assertTrue(all(times[i] > times[i - 1] for i in range(1, len(times))))
|
||||
self.assertTrue(all(times[i] < times[i - 1] for i in range(1, len(times))))
|
||||
|
||||
def test_components_function(self):
|
||||
init_components = self.get_dummy_components()
|
||||
@@ -257,7 +329,7 @@ class PipelineTesterMixin:
|
||||
output_loaded = pipe_loaded(**inputs)[0]
|
||||
|
||||
max_diff = np.abs(output - output_loaded).max()
|
||||
self.assertLess(max_diff, 1e-5)
|
||||
self.assertLess(max_diff, 1e-4)
|
||||
|
||||
@unittest.skipIf(torch_device != "cuda", reason="CUDA and CPU are required to switch devices")
|
||||
def test_to_device(self):
|
||||
@@ -332,7 +404,7 @@ class PipelineTesterMixin:
|
||||
output_with_offload = pipe(**inputs)[0]
|
||||
|
||||
max_diff = np.abs(output_with_offload - output_without_offload).max()
|
||||
self.assertLess(max_diff, 1e-5, "CPU offloading should not affect the inference results")
|
||||
self.assertLess(max_diff, 1e-4, "CPU offloading should not affect the inference results")
|
||||
|
||||
@unittest.skipIf(
|
||||
torch_device != "cuda" or not is_xformers_available(),
|
||||
@@ -355,7 +427,7 @@ class PipelineTesterMixin:
|
||||
output_with_offload = pipe(**inputs)[0]
|
||||
|
||||
max_diff = np.abs(output_with_offload - output_without_offload).max()
|
||||
self.assertLess(max_diff, 1e-5, "XFormers attention should not affect the inference results")
|
||||
self.assertLess(max_diff, 1e-4, "XFormers attention should not affect the inference results")
|
||||
|
||||
def test_progress_bar(self):
|
||||
components = self.get_dummy_components()
|
||||
|
||||
@@ -642,12 +642,12 @@ class DDPMSchedulerTest(SchedulerCommonTest):
|
||||
self.check_over_configs(prediction_type=prediction_type)
|
||||
|
||||
def test_deprecated_predict_epsilon(self):
|
||||
deprecate("remove this test", "0.11.0", "remove")
|
||||
deprecate("remove this test", "0.12.0", "remove")
|
||||
for predict_epsilon in [True, False]:
|
||||
self.check_over_configs(predict_epsilon=predict_epsilon)
|
||||
|
||||
def test_deprecated_epsilon(self):
|
||||
deprecate("remove this test", "0.11.0", "remove")
|
||||
deprecate("remove this test", "0.12.0", "remove")
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
|
||||
|
||||
@@ -626,12 +626,12 @@ class FlaxDDIMSchedulerTest(FlaxSchedulerCommonTest):
|
||||
self.check_over_configs(prediction_type=prediction_type)
|
||||
|
||||
def test_deprecated_predict_epsilon(self):
|
||||
deprecate("remove this test", "0.11.0", "remove")
|
||||
deprecate("remove this test", "0.12.0", "remove")
|
||||
for predict_epsilon in [True, False]:
|
||||
self.check_over_configs(predict_epsilon=predict_epsilon)
|
||||
|
||||
def test_deprecated_predict_epsilon_to_prediction_type(self):
|
||||
deprecate("remove this test", "0.11.0", "remove")
|
||||
deprecate("remove this test", "0.12.0", "remove")
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
scheduler_config = self.get_scheduler_config(predict_epsilon=True)
|
||||
scheduler = scheduler_class.from_config(scheduler_config)
|
||||
|
||||
Reference in New Issue
Block a user