Compare commits

..

16 Commits

Author SHA1 Message Date
thomasw21
c11d4b42a7 Context has its own sequence length 2022-11-22 12:26:59 +01:00
thomasw21
d5af4fd153 Woops 2022-11-22 12:21:59 +01:00
thomasw21
1f135ac219 Woops 2022-11-22 12:20:26 +01:00
thomasw21
0c49f4cf30 Woops 2022-11-22 12:19:15 +01:00
thomasw21
5d4145cfa2 Remove transpose for baddbmm 2022-11-22 12:17:11 +01:00
thomasw21
31d26872c1 Revert "Making hidden_state contiguous before applying multiple linear layers"
This reverts commit 1cd09cccf3.
2022-11-22 12:07:54 +01:00
thomasw21
1cd09cccf3 Making hidden_state contiguous before applying multiple linear layers 2022-11-22 11:55:03 +01:00
thomasw21
fa4d738cbb Revert "Save one more copy" as it's much slower on A100
This reverts commit 136f84283c.
2022-11-22 11:53:58 +01:00
thomasw21
136f84283c Save one more copy 2022-11-22 11:49:15 +01:00
thomasw21
42ba85998f scatter_ argument is not called src, but rather value 2022-11-22 01:11:18 +01:00
thomasw21
e1623e2081 Woops 2022-11-22 01:02:24 +01:00
thomasw21
fdef40ba03 Woops 2022-11-22 00:57:19 +01:00
thomasw21
fe691feb5a Remove unused import 2022-11-22 00:52:53 +01:00
thomasw21
f2ed5d8b44 black 2022-11-22 00:48:50 +01:00
thomasw21
e43244f33a Fix transpose issue 2022-11-22 00:47:22 +01:00
thomasw21
3c45926a0e WIP: some optimizations 2022-11-22 00:37:17 +01:00
223 changed files with 3899 additions and 21967 deletions

View File

@@ -5,20 +5,7 @@ body:
- type: markdown
attributes:
value: |
Thanks a lot for taking the time to file this issue 🤗.
Issues do not only help to improve the library, but also publicly document common problems, questions, workflows for the whole community!
Thus, issues are of the same importance as pull requests when contributing to this library ❤️.
In order to make your issue as **useful for the community as possible**, let's try to stick to some simple guidelines:
- 1. Please try to be as precise and concise as possible.
*Give your issue a fitting title. Assume that someone which very limited knowledge of diffusers can understand your issue. Add links to the source code, documentation other issues, pull requests etc...*
- 2. If your issue is about something not working, **always** provide a reproducible code snippet. The reader should be able to reproduce your issue by **only copy-pasting your code snippet into a Python shell**.
*The community cannot solve your issue if it cannot reproduce it. If your bug is related to training, add your training script and make everything needed to train public. Otherwise, just add a simple Python code snippet.*
- 3. Add the **minimum amount of code / context that is needed to understand, reproduce your issue**.
*Make the life of maintainers easy. `diffusers` is getting many issues every day. Make sure your issue is about one bug and one bug only. Make sure you add only the context, code needed to understand your issues - nothing more. Generally, every issue is a way of documenting this library, try to make it a good documentation entry.*
- type: markdown
attributes:
value: |
For more in-detail information on how to write good issues you can have a look [here](https://huggingface.co/course/chapter8/5?fw=pt)
Thanks for taking the time to fill out this bug report!
- type: textarea
id: bug-description
attributes:
@@ -33,8 +20,6 @@ body:
label: Reproduction
description: Please provide a minimal reproducible code which we can copy/paste and reproduce the issue.
placeholder: Reproduction
validations:
required: true
- type: textarea
id: logs
attributes:

View File

@@ -1,66 +0,0 @@
name: Nightly integration tests
on:
schedule:
- cron: "0 0 * * *" # every day at midnight
env:
DIFFUSERS_IS_CI: yes
HF_HOME: /mnt/cache
OMP_NUM_THREADS: 8
MKL_NUM_THREADS: 8
PYTEST_TIMEOUT: 1000
RUN_SLOW: yes
jobs:
run_slow_tests_apple_m1:
name: Slow PyTorch MPS tests on MacOS
runs-on: [ self-hosted, apple-m1 ]
steps:
- name: Checkout diffusers
uses: actions/checkout@v3
with:
fetch-depth: 2
- name: Clean checkout
shell: arch -arch arm64 bash {0}
run: |
git clean -fxd
- name: Setup miniconda
uses: ./.github/actions/setup-miniconda
with:
python-version: 3.9
- name: Install dependencies
shell: arch -arch arm64 bash {0}
run: |
${CONDA_RUN} python -m pip install --upgrade pip
${CONDA_RUN} python -m pip install -e .[quality,test]
${CONDA_RUN} python -m pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu
${CONDA_RUN} python -m pip install git+https://github.com/huggingface/accelerate
- name: Environment
shell: arch -arch arm64 bash {0}
run: |
${CONDA_RUN} python utils/print_env.py
- name: Run slow PyTorch tests on M1 (MPS)
shell: arch -arch arm64 bash {0}
env:
HF_HOME: /System/Volumes/Data/mnt/cache
HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }}
run: |
${CONDA_RUN} python -m pytest -n 1 -s -v --make-reports=tests_torch_mps tests/
- name: Failure short reports
if: ${{ failure() }}
run: cat reports/tests_torch_mps_failures_short.txt
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v2
with:
name: torch_mps_test_reports
path: reports

View File

@@ -14,6 +14,7 @@ env:
OMP_NUM_THREADS: 4
MKL_NUM_THREADS: 4
PYTEST_TIMEOUT: 60
MPS_TORCH_VERSION: 1.13.0
jobs:
run_fast_tests:
@@ -57,10 +58,8 @@ jobs:
- name: Install dependencies
run: |
apt-get update && apt-get install libsndfile1-dev -y
python -m pip install -e .[quality,test]
python -m pip install git+https://github.com/huggingface/accelerate
python -m pip install -U git+https://github.com/huggingface/transformers
- name: Environment
run: |
@@ -126,9 +125,8 @@ jobs:
run: |
${CONDA_RUN} python -m pip install --upgrade pip
${CONDA_RUN} python -m pip install -e .[quality,test]
${CONDA_RUN} python -m pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu
${CONDA_RUN} python -m pip install --pre torch==${MPS_TORCH_VERSION} --extra-index-url https://download.pytorch.org/whl/test/cpu
${CONDA_RUN} python -m pip install git+https://github.com/huggingface/accelerate
${CONDA_RUN} python -m pip install -U git+https://github.com/huggingface/transformers
- name: Environment
shell: arch -arch arm64 bash {0}
@@ -137,9 +135,6 @@ jobs:
- name: Run fast PyTorch tests on M1 (MPS)
shell: arch -arch arm64 bash {0}
env:
HF_HOME: /System/Volumes/Data/mnt/cache
HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }}
run: |
${CONDA_RUN} python -m pytest -n 0 -s -v --make-reports=tests_torch_mps tests/

View File

@@ -62,7 +62,6 @@ jobs:
run: |
python -m pip install -e .[quality,test]
python -m pip install git+https://github.com/huggingface/accelerate
python -m pip install -U git+https://github.com/huggingface/transformers
- name: Environment
run: |
@@ -132,7 +131,6 @@ jobs:
run: |
python -m pip install -e .[quality,test,training]
python -m pip install git+https://github.com/huggingface/accelerate
python -m pip install -U git+https://github.com/huggingface/transformers
- name: Environment
run: |
@@ -153,4 +151,4 @@ jobs:
uses: actions/upload-artifact@v2
with:
name: examples_test_reports
path: reports
path: reports

2
.gitignore vendored
View File

@@ -165,4 +165,4 @@ tags
# DS_Store (MacOS)
.DS_Store
# RL pipelines may produce mp4 outputs
*.mp4
*.mp4

View File

@@ -280,7 +280,7 @@ init_image = init_image.resize((768, 512))
prompt = "A fantasy landscape, trending on artstation"
images = pipe(prompt=prompt, image=init_image, strength=0.75, guidance_scale=7.5).images
images = pipe(prompt=prompt, init_image=init_image, strength=0.75, guidance_scale=7.5).images
images[0].save("fantasy_landscape.png")
```

View File

@@ -11,7 +11,6 @@ RUN apt update && \
git-lfs \
curl \
ca-certificates \
libsndfile1-dev \
python3.8 \
python3-pip \
python3.8-venv && \
@@ -34,7 +33,6 @@ RUN python3 -m pip install --no-cache-dir --upgrade pip && \
datasets \
hf-doc-builder \
huggingface-hub \
librosa \
modelcards \
numpy \
scipy \

View File

@@ -11,7 +11,6 @@ RUN apt update && \
git-lfs \
curl \
ca-certificates \
libsndfile1-dev \
python3.8 \
python3-pip \
python3.8-venv && \
@@ -36,7 +35,6 @@ RUN python3 -m pip install --no-cache-dir --upgrade pip && \
datasets \
hf-doc-builder \
huggingface-hub \
librosa \
modelcards \
numpy \
scipy \

View File

@@ -11,7 +11,6 @@ RUN apt update && \
git-lfs \
curl \
ca-certificates \
libsndfile1-dev \
python3.8 \
python3-pip \
python3.8-venv && \
@@ -34,7 +33,6 @@ RUN python3 -m pip install --no-cache-dir --upgrade pip && \
datasets \
hf-doc-builder \
huggingface-hub \
librosa \
modelcards \
numpy \
scipy \

View File

@@ -11,7 +11,6 @@ RUN apt update && \
git-lfs \
curl \
ca-certificates \
libsndfile1-dev \
python3.8 \
python3-pip \
python3.8-venv && \
@@ -34,7 +33,6 @@ RUN python3 -m pip install --no-cache-dir --upgrade pip && \
datasets \
hf-doc-builder \
huggingface-hub \
librosa \
modelcards \
numpy \
scipy \

View File

@@ -11,7 +11,6 @@ RUN apt update && \
git-lfs \
curl \
ca-certificates \
libsndfile1-dev \
python3.8 \
python3-pip \
python3.8-venv && \
@@ -33,7 +32,6 @@ RUN python3 -m pip install --no-cache-dir --upgrade pip && \
datasets \
hf-doc-builder \
huggingface-hub \
librosa \
modelcards \
numpy \
scipy \

View File

@@ -11,7 +11,6 @@ RUN apt update && \
git-lfs \
curl \
ca-certificates \
libsndfile1-dev \
python3.8 \
python3-pip \
python3.8-venv && \
@@ -33,7 +32,6 @@ RUN python3 -m pip install --no-cache-dir --upgrade pip && \
datasets \
hf-doc-builder \
huggingface-hub \
librosa \
modelcards \
numpy \
scipy \

View File

@@ -49,8 +49,6 @@
title: "OpenVINO"
- local: optimization/mps
title: "MPS"
- local: optimization/habana
title: "Habana Gaudi"
title: "Optimization/Special Hardware"
- sections:
- local: training/overview
@@ -102,30 +100,20 @@
title: "Latent Diffusion"
- local: api/pipelines/latent_diffusion_uncond
title: "Unconditional Latent Diffusion"
- local: api/pipelines/paint_by_example
title: "PaintByExample"
- local: api/pipelines/pndm
title: "PNDM"
- local: api/pipelines/score_sde_ve
title: "Score SDE VE"
- local: api/pipelines/stable_diffusion
title: "Stable Diffusion"
- local: api/pipelines/stable_diffusion_2
title: "Stable Diffusion 2"
- local: api/pipelines/stable_diffusion_safe
title: "Safe Stable Diffusion"
- local: api/pipelines/stochastic_karras_ve
title: "Stochastic Karras VE"
- local: api/pipelines/dance_diffusion
title: "Dance Diffusion"
- local: api/pipelines/versatile_diffusion
title: "Versatile Diffusion"
- local: api/pipelines/vq_diffusion
title: "VQ Diffusion"
- local: api/pipelines/repaint
title: "RePaint"
- local: api/pipelines/audio_diffusion
title: "Audio Diffusion"
title: "Pipelines"
- sections:
- local: api/experimental/rl

View File

@@ -51,7 +51,7 @@ To use a different scheduler, you can either change it via the [`ConfigMixin.fro
```
- *How to convert all use cases with multiple or single pipeline*
- *How to conver all use cases with multiple or single pipeline*
If you want to use all possible use cases in a single `DiffusionPipeline` we recommend using the `components` functionality to instantiate all components in the most memory-efficient way:

View File

@@ -1,102 +0,0 @@
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Audio Diffusion
## Overview
[Audio Diffusion](https://github.com/teticio/audio-diffusion) by Robert Dargavel Smith.
Audio Diffusion leverages the recent advances in image generation using diffusion models by converting audio samples to
and from mel spectrogram images.
The original codebase of this implementation can be found [here](https://github.com/teticio/audio-diffusion), including
training scripts and example notebooks.
## Available Pipelines:
| Pipeline | Tasks | Colab
|---|---|:---:|
| [pipeline_audio_diffusion.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/audio_diffusion/pipeline_audio_diffusion.py) | *Unconditional Audio Generation* | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/teticio/audio-diffusion/blob/master/notebooks/audio_diffusion_pipeline.ipynb) |
## Examples:
### Audio Diffusion
```python
import torch
from IPython.display import Audio
from diffusers import DiffusionPipeline
device = "cuda" if torch.cuda.is_available() else "cpu"
pipe = DiffusionPipeline.from_pretrained("teticio/audio-diffusion-256").to(device)
output = pipe()
display(output.images[0])
display(Audio(output.audios[0], rate=mel.get_sample_rate()))
```
### Latent Audio Diffusion
```python
import torch
from IPython.display import Audio
from diffusers import DiffusionPipeline
device = "cuda" if torch.cuda.is_available() else "cpu"
pipe = DiffusionPipeline.from_pretrained("teticio/latent-audio-diffusion-256").to(device)
output = pipe()
display(output.images[0])
display(Audio(output.audios[0], rate=pipe.mel.get_sample_rate()))
```
### Audio Diffusion with DDIM (faster)
```python
import torch
from IPython.display import Audio
from diffusers import DiffusionPipeline
device = "cuda" if torch.cuda.is_available() else "cpu"
pipe = DiffusionPipeline.from_pretrained("teticio/audio-diffusion-ddim-256").to(device)
output = pipe()
display(output.images[0])
display(Audio(output.audios[0], rate=pipe.mel.get_sample_rate()))
```
### Variations, in-painting, out-painting etc.
```python
output = pipe(
raw_audio=output.audios[0, 0],
start_step=int(pipe.get_default_steps() / 2),
mask_start_secs=1,
mask_end_secs=1,
)
display(output.images[0])
display(Audio(output.audios[0], rate=pipe.mel.get_sample_rate()))
```
## AudioDiffusionPipeline
[[autodoc]] AudioDiffusionPipeline
- __call__
- encode
- slerp
## Mel
[[autodoc]] Mel
- audio_slice_to_image
- image_to_audio

View File

@@ -57,7 +57,7 @@ prompt = "An astronaut riding an elephant"
image = pipe(
prompt=prompt,
source_prompt=source_prompt,
image=init_image,
init_image=init_image,
num_inference_steps=100,
eta=0.1,
strength=0.8,
@@ -83,7 +83,7 @@ torch.manual_seed(0)
image = pipe(
prompt=prompt,
source_prompt=source_prompt,
image=init_image,
init_image=init_image,
num_inference_steps=100,
eta=0.1,
strength=0.85,

View File

@@ -45,7 +45,6 @@ available a colab notebook to directly try them out.
| Pipeline | Paper | Tasks | Colab
|---|---|:---:|:---:|
| [alt_diffusion](./api/pipelines/alt_diffusion) | [**AltDiffusion**](https://arxiv.org/abs/2211.06679) | Image-to-Image Text-Guided Generation | -
| [audio_diffusion](./api/pipelines/audio_diffusion) | [**Audio Diffusion**](https://github.com/teticio/audio_diffusion.git) | Unconditional Audio Generation |
| [cycle_diffusion](./api/pipelines/cycle_diffusion) | [**Cycle Diffusion**](https://arxiv.org/abs/2210.05559) | Image-to-Image Text-Guided Generation |
| [dance_diffusion](./api/pipelines/dance_diffusion) | [**Dance Diffusion**](https://github.com/williamberman/diffusers.git) | Unconditional Audio Generation |
| [ddpm](./api/pipelines/ddpm) | [**Denoising Diffusion Probabilistic Models**](https://arxiv.org/abs/2006.11239) | Unconditional Image Generation |
@@ -53,21 +52,13 @@ available a colab notebook to directly try them out.
| [latent_diffusion](./api/pipelines/latent_diffusion) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752)| Text-to-Image Generation |
| [latent_diffusion](./api/pipelines/latent_diffusion) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752)| Super Resolution Image-to-Image |
| [latent_diffusion_uncond](./api/pipelines/latent_diffusion_uncond) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752) | Unconditional Image Generation |
| [paint_by_example](./api/pipelines/paint_by_example) | [**Paint by Example: Exemplar-based Image Editing with Diffusion Models**](https://arxiv.org/abs/2211.13227) | Image-Guided Image Inpainting |
| [pndm](./api/pipelines/pndm) | [**Pseudo Numerical Methods for Diffusion Models on Manifolds**](https://arxiv.org/abs/2202.09778) | Unconditional Image Generation |
| [score_sde_ve](./api/pipelines/score_sde_ve) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation |
| [score_sde_vp](./api/pipelines/score_sde_vp) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation |
| [stable_diffusion](./api/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-to-Image Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb)
| [stable_diffusion](./api/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Image-to-Image Text-Guided Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb)
| [stable_diffusion](./api/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-Guided Image Inpainting | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb)
| [stable_diffusion_2](./api/pipelines/stable_diffusion_2) | [**Stable Diffusion 2**](https://stability.ai/blog/stable-diffusion-v2-release) | Text-to-Image Generation |
| [stable_diffusion_2](./api/pipelines/stable_diffusion_2) | [**Stable Diffusion 2**](https://stability.ai/blog/stable-diffusion-v2-release) | Text-Guided Image Inpainting |
| [stable_diffusion_2](./api/pipelines/stable_diffusion_2) | [**Stable Diffusion 2**](https://stability.ai/blog/stable-diffusion-v2-release) | Text-Guided Super Resolution Image-to-Image |
| [stable_diffusion_safe](./api/pipelines/stable_diffusion_safe) | [**Safe Stable Diffusion**](https://arxiv.org/abs/2211.05105) | Text-Guided Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ml-research/safe-latent-diffusion/blob/main/examples/Safe%20Latent%20Diffusion.ipynb)
| [stochastic_karras_ve](./api/pipelines/stochastic_karras_ve) | [**Elucidating the Design Space of Diffusion-Based Generative Models**](https://arxiv.org/abs/2206.00364) | Unconditional Image Generation |
| [versatile_diffusion](./api/pipelines/versatile_diffusion) | [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://arxiv.org/abs/2211.08332) | Text-to-Image Generation |
| [versatile_diffusion](./api/pipelines/versatile_diffusion) | [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://arxiv.org/abs/2211.08332) | Image Variations Generation |
| [versatile_diffusion](./api/pipelines/versatile_diffusion) | [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://arxiv.org/abs/2211.08332) | Dual Image and Text Guided Generation |
| [stochastic_karras_ve](./api/pipelines/stochastic_karras_ve) | [**Elucidating the Design Space of Diffusion-Based Generative Models**](https://arxiv.org/abs/2206.00364) | Unconditional Image Generation |
| [vq_diffusion](./api/pipelines/vq_diffusion) | [Vector Quantized Diffusion Model for Text-to-Image Synthesis](https://arxiv.org/abs/2111.14822) | Text-to-Image Generation |
@@ -151,7 +142,7 @@ init_image = init_image.resize((768, 512))
prompt = "A fantasy landscape, trending on artstation"
images = pipe(prompt=prompt, image=init_image, strength=0.75, guidance_scale=7.5).images
images = pipe(prompt=prompt, init_image=init_image, strength=0.75, guidance_scale=7.5).images
images[0].save("fantasy_landscape.png")
```

View File

@@ -1,73 +0,0 @@
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# PaintByExample
## Overview
[Paint by Example: Exemplar-based Image Editing with Diffusion Models](https://arxiv.org/abs/2211.13227) by Binxin Yang, Shuyang Gu, Bo Zhang, Ting Zhang, Xuejin Chen, Xiaoyan Sun, Dong Chen, Fang Wen
The abstract of the paper is the following:
*Language-guided image editing has achieved great success recently. In this paper, for the first time, we investigate exemplar-guided image editing for more precise control. We achieve this goal by leveraging self-supervised training to disentangle and re-organize the source image and the exemplar. However, the naive approach will cause obvious fusing artifacts. We carefully analyze it and propose an information bottleneck and strong augmentations to avoid the trivial solution of directly copying and pasting the exemplar image. Meanwhile, to ensure the controllability of the editing process, we design an arbitrary shape mask for the exemplar image and leverage the classifier-free guidance to increase the similarity to the exemplar image. The whole framework involves a single forward of the diffusion model without any iterative optimization. We demonstrate that our method achieves an impressive performance and enables controllable editing on in-the-wild images with high fidelity.*
The original codebase can be found [here](https://github.com/Fantasy-Studio/Paint-by-Example).
## Available Pipelines:
| Pipeline | Tasks | Colab
|---|---|:---:|
| [pipeline_paint_by_example.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py) | *Image-Guided Image Painting* | - |
## Tips
- PaintByExample is supported by the official [Fantasy-Studio/Paint-by-Example](https://huggingface.co/Fantasy-Studio/Paint-by-Example) checkpoint. The checkpoint has been warm-started from the [CompVis/stable-diffusion-v1-4](https://huggingface.co/CompVis/stable-diffusion-v1-4) and with the objective to inpaint partly masked images conditioned on example / reference images
- To quickly demo *PaintByExample*, please have a look at [this demo](https://huggingface.co/spaces/Fantasy-Studio/Paint-by-Example)
- You can run the following code snippet as an example:
```python
# !pip install diffusers transformers
import PIL
import requests
import torch
from io import BytesIO
from diffusers import DiffusionPipeline
def download_image(url):
response = requests.get(url)
return PIL.Image.open(BytesIO(response.content)).convert("RGB")
img_url = "https://raw.githubusercontent.com/Fantasy-Studio/Paint-by-Example/main/examples/image/example_1.png"
mask_url = "https://raw.githubusercontent.com/Fantasy-Studio/Paint-by-Example/main/examples/mask/example_1.png"
example_url = "https://raw.githubusercontent.com/Fantasy-Studio/Paint-by-Example/main/examples/reference/example_1.jpg"
init_image = download_image(img_url).resize((512, 512))
mask_image = download_image(mask_url).resize((512, 512))
example_image = download_image(example_url).resize((512, 512))
pipe = DiffusionPipeline.from_pretrained(
"Fantasy-Studio/Paint-by-Example",
torch_dtype=torch.float16,
)
pipe = pipe.to("cuda")
image = pipe(image=init_image, mask_image=mask_image, example_image=example_image).images[0]
image
```
## PaintByExamplePipeline
[[autodoc]] pipelines.paint_by_example.pipeline_paint_by_example.PaintByExamplePipeline
- __call__

View File

@@ -48,7 +48,7 @@ To use a different scheduler, you can either change it via the [`ConfigMixin.fro
```
### How to convert all use cases with multiple or single pipeline
### How to conver all use cases with multiple or single pipeline
If you want to use all possible use cases in a single `DiffusionPipeline` you can either:
- Make use of the [Stable Diffusion Mega Pipeline](https://github.com/huggingface/diffusers/tree/main/examples/community#stable-diffusion-mega) or
@@ -76,48 +76,15 @@ If you want to use all possible use cases in a single `DiffusionPipeline` you ca
- __call__
- enable_attention_slicing
- disable_attention_slicing
- enable_vae_slicing
- disable_vae_slicing
- enable_xformers_memory_efficient_attention
- disable_xformers_memory_efficient_attention
## StableDiffusionImg2ImgPipeline
[[autodoc]] StableDiffusionImg2ImgPipeline
- __call__
- enable_attention_slicing
- disable_attention_slicing
- enable_xformers_memory_efficient_attention
- disable_xformers_memory_efficient_attention
## StableDiffusionInpaintPipeline
[[autodoc]] StableDiffusionInpaintPipeline
- __call__
- enable_attention_slicing
- disable_attention_slicing
- enable_xformers_memory_efficient_attention
- disable_xformers_memory_efficient_attention
## StableDiffusionDepth2ImgPipeline
[[autodoc]] StableDiffusionDepth2ImgPipeline
- __call__
- enable_attention_slicing
- disable_attention_slicing
- enable_xformers_memory_efficient_attention
- disable_xformers_memory_efficient_attention
## StableDiffusionImageVariationPipeline
[[autodoc]] StableDiffusionImageVariationPipeline
- __call__
- enable_attention_slicing
- disable_attention_slicing
- enable_xformers_memory_efficient_attention
- disable_xformers_memory_efficient_attention
## StableDiffusionUpscalePipeline
[[autodoc]] StableDiffusionUpscalePipeline
- __call__
- enable_attention_slicing
- disable_attention_slicing
- enable_xformers_memory_efficient_attention
- disable_xformers_memory_efficient_attention

View File

@@ -1,174 +0,0 @@
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Stable diffusion 2
Stable Diffusion 2 is a text-to-image _latent diffusion_ model built upon the work of [Stable Diffusion 1](https://stability.ai/blog/stable-diffusion-public-release).
The project to train Stable Diffusion 2 was led by Robin Rombach and Katherine Crowson from [Stability AI](https://stability.ai/) and [LAION](https://laion.ai/).
*The Stable Diffusion 2.0 release includes robust text-to-image models trained using a brand new text encoder (OpenCLIP), developed by LAION with support from Stability AI, which greatly improves the quality of the generated images compared to earlier V1 releases. The text-to-image models in this release can generate images with default resolutions of both 512x512 pixels and 768x768 pixels.
These models are trained on an aesthetic subset of the [LAION-5B dataset](https://laion.ai/blog/laion-5b/) created by the DeepFloyd team at Stability AI, which is then further filtered to remove adult content using [LAIONs NSFW filter](https://openreview.net/forum?id=M3Y74vmsMcY).*
For more details about how Stable Diffusion 2 works and how it differs from Stable Diffusion 1, please refer to the official [launch announcement post](https://stability.ai/blog/stable-diffusion-v2-release).
## Tips
### Available checkpoints:
Note that the architecture is more or less identical to [Stable Diffusion 1](./api/pipelines/stable_diffusion) so please refer to [this page](./api/pipelines/stable_diffusion) for API documentation.
- *Text-to-Image (512x512 resolution)*: [stabilityai/stable-diffusion-2-base](https://huggingface.co/stabilityai/stable-diffusion-2-base) with [`StableDiffusionPipeline`]
- *Text-to-Image (768x768 resolution)*: [stabilityai/stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) with [`StableDiffusionPipeline`]
- *Image Inpainting (512x512 resolution)*: [stabilityai/stable-diffusion-2-inpainting](https://huggingface.co/stabilityai/stable-diffusion-2-inpainting) with [`StableDiffusionInpaintPipeline`]
- *Image Upscaling (x4 resolution resolution)*: [stable-diffusion-x4-upscaler](https://huggingface.co/stabilityai/stable-diffusion-x4-upscaler) [`StableDiffusionUpscalePipeline`]
- *Depth-to-Image (512x512 resolution)*: [stabilityai/stable-diffusion-2-depth](https://huggingface.co/stabilityai/stable-diffusion-2-depth) with [`StableDiffusionDepth2ImagePipeline`]
We recommend using the [`DPMSolverMultistepScheduler`] as it's currently the fastest scheduler there is.
- *Text-to-Image (512x512 resolution)*:
```python
from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
import torch
repo_id = "stabilityai/stable-diffusion-2-base"
pipe = DiffusionPipeline.from_pretrained(repo_id, torch_dtype=torch.float16, revision="fp16")
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
pipe = pipe.to("cuda")
prompt = "High quality photo of an astronaut riding a horse in space"
image = pipe(prompt, num_inference_steps=25).images[0]
image.save("astronaut.png")
```
- *Text-to-Image (768x768 resolution)*:
```python
from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
import torch
repo_id = "stabilityai/stable-diffusion-2"
pipe = DiffusionPipeline.from_pretrained(repo_id, torch_dtype=torch.float16, revision="fp16")
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
pipe = pipe.to("cuda")
prompt = "High quality photo of an astronaut riding a horse in space"
image = pipe(prompt, guidance_scale=9, num_inference_steps=25).images[0]
image.save("astronaut.png")
```
- *Image Inpainting (512x512 resolution)*:
```python
import PIL
import requests
import torch
from io import BytesIO
from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
def download_image(url):
response = requests.get(url)
return PIL.Image.open(BytesIO(response.content)).convert("RGB")
img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
init_image = download_image(img_url).resize((512, 512))
mask_image = download_image(mask_url).resize((512, 512))
repo_id = "stabilityai/stable-diffusion-2-inpainting"
pipe = DiffusionPipeline.from_pretrained(repo_id, torch_dtype=torch.float16, revision="fp16")
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
pipe = pipe.to("cuda")
prompt = "Face of a yellow cat, high resolution, sitting on a park bench"
image = pipe(prompt=prompt, image=init_image, mask_image=mask_image, num_inference_steps=25).images[0]
image.save("yellow_cat.png")
```
- *Image Upscaling (x4 resolution resolution)*: [stable-diffusion-x4-upscaler](https://huggingface.co/stabilityai/stable-diffusion-x4-upscaler) [`StableDiffusionUpscalePipeline`]
```python
import requests
from PIL import Image
from io import BytesIO
from diffusers import StableDiffusionUpscalePipeline
import torch
# load model and scheduler
model_id = "stabilityai/stable-diffusion-x4-upscaler"
pipeline = StableDiffusionUpscalePipeline.from_pretrained(model_id, revision="fp16", torch_dtype=torch.float16)
pipeline = pipeline.to("cuda")
# let's download an image
url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd2-upscale/low_res_cat.png"
response = requests.get(url)
low_res_img = Image.open(BytesIO(response.content)).convert("RGB")
low_res_img = low_res_img.resize((128, 128))
prompt = "a white cat"
upscaled_image = pipeline(prompt=prompt, image=low_res_img).images[0]
upscaled_image.save("upsampled_cat.png")
```
- *Depth-Guided Text-to-Image*: [stabilityai/stable-diffusion-2-depth](https://huggingface.co/stabilityai/stable-diffusion-2-depth) [`StableDiffusionDepth2ImagePipeline`]
**Installation**
```bash
!pip install -U git+https://github.com/huggingface/transformers.git
!pip install diffusers[torch]
```
**Example**
```python
import torch
import requests
from PIL import Image
from diffusers import StableDiffusionDepth2ImgPipeline
pipe = StableDiffusionDepth2ImgPipeline.from_pretrained(
"stabilityai/stable-diffusion-2-depth",
torch_dtype=torch.float16,
).to("cuda")
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
init_image = Image.open(requests.get(url, stream=True).raw)
prompt = "two tigers"
n_propmt = "bad, deformed, ugly, bad anotomy"
image = pipe(prompt=prompt, image=init_image, negative_prompt=n_propmt, strength=0.7).images[0]
```
### How to load and use different schedulers.
The stable diffusion pipeline uses [`DDIMScheduler`] scheduler by default. But `diffusers` provides many other schedulers that can be used with the stable diffusion pipeline such as [`PNDMScheduler`], [`LMSDiscreteScheduler`], [`EulerDiscreteScheduler`], [`EulerAncestralDiscreteScheduler`] etc.
To use a different scheduler, you can either change it via the [`ConfigMixin.from_config`] method or pass the `scheduler` argument to the `from_pretrained` method of the pipeline. For example, to use the [`EulerDiscreteScheduler`], you can do the following:
```python
>>> from diffusers import StableDiffusionPipeline, EulerDiscreteScheduler
>>> pipeline = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2")
>>> pipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config)
>>> # or
>>> euler_scheduler = EulerDiscreteScheduler.from_pretrained("stabilityai/stable-diffusion-2", subfolder="scheduler")
>>> pipeline = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2", scheduler=euler_scheduler)
```

View File

@@ -1,90 +0,0 @@
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Safe Stable Diffusion
Safe Stable Diffusion was proposed in [Safe Latent Diffusion: Mitigating Inappropriate Degeneration in Diffusion Models](https://arxiv.org/abs/2211.05105) and mitigates the well known issue that models like Stable Diffusion that are trained on unfiltered, web-crawled datasets tend to suffer from inappropriate degeneration. For instance Stable Diffusion may unexpectedly generate nudity, violence, images depicting self-harm, or otherwise offensive content.
Safe Stable Diffusion is an extension to the Stable Diffusion that drastically reduces content like this.
The abstract of the paper is the following:
*Text-conditioned image generation models have recently achieved astonishing results in image quality and text alignment and are consequently employed in a fast-growing number of applications. Since they are highly data-driven, relying on billion-sized datasets randomly scraped from the internet, they also suffer, as we demonstrate, from degenerated and biased human behavior. In turn, they may even reinforce such biases. To help combat these undesired side effects, we present safe latent diffusion (SLD). Specifically, to measure the inappropriate degeneration due to unfiltered and imbalanced training sets, we establish a novel image generation test bed-inappropriate image prompts (I2P)-containing dedicated, real-world image-to-text prompts covering concepts such as nudity and violence. As our exhaustive empirical evaluation demonstrates, the introduced SLD removes and suppresses inappropriate image parts during the diffusion process, with no additional training required and no adverse effect on overall image quality or text alignment.*
*Overview*:
| Pipeline | Tasks | Colab | Demo
|---|---|:---:|:---:|
| [pipeline_stable_diffusion_safe.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py) | *Text-to-Image Generation* | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ml-research/safe-latent-diffusion/blob/main/examples/Safe%20Latent%20Diffusion.ipynb) | -
## Tips
- Safe Stable Diffusion may also be used with weights of [Stable Diffusion](./api/pipelines/stable_diffusion).
### Run Safe Stable Diffusion
Safe Stable Diffusion can be tested very easily with the [`StableDiffusionPipelineSafe`], and the `"AIML-TUDA/stable-diffusion-safe"` checkpoint exactly in the same way it is shown in the [Conditional Image Generation Guide](./using-diffusers/conditional_image_generation).
### Interacting with the Safety Concept
To check and edit the currently used safety concept, use the `safety_concept` property of [`StableDiffusionPipelineSafe`]
```python
>>> from diffusers import StableDiffusionPipelineSafe
>>> pipeline = StableDiffusionPipelineSafe.from_pretrained("AIML-TUDA/stable-diffusion-safe")
>>> pipeline.safety_concept
```
For each image generation the active concept is also contained in [`StableDiffusionSafePipelineOutput`].
### Using pre-defined safety configurations
You may use the 4 configurations defined in the [Safe Latent Diffusion paper](https://arxiv.org/abs/2211.05105) as follows:
```python
>>> from diffusers import StableDiffusionPipelineSafe
>>> from diffusers.pipelines.stable_diffusion_safe import SafetyConfig
>>> pipeline = StableDiffusionPipelineSafe.from_pretrained("AIML-TUDA/stable-diffusion-safe")
>>> prompt = "the four horsewomen of the apocalypse, painting by tom of finland, gaston bussiere, craig mullins, j. c. leyendecker"
>>> out = pipeline(prompt=prompt, **SafetyConfig.MAX)
```
The following configurations are available: `SafetyConfig.WEAK`, `SafetyConfig.MEDIUM`, `SafetyConfig.STRONg`, and `SafetyConfig.MAX`.
### How to load and use different schedulers.
The safe stable diffusion pipeline uses [`PNDMScheduler`] scheduler by default. But `diffusers` provides many other schedulers that can be used with the stable diffusion pipeline such as [`DDIMScheduler`], [`LMSDiscreteScheduler`], [`EulerDiscreteScheduler`], [`EulerAncestralDiscreteScheduler`] etc.
To use a different scheduler, you can either change it via the [`ConfigMixin.from_config`] method or pass the `scheduler` argument to the `from_pretrained` method of the pipeline. For example, to use the [`EulerDiscreteScheduler`], you can do the following:
```python
>>> from diffusers import StableDiffusionPipelineSafe, EulerDiscreteScheduler
>>> pipeline = StableDiffusionPipelineSafe.from_pretrained("AIML-TUDA/stable-diffusion-safe")
>>> pipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config)
>>> # or
>>> euler_scheduler = EulerDiscreteScheduler.from_pretrained("AIML-TUDA/stable-diffusion-safe", subfolder="scheduler")
>>> pipeline = StableDiffusionPipelineSafe.from_pretrained(
... "AIML-TUDA/stable-diffusion-safe", scheduler=euler_scheduler
... )
```
## StableDiffusionSafePipelineOutput
[[autodoc]] pipelines.stable_diffusion_safe.StableDiffusionSafePipelineOutput
## StableDiffusionPipelineSafe
[[autodoc]] StableDiffusionPipelineSafe
- __call__
- enable_attention_slicing
- disable_attention_slicing

View File

@@ -1,73 +0,0 @@
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# VersatileDiffusion
VersatileDiffusion was proposed in [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://arxiv.org/abs/2211.08332) by Xingqian Xu, Zhangyang Wang, Eric Zhang, Kai Wang, Humphrey Shi .
The abstract of the paper is the following:
*The recent advances in diffusion models have set an impressive milestone in many generation tasks. Trending works such as DALL-E2, Imagen, and Stable Diffusion have attracted great interest in academia and industry. Despite the rapid landscape changes, recent new approaches focus on extensions and performance rather than capacity, thus requiring separate models for separate tasks. In this work, we expand the existing single-flow diffusion pipeline into a multi-flow network, dubbed Versatile Diffusion (VD), that handles text-to-image, image-to-text, image-variation, and text-variation in one unified model. Moreover, we generalize VD to a unified multi-flow multimodal diffusion framework with grouped layers, swappable streams, and other propositions that can process modalities beyond images and text. Through our experiments, we demonstrate that VD and its underlying framework have the following merits: a) VD handles all subtasks with competitive quality; b) VD initiates novel extensions and applications such as disentanglement of style and semantic, image-text dual-guided generation, etc.; c) Through these experiments and applications, VD provides more semantic insights of the generated outputs.*
## Tips
- VersatileDiffusion is conceptually very similar as [Stable Diffusion](./api/pipelines/stable_diffusion), but instead of providing just a image data stream conditioned on text, VersatileDiffusion provides both a image and text data stream and can be conditioned on both text and image.
### *Run VersatileDiffusion*
You can both load the memory intensive "all-in-one" [`VersatileDiffusionPipeline`] that can run all tasks
with the same class as shown in [`VersatileDiffusionPipeline.text_to_image`], [`VersatileDiffusionPipeline.image_variation`], and [`VersatileDiffusionPipeline.dual_guided`]
**or**
You can run the individual pipelines which are much more memory efficient:
- *Text-to-Image*: [`VersatileDiffusionTextToImagePipeline.__call__`]
- *Image Variation*: [`VersatileDiffusionImageVariationPipeline.__call__`]
- *Dual Text and Image Guided Generation*: [`VersatileDiffusionDualGuidedPipeline.__call__`]
### *How to load and use different schedulers.*
The versatile diffusion pipelines uses [`DDIMScheduler`] scheduler by default. But `diffusers` provides many other schedulers that can be used with the alt diffusion pipeline such as [`PNDMScheduler`], [`LMSDiscreteScheduler`], [`EulerDiscreteScheduler`], [`EulerAncestralDiscreteScheduler`] etc.
To use a different scheduler, you can either change it via the [`ConfigMixin.from_config`] method or pass the `scheduler` argument to the `from_pretrained` method of the pipeline. For example, to use the [`EulerDiscreteScheduler`], you can do the following:
```python
>>> from diffusers import VersatileDiffusionPipeline, EulerDiscreteScheduler
>>> pipeline = VersatileDiffusionPipeline.from_pretrained("shi-labs/versatile-diffusion")
>>> pipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config)
>>> # or
>>> euler_scheduler = EulerDiscreteScheduler.from_pretrained("shi-labs/versatile-diffusion", subfolder="scheduler")
>>> pipeline = VersatileDiffusionPipeline.from_pretrained("shi-labs/versatile-diffusion", scheduler=euler_scheduler)
```
## VersatileDiffusionPipeline
[[autodoc]] VersatileDiffusionPipeline
## VersatileDiffusionTextToImagePipeline
[[autodoc]] VersatileDiffusionTextToImagePipeline
- __call__
- enable_attention_slicing
- disable_attention_slicing
## VersatileDiffusionImageVariationPipeline
[[autodoc]] VersatileDiffusionImageVariationPipeline
- __call__
- enable_attention_slicing
- disable_attention_slicing
## VersatileDiffusionDualGuidedPipeline
[[autodoc]] VersatileDiffusionDualGuidedPipeline
- __call__
- enable_attention_slicing
- disable_attention_slicing

View File

@@ -70,45 +70,12 @@ Original paper can be found [here](https://arxiv.org/abs/2010.02502).
[[autodoc]] DDPMScheduler
#### Singlestep DPM-Solver
Original paper can be found [here](https://arxiv.org/abs/2206.00927) and the [improved version](https://arxiv.org/abs/2211.01095). The original implementation can be found [here](https://github.com/LuChengTHU/dpm-solver).
[[autodoc]] DPMSolverSinglestepScheduler
#### Multistep DPM-Solver
Original paper can be found [here](https://arxiv.org/abs/2206.00927) and the [improved version](https://arxiv.org/abs/2211.01095). The original implementation can be found [here](https://github.com/LuChengTHU/dpm-solver).
[[autodoc]] DPMSolverMultistepScheduler
#### Heun scheduler inspired by Karras et. al paper
Algorithm 1 of [Karras et. al](https://arxiv.org/abs/2206.00364).
Scheduler ported from @crowsonkb's https://github.com/crowsonkb/k-diffusion library:
All credit for making this scheduler work goes to [Katherine Crowson](https://github.com/crowsonkb/)
[[autodoc]] HeunDiscreteScheduler
#### DPM Discrete Scheduler inspired by Karras et. al paper
Inspired by [Karras et. al](https://arxiv.org/abs/2206.00364).
Scheduler ported from @crowsonkb's https://github.com/crowsonkb/k-diffusion library:
All credit for making this scheduler work goes to [Katherine Crowson](https://github.com/crowsonkb/)
[[autodoc]] KDPM2DiscreteScheduler
#### DPM Discrete Scheduler with ancestral sampling inspired by Karras et. al paper
Inspired by [Karras et. al](https://arxiv.org/abs/2206.00364).
Scheduler ported from @crowsonkb's https://github.com/crowsonkb/k-diffusion library:
All credit for making this scheduler work goes to [Katherine Crowson](https://github.com/crowsonkb/)
[[autodoc]] KDPM2AncestralDiscreteScheduler
#### Variance exploding, stochastic sampling from Karras et. al
Original paper can be found [here](https://arxiv.org/abs/2006.11239).
@@ -119,6 +86,7 @@ Original paper can be found [here](https://arxiv.org/abs/2006.11239).
Original implementation can be found [here](https://arxiv.org/abs/2206.00364).
[[autodoc]] LMSDiscreteScheduler
#### Pseudo numerical methods for diffusion models (PNDM)

View File

@@ -35,7 +35,6 @@ available a colab notebook to directly try them out.
| Pipeline | Paper | Tasks | Colab
|---|---|:---:|:---:|
| [alt_diffusion](./api/pipelines/alt_diffusion) | [**AltDiffusion**](https://arxiv.org/abs/2211.06679) | Image-to-Image Text-Guided Generation |
| [audio_diffusion](./api/pipelines/audio_diffusion) | [**Audio Diffusion**](https://github.com/teticio/audio-diffusion.git) | Unconditional Audio Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/teticio/audio-diffusion/blob/master/notebooks/audio_diffusion_pipeline.ipynb)
| [cycle_diffusion](./api/pipelines/cycle_diffusion) | [**Cycle Diffusion**](https://arxiv.org/abs/2210.05559) | Image-to-Image Text-Guided Generation |
| [dance_diffusion](./api/pipelines/dance_diffusion) | [**Dance Diffusion**](https://github.com/williamberman/diffusers.git) | Unconditional Audio Generation |
| [ddpm](./api/pipelines/ddpm) | [**Denoising Diffusion Probabilistic Models**](https://arxiv.org/abs/2006.11239) | Unconditional Image Generation |
@@ -43,21 +42,13 @@ available a colab notebook to directly try them out.
| [latent_diffusion](./api/pipelines/latent_diffusion) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752)| Text-to-Image Generation |
| [latent_diffusion](./api/pipelines/latent_diffusion) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752)| Super Resolution Image-to-Image |
| [latent_diffusion_uncond](./api/pipelines/latent_diffusion_uncond) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752) | Unconditional Image Generation |
| [paint_by_example](./api/pipelines/paint_by_example) | [**Paint by Example: Exemplar-based Image Editing with Diffusion Models**](https://arxiv.org/abs/2211.13227) | Image-Guided Image Inpainting |
| [pndm](./api/pipelines/pndm) | [**Pseudo Numerical Methods for Diffusion Models on Manifolds**](https://arxiv.org/abs/2202.09778) | Unconditional Image Generation |
| [score_sde_ve](./api/pipelines/score_sde_ve) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation |
| [score_sde_vp](./api/pipelines/score_sde_vp) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation |
| [stable_diffusion](./api/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-to-Image Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb)
| [stable_diffusion](./api/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Image-to-Image Text-Guided Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb)
| [stable_diffusion](./api/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-Guided Image Inpainting | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb)
| [stable_diffusion_2](./api/pipelines/stable_diffusion_2) | [**Stable Diffusion 2**](https://stability.ai/blog/stable-diffusion-v2-release) | Text-to-Image Generation |
| [stable_diffusion_2](./api/pipelines/stable_diffusion_2) | [**Stable Diffusion 2**](https://stability.ai/blog/stable-diffusion-v2-release) | Text-Guided Image Inpainting |
| [stable_diffusion_2](./api/pipelines/stable_diffusion_2) | [**Stable Diffusion 2**](https://stability.ai/blog/stable-diffusion-v2-release) | Text-Guided Super Resolution Image-to-Image |
| [stable_diffusion_safe](./api/pipelines/stable_diffusion_safe) | [**Safe Stable Diffusion**](https://arxiv.org/abs/2211.05105) | Text-Guided Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ml-research/safe-latent-diffusion/blob/main/examples/Safe%20Latent%20Diffusion.ipynb)
| [stochastic_karras_ve](./api/pipelines/stochastic_karras_ve) | [**Elucidating the Design Space of Diffusion-Based Generative Models**](https://arxiv.org/abs/2206.00364) | Unconditional Image Generation |
| [versatile_diffusion](./api/pipelines/versatile_diffusion) | [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://arxiv.org/abs/2211.08332) | Text-to-Image Generation |
| [versatile_diffusion](./api/pipelines/versatile_diffusion) | [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://arxiv.org/abs/2211.08332) | Image Variations Generation |
| [versatile_diffusion](./api/pipelines/versatile_diffusion) | [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://arxiv.org/abs/2211.08332) | Dual Image and Text Guided Generation |
| [vq_diffusion](./api/pipelines/vq_diffusion) | [Vector Quantized Diffusion Model for Text-to-Image Synthesis](https://arxiv.org/abs/2111.14822) | Text-to-Image Generation |
**Note**: Pipelines are simple examples of how to play around with the diffusion systems as described in the corresponding papers.

View File

@@ -120,24 +120,3 @@ git pull
```
Your Python environment will find the `main` version of 🤗 Diffusers on the next run.
## Notice on telemetry logging
Our library gathers telemetry information during `from_pretrained()` requests.
This data includes the version of Diffusers and PyTorch/Flax, the requested model or pipeline class,
and the path to a pretrained checkpoint if it is hosted on the Hub.
This usage data helps us debug issues and prioritize new features.
No private data, such as paths to models saved locally on disk, is ever collected.
We understand that not everyone wants to share additional information, and we respect your privacy,
so you can disable telemetry collection by setting the `DISABLE_TELEMETRY` environment variable from your terminal:
On Linux/MacOS:
```bash
export DISABLE_TELEMETRY=YES
```
On Windows:
```bash
set DISABLE_TELEMETRY=YES
```

View File

@@ -117,34 +117,6 @@ image = pipe(prompt).images[0]
There's a small performance penalty of about 10% slower inference times, but this method allows you to use Stable Diffusion in as little as 3.2 GB of VRAM!
## Sliced VAE decode for larger batches
To decode large batches of images with limited VRAM, or to enable batches with 32 images or more, you can use sliced VAE decode that decodes the batch latents one image at a time.
You likely want to couple this with [`~StableDiffusionPipeline.enable_attention_slicing`] or [`~StableDiffusionPipeline.enable_xformers_memory_efficient_attention`] to further minimize memory use.
To perform the VAE decode one image at a time, invoke [`~StableDiffusionPipeline.enable_vae_slicing`] in your pipeline before inference. For example:
```Python
import torch
from diffusers import StableDiffusionPipeline
pipe = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
revision="fp16",
torch_dtype=torch.float16,
)
pipe = pipe.to("cuda")
prompt = "a photo of an astronaut riding a horse on mars"
pipe.enable_vae_slicing()
images = pipe([prompt] * 32).images
```
You may see a small performance boost in VAE decode on multi-image batches. There should be no performance impact on single-image batches.
## Offloading to CPU with accelerate for memory savings
For additional memory savings, you can offload the weights to CPU and load them to GPU when performing the forward pass.

View File

@@ -1,70 +0,0 @@
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# How to use Stable Diffusion on Habana Gaudi
🤗 Diffusers is compatible with Habana Gaudi through 🤗 [Optimum Habana](https://huggingface.co/docs/optimum/habana/usage_guides/stable_diffusion).
## Requirements
- Optimum Habana 1.3 or later, [here](https://huggingface.co/docs/optimum/habana/installation) is how to install it.
- SynapseAI 1.7.
## Inference Pipeline
To generate images with Stable Diffusion 1 and 2 on Gaudi, you need to instantiate two instances:
- A pipeline with [`GaudiStableDiffusionPipeline`](https://huggingface.co/docs/optimum/habana/package_reference/stable_diffusion_pipeline). This pipeline supports *text-to-image generation*.
- A scheduler with [`GaudiDDIMScheduler`](https://huggingface.co/docs/optimum/habana/package_reference/stable_diffusion_pipeline#optimum.habana.diffusers.GaudiDDIMScheduler). This scheduler has been optimized for Habana Gaudi.
When initializing the pipeline, you have to specify `use_habana=True` to deploy it on HPUs.
Furthermore, in order to get the fastest possible generations you should enable **HPU graphs** with `use_hpu_graphs=True`.
Finally, you will need to specify a [Gaudi configuration](https://huggingface.co/docs/optimum/habana/package_reference/gaudi_config) which can be downloaded from the [Hugging Face Hub](https://huggingface.co/Habana).
```python
from optimum.habana import GaudiConfig
from optimum.habana.diffusers import GaudiDDIMScheduler, GaudiStableDiffusionPipeline
model_name = "stabilityai/stable-diffusion-2-base"
scheduler = GaudiDDIMScheduler.from_pretrained(model_name, subfolder="scheduler")
pipeline = GaudiStableDiffusionPipeline.from_pretrained(
model_name,
scheduler=scheduler,
use_habana=True,
use_hpu_graphs=True,
gaudi_config="Habana/stable-diffusion",
)
```
You can then call the pipeline to generate images by batches from one or several prompts:
```python
outputs = pipeline(
prompt=[
"High quality photo of an astronaut riding a horse in space",
"Face of a yellow cat, high resolution, sitting on a park bench",
],
num_images_per_prompt=10,
batch_size=4,
)
```
For more information, check out Optimum Habana's [documentation](https://huggingface.co/docs/optimum/habana/usage_guides/stable_diffusion) and the [example](https://github.com/huggingface/optimum-habana/tree/main/examples/stable-diffusion) provided in the official Github repository.
## Benchmark
Here are the latencies for Habana Gaudi 1 and Gaudi 2 with the [Habana/stable-diffusion](https://huggingface.co/Habana/stable-diffusion) Gaudi configuration (mixed precision bf16/fp32):
| | Latency | Batch size |
| ------- |:-------:|:----------:|
| Gaudi 1 | 4.37s | 4/8 |
| Gaudi 2 | 1.19s | 4/8 |

View File

@@ -12,5 +12,5 @@ specific language governing permissions and limitations under the License.
# Using Diffusers for audio
[`DanceDiffusionPipeline`] and [`AudioDiffusionPipeline`] can be used to generate
audio rapidly! More coming soon!
The [`DanceDiffusionPipeline`] can be used to generate audio rapidly!
More coming soon!

View File

@@ -177,7 +177,7 @@ init_image = download_image(
prompt = "A fantasy landscape, trending on artstation"
images = pipe.img2img(prompt=prompt, image=init_image, strength=0.75, guidance_scale=7.5).images
images = pipe.img2img(prompt=prompt, init_image=init_image, strength=0.75, guidance_scale=7.5).images
### Inpainting
@@ -187,7 +187,7 @@ init_image = download_image(img_url).resize((512, 512))
mask_image = download_image(mask_url).resize((512, 512))
prompt = "a cat sitting on a bench"
images = pipe.inpaint(prompt=prompt, image=init_image, mask_image=mask_image, strength=0.75).images
images = pipe.inpaint(prompt=prompt, init_image=init_image, mask_image=mask_image, strength=0.75).images
```
As shown above this one pipeline can run all both "text-to-image", "image-to-image", and "inpainting" in one pipeline.

View File

@@ -37,7 +37,7 @@ init_image.thumbnail((768, 768))
prompt = "A fantasy landscape, trending on artstation"
images = pipe(prompt=prompt, image=init_image, strength=0.75, guidance_scale=7.5).images
images = pipe(prompt=prompt, init_image=init_image, strength=0.75, guidance_scale=7.5).images
images[0].save("fantasy_landscape.png")
```

View File

@@ -378,3 +378,21 @@ dpm = DPMSolverMultistepScheduler.from_pretrained(repo_id, subfolder="scheduler"
# replace `dpm` with any of `ddpm`, `ddim`, `pndm`, `lms`, `euler`, `euler_anc`
pipeline = StableDiffusionPipeline.from_pretrained(repo_id, scheduler=dpm)
```
## API
[[autodoc]] modeling_utils.ModelMixin
- from_pretrained
- save_pretrained
[[autodoc]] pipeline_utils.DiffusionPipeline
- from_pretrained
- save_pretrained
[[autodoc]] modeling_flax_utils.FlaxModelMixin
- from_pretrained
- save_pretrained
[[autodoc]] pipeline_flax_utils.FlaxDiffusionPipeline
- from_pretrained
- save_pretrained

View File

@@ -14,8 +14,7 @@ specific language governing permissions and limitations under the License.
Diffusers is in the process of expanding to modalities other than images.
Example type | Colab | Pipeline |
:-------------------------:|:-------------------------:|:-------------------------:|
[Molecule conformation](https://www.nature.com/subjects/molecular-conformation#:~:text=Definition,to%20changes%20in%20their%20environment.) generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/geodiff_molecule_conformation.ipynb) | ❌
Currently, one example is for [molecule conformation](https://www.nature.com/subjects/molecular-conformation#:~:text=Definition,to%20changes%20in%20their%20environment.) generation.
* Generate conformations in Colab [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/geodiff_molecule_conformation.ipynb)
More coming soon!

View File

@@ -13,13 +13,6 @@ specific language governing permissions and limitations under the License.
# Using Diffusers for reinforcement learning
Support for one RL model and related pipelines is included in the `experimental` source of diffusers.
More models and examples coming soon!
# Diffuser Value-guided Planning
You can run the model from [*Planning with Diffusion for Flexible Behavior Synthesis*](https://arxiv.org/abs/2205.09991) with Diffusers.
The script is located in the [RL Examples](https://github.com/huggingface/diffusers/tree/main/examples/rl) folder.
Or, run this example in Colab [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/reinforcement_learning_with_diffusers.ipynb)
[[autodoc]] diffusers.experimental.ValueGuidedRLPipeline
To try some of this in colab, please look at the following example:
* Model-based reinforcement learning on Colab [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/reinforcement_learning_with_diffusers.ipynb) ![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)

View File

@@ -23,7 +23,6 @@ If a community doesn't work as expected, please open an issue and ping the autho
| Text Based Inpainting Stable Diffusion | Stable Diffusion Inpainting Pipeline that enables passing a text prompt to generate the mask for inpainting| [Text Based Inpainting Stable Diffusion](#image-to-image-inpainting-stable-diffusion) | - | [Dhruv Karan](https://github.com/unography) |
| Bit Diffusion | Diffusion on discrete data | [Bit Diffusion](#bit-diffusion) | - |[Stuti R.](https://github.com/kingstut) |
| K-Diffusion Stable Diffusion | Run Stable Diffusion with any of [K-Diffusion's samplers](https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py) | [Stable Diffusion with K Diffusion](#stable-diffusion-with-k-diffusion) | - | [Patrick von Platen](https://github.com/patrickvonplaten/) |
| Checkpoint Merger Pipeline | Diffusion Pipeline that enables merging of saved model checkpoints | [Checkpoint Merger Pipeline](#checkpoint-merger-pipeline) | - | [Naga Sai Abhinay Devarinti](https://github.com/Abhinay1997/) |
@@ -167,7 +166,7 @@ init_image = download_image("https://raw.githubusercontent.com/CompVis/stable-di
prompt = "A fantasy landscape, trending on artstation"
images = pipe.img2img(prompt=prompt, image=init_image, strength=0.75, guidance_scale=7.5).images
images = pipe.img2img(prompt=prompt, init_image=init_image, strength=0.75, guidance_scale=7.5).images
### Inpainting
@@ -177,7 +176,7 @@ init_image = download_image(img_url).resize((512, 512))
mask_image = download_image(mask_url).resize((512, 512))
prompt = "a cat sitting on a bench"
images = pipe.inpaint(prompt=prompt, image=init_image, mask_image=mask_image, strength=0.75).images
images = pipe.inpaint(prompt=prompt, init_image=init_image, mask_image=mask_image, strength=0.75).images
```
As shown above this one pipeline can run all both "text-to-image", "image-to-image", and "inpainting" in one pipeline.
@@ -412,7 +411,7 @@ pipe = DiffusionPipeline.from_pretrained(
custom_pipeline="imagic_stable_diffusion",
scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
).to(device)
generator = torch.Generator("cuda").manual_seed(0)
generator = th.Generator("cuda").manual_seed(0)
seed = 0
prompt = "A photo of Barack Obama smiling with a big grin"
url = 'https://www.dropbox.com/s/6tlwzr73jd1r9yk/obama.png?dl=1'
@@ -421,16 +420,18 @@ init_image = Image.open(BytesIO(response.content)).convert("RGB")
init_image = init_image.resize((512, 512))
res = pipe.train(
prompt,
image=init_image,
init_image,
guidance_scale=7.5,
num_inference_steps=50,
generator=generator)
res = pipe(alpha=1, guidance_scale=7.5, num_inference_steps=50)
res = pipe(alpha=1)
os.makedirs("imagic", exist_ok=True)
image = res.images[0]
image.save('./imagic/imagic_image_alpha_1.png')
res = pipe(alpha=1.5, guidance_scale=7.5, num_inference_steps=50)
res = pipe(alpha=1.5)
image = res.images[0]
image.save('./imagic/imagic_image_alpha_1_5.png')
res = pipe(alpha=2, guidance_scale=7.5, num_inference_steps=50)
res = pipe(alpha=2)
image = res.images[0]
image.save('./imagic/imagic_image_alpha_2.png')
```
@@ -601,7 +602,7 @@ For example, this could be used to place a logo on a shirt and make it blend sea
import PIL
import torch
from diffusers import DiffusionPipeline
from diffusers import StableDiffusionInpaintPipeline
image_path = "./path-to-image.png"
inner_image_path = "./path-to-inner-image.png"
@@ -611,11 +612,10 @@ init_image = PIL.Image.open(image_path).convert("RGB").resize((512, 512))
inner_image = PIL.Image.open(inner_image_path).convert("RGBA").resize((512, 512))
mask_image = PIL.Image.open(mask_path).convert("RGB").resize((512, 512))
pipe = DiffusionPipeline.from_pretrained(
pipe = StableDiffusionInpaintPipeline.from_pretrained(
"runwayml/stable-diffusion-inpainting",
custom_pipeline="img2img_inpainting",
revision="fp16",
torch_dtype=torch.float16
torch_dtype=torch.float16,
)
pipe = pipe.to("cuda")
@@ -623,8 +623,6 @@ prompt = "Your prompt here!"
image = pipe(prompt=prompt, image=init_image, inner_image=inner_image, mask_image=mask_image).images[0]
```
![2 by 2 grid demonstrating image to image inpainting.](https://user-images.githubusercontent.com/44398246/203506577-ec303be4-887e-4ebd-a773-c83fcb3dd01a.png)
### Text Based Inpainting Stable Diffusion
Use a text prompt to generate the mask for the area to be inpainted.
@@ -686,7 +684,7 @@ pipe = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", custom
pipe = pipe.to("cuda")
prompt = "an astronaut riding a horse on mars"
pipe.set_scheduler("sample_heun")
pipe.set_sampler("sample_heun")
generator = torch.Generator(device="cuda").manual_seed(seed)
image = pipe(prompt, generator=generator, num_inference_steps=20).images[0]
@@ -721,56 +719,10 @@ pipe = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", custom
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
pipe = pipe.to("cuda")
pipe.set_scheduler("sample_euler")
pipe.set_sampler("sample_euler")
generator = torch.Generator(device="cuda").manual_seed(seed)
image = pipe(prompt, generator=generator, num_inference_steps=50).images[0]
```
![diffusers_euler](https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/k_diffusion/astronaut_euler_k_diffusion.png)
### Checkpoint Merger Pipeline
Based on the AUTOMATIC1111/webui for checkpoint merging. This is a custom pipeline that merges upto 3 pretrained model checkpoints as long as they are in the HuggingFace model_index.json format.
The checkpoint merging is currently memory intensive as it modifies the weights of a DiffusionPipeline object in place. Expect atleast 13GB RAM Usage on Kaggle GPU kernels and
on colab you might run out of the 12GB memory even while merging two checkpoints.
Usage:-
```python
from diffusers import DiffusionPipeline
#Return a CheckpointMergerPipeline class that allows you to merge checkpoints.
#The checkpoint passed here is ignored. But still pass one of the checkpoints you plan to
#merge for convenience
pipe = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", custom_pipeline="checkpoint_merger")
#There are multiple possible scenarios:
#The pipeline with the merged checkpoints is returned in all the scenarios
#Compatible checkpoints a.k.a matched model_index.json files. Ignores the meta attributes in model_index.json during comparision.( attrs with _ as prefix )
merged_pipe = pipe.merge(["CompVis/stable-diffusion-v1-4","CompVis/stable-diffusion-v1-2"], interp = "sigmoid", alpha = 0.4)
#Incompatible checkpoints in model_index.json but merge might be possible. Use force = True to ignore model_index.json compatibility
merged_pipe_1 = pipe.merge(["CompVis/stable-diffusion-v1-4","hakurei/waifu-diffusion"], force = True, interp = "sigmoid", alpha = 0.4)
#Three checkpoint merging. Only "add_difference" method actually works on all three checkpoints. Using any other options will ignore the 3rd checkpoint.
merged_pipe_2 = pipe.merge(["CompVis/stable-diffusion-v1-4","hakurei/waifu-diffusion","prompthero/openjourney"], force = True, interp = "add_difference", alpha = 0.4)
prompt = "An astronaut riding a horse on Mars"
image = merged_pipe(prompt).images[0]
```
Some examples along with the merge details:
1. "CompVis/stable-diffusion-v1-4" + "hakurei/waifu-diffusion" ; Sigmoid interpolation; alpha = 0.8
![Stable plus Waifu Sigmoid 0.8](https://huggingface.co/datasets/NagaSaiAbhinay/CheckpointMergerSamples/resolve/main/stability_v1_4_waifu_sig_0.8.png)
2. "hakurei/waifu-diffusion" + "prompthero/openjourney" ; Inverse Sigmoid interpolation; alpha = 0.8
![Stable plus Waifu Sigmoid 0.8](https://huggingface.co/datasets/NagaSaiAbhinay/CheckpointMergerSamples/resolve/main/waifu_openjourney_inv_sig_0.8.png)
3. "CompVis/stable-diffusion-v1-4" + "hakurei/waifu-diffusion" + "prompthero/openjourney"; Add Difference interpolation; alpha = 0.5
![Stable plus Waifu plus openjourney add_diff 0.5](https://huggingface.co/datasets/NagaSaiAbhinay/CheckpointMergerSamples/resolve/main/stable_waifu_openjourney_add_diff_0.5.png)

View File

@@ -138,7 +138,7 @@ def ddpm_bit_scheduler_step(
model_output: torch.FloatTensor,
timestep: int,
sample: torch.FloatTensor,
prediction_type="epsilon",
predict_epsilon=True,
generator=None,
return_dict: bool = True,
) -> Union[DDPMSchedulerOutput, Tuple]:
@@ -150,8 +150,8 @@ def ddpm_bit_scheduler_step(
timestep (`int`): current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
current instance of sample being created by diffusion process.
prediction_type (`str`, default `epsilon`):
indicates whether the model predicts the noise (epsilon), or the samples (`sample`).
predict_epsilon (`bool`):
optional flag to use when model predicts the samples directly instead of the noise, epsilon.
generator: random number generator.
return_dict (`bool`): option for returning tuple rather than DDPMSchedulerOutput class
Returns:
@@ -174,12 +174,10 @@ def ddpm_bit_scheduler_step(
# 2. compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
if prediction_type == "epsilon":
if predict_epsilon:
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
elif prediction_type == "sample":
pred_original_sample = model_output
else:
raise ValueError(f"Unsupported prediction_type {prediction_type}.")
pred_original_sample = model_output
# 3. Clip "predicted x_0"
scale = self.bit_scale

View File

@@ -1,262 +0,0 @@
import glob
import os
from typing import Dict, List, Union
import torch
from diffusers import DiffusionPipeline, __version__
from diffusers.pipeline_utils import (
CONFIG_NAME,
DIFFUSERS_CACHE,
ONNX_WEIGHTS_NAME,
SCHEDULER_CONFIG_NAME,
WEIGHTS_NAME,
)
from huggingface_hub import snapshot_download
class CheckpointMergerPipeline(DiffusionPipeline):
"""
A class that that supports merging diffusion models based on the discussion here:
https://github.com/huggingface/diffusers/issues/877
Example usage:-
pipe = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", custom_pipeline="checkpoint_merger.py")
merged_pipe = pipe.merge(["CompVis/stable-diffusion-v1-4","prompthero/openjourney"], interp = 'inv_sigmoid', alpha = 0.8, force = True)
merged_pipe.to('cuda')
prompt = "An astronaut riding a unicycle on Mars"
results = merged_pipe(prompt)
## For more details, see the docstring for the merge method.
"""
def __init__(self):
super().__init__()
def _compare_model_configs(self, dict0, dict1):
if dict0 == dict1:
return True
else:
config0, meta_keys0 = self._remove_meta_keys(dict0)
config1, meta_keys1 = self._remove_meta_keys(dict1)
if config0 == config1:
print(f"Warning !: Mismatch in keys {meta_keys0} and {meta_keys1}.")
return True
return False
def _remove_meta_keys(self, config_dict: Dict):
meta_keys = []
temp_dict = config_dict.copy()
for key in config_dict.keys():
if key.startswith("_"):
temp_dict.pop(key)
meta_keys.append(key)
return (temp_dict, meta_keys)
@torch.no_grad()
def merge(self, pretrained_model_name_or_path_list: List[Union[str, os.PathLike]], **kwargs):
"""
Returns a new pipeline object of the class 'DiffusionPipeline' with the merged checkpoints(weights) of the models passed
in the argument 'pretrained_model_name_or_path_list' as a list.
Parameters:
-----------
pretrained_model_name_or_path_list : A list of valid pretrained model names in the HuggingFace hub or paths to locally stored models in the HuggingFace format.
**kwargs:
Supports all the default DiffusionPipeline.get_config_dict kwargs viz..
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map.
alpha - The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha
would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2
interp - The interpolation method to use for the merging. Supports "sigmoid", "inv_sigmoid", "add_difference" and None.
Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported.
force - Whether to ignore mismatch in model_config.json for the current models. Defaults to False.
"""
# Default kwargs from DiffusionPipeline
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
resume_download = kwargs.pop("resume_download", False)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", False)
use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None)
torch_dtype = kwargs.pop("torch_dtype", None)
device_map = kwargs.pop("device_map", None)
alpha = kwargs.pop("alpha", 0.5)
interp = kwargs.pop("interp", None)
print("Recieved list", pretrained_model_name_or_path_list)
checkpoint_count = len(pretrained_model_name_or_path_list)
# Ignore result from model_index_json comparision of the two checkpoints
force = kwargs.pop("force", False)
# If less than 2 checkpoints, nothing to merge. If more than 3, not supported for now.
if checkpoint_count > 3 or checkpoint_count < 2:
raise ValueError(
"Received incorrect number of checkpoints to merge. Ensure that either 2 or 3 checkpoints are being"
" passed."
)
print("Received the right number of checkpoints")
# chkpt0, chkpt1 = pretrained_model_name_or_path_list[0:2]
# chkpt2 = pretrained_model_name_or_path_list[2] if checkpoint_count == 3 else None
# Validate that the checkpoints can be merged
# Step 1: Load the model config and compare the checkpoints. We'll compare the model_index.json first while ignoring the keys starting with '_'
config_dicts = []
for pretrained_model_name_or_path in pretrained_model_name_or_path_list:
if not os.path.isdir(pretrained_model_name_or_path):
config_dict = DiffusionPipeline.get_config_dict(
pretrained_model_name_or_path,
cache_dir=cache_dir,
resume_download=resume_download,
force_download=force_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
)
config_dicts.append(config_dict)
comparison_result = True
for idx in range(1, len(config_dicts)):
comparison_result &= self._compare_model_configs(config_dicts[idx - 1], config_dicts[idx])
if not force and comparison_result is False:
raise ValueError("Incompatible checkpoints. Please check model_index.json for the models.")
print(config_dicts[0], config_dicts[1])
print("Compatible model_index.json files found")
# Step 2: Basic Validation has succeeded. Let's download the models and save them into our local files.
cached_folders = []
for pretrained_model_name_or_path, config_dict in zip(pretrained_model_name_or_path_list, config_dicts):
folder_names = [k for k in config_dict.keys() if not k.startswith("_")]
allow_patterns = [os.path.join(k, "*") for k in folder_names]
allow_patterns += [
WEIGHTS_NAME,
SCHEDULER_CONFIG_NAME,
CONFIG_NAME,
ONNX_WEIGHTS_NAME,
DiffusionPipeline.config_name,
]
requested_pipeline_class = config_dict.get("_class_name")
user_agent = {"diffusers": __version__, "pipeline_class": requested_pipeline_class}
cached_folder = snapshot_download(
pretrained_model_name_or_path,
cache_dir=cache_dir,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
allow_patterns=allow_patterns,
user_agent=user_agent,
)
print("Cached Folder", cached_folder)
cached_folders.append(cached_folder)
# Step 3:-
# Load the first checkpoint as a diffusion pipeline and modify it's module state_dict in place
final_pipe = DiffusionPipeline.from_pretrained(
cached_folders[0], torch_dtype=torch_dtype, device_map=device_map
)
checkpoint_path_2 = None
if len(cached_folders) > 2:
checkpoint_path_2 = os.path.join(cached_folders[2])
if interp == "sigmoid":
theta_func = CheckpointMergerPipeline.sigmoid
elif interp == "inv_sigmoid":
theta_func = CheckpointMergerPipeline.inv_sigmoid
elif interp == "add_diff":
theta_func = CheckpointMergerPipeline.add_difference
else:
theta_func = CheckpointMergerPipeline.weighted_sum
# Find each module's state dict.
for attr in final_pipe.config.keys():
if not attr.startswith("_"):
checkpoint_path_1 = os.path.join(cached_folders[1], attr)
if os.path.exists(checkpoint_path_1):
files = glob.glob(os.path.join(checkpoint_path_1, "*.bin"))
checkpoint_path_1 = files[0] if len(files) > 0 else None
if checkpoint_path_2 is not None and os.path.exists(checkpoint_path_2):
files = glob.glob(os.path.join(checkpoint_path_2, "*.bin"))
checkpoint_path_2 = files[0] if len(files) > 0 else None
# For an attr if both checkpoint_path_1 and 2 are None, ignore.
# If atleast one is present, deal with it according to interp method, of course only if the state_dict keys match.
if checkpoint_path_1 is None and checkpoint_path_2 is None:
print("SKIPPING ATTR ", attr)
continue
try:
module = getattr(final_pipe, attr)
theta_0 = getattr(module, "state_dict")
theta_0 = theta_0()
update_theta_0 = getattr(module, "load_state_dict")
theta_1 = torch.load(checkpoint_path_1)
theta_2 = torch.load(checkpoint_path_2) if checkpoint_path_2 else None
if not theta_0.keys() == theta_1.keys():
print("SKIPPING ATTR ", attr, " DUE TO MISMATCH")
continue
if theta_2 and not theta_1.keys() == theta_2.keys():
print("SKIPPING ATTR ", attr, " DUE TO MISMATCH")
except:
print("SKIPPING ATTR ", attr)
continue
print("Found dicts for")
print(attr)
print(checkpoint_path_1)
print(checkpoint_path_2)
for key in theta_0.keys():
if theta_2:
theta_0[key] = theta_func(theta_0[key], theta_1[key], theta_2[key], alpha)
else:
theta_0[key] = theta_func(theta_0[key], theta_1[key], None, alpha)
del theta_1
del theta_2
update_theta_0(theta_0)
del theta_0
print("Diffusion pipeline successfully updated with merged weights")
return final_pipe
@staticmethod
def weighted_sum(theta0, theta1, theta2, alpha):
return ((1 - alpha) * theta0) + (alpha * theta1)
# Smoothstep (https://en.wikipedia.org/wiki/Smoothstep)
@staticmethod
def sigmoid(theta0, theta1, theta2, alpha):
alpha = alpha * alpha * (3 - (2 * alpha))
return theta0 + ((theta1 - theta0) * alpha)
# Inverse Smoothstep (https://en.wikipedia.org/wiki/Smoothstep)
@staticmethod
def inv_sigmoid(theta0, theta1, theta2, alpha):
import math
alpha = 0.5 - math.sin(math.asin(1.0 - 2.0 * alpha) / 3.0)
return theta0 + ((theta1 - theta0) * alpha)
@staticmethod
def add_difference(theta0, theta1, theta2, alpha):
return theta0 + (theta1 - theta2) * (1.0 - alpha)

View File

@@ -78,12 +78,7 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline):
)
self.normalize = transforms.Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
cut_out_size = (
feature_extractor.size
if isinstance(feature_extractor.size, int)
else feature_extractor.size["shortest_edge"]
)
self.make_cutouts = MakeCutouts(cut_out_size)
self.make_cutouts = MakeCutouts(feature_extractor.size)
set_requires_grad(self.text_encoder, False)
set_requires_grad(self.clip_model, False)

View File

@@ -17,7 +17,7 @@ from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from diffusers.utils import deprecate, logging
from diffusers.utils import logging
# TODO: remove and import from diffusers.utils when the new version of diffusers is released
from packaging import version
@@ -133,7 +133,7 @@ class ImagicStableDiffusionPipeline(DiffusionPipeline):
def train(
self,
prompt: Union[str, List[str]],
image: Union[torch.FloatTensor, PIL.Image.Image],
init_image: Union[torch.FloatTensor, PIL.Image.Image],
height: Optional[int] = 512,
width: Optional[int] = 512,
generator: Optional[torch.Generator] = None,
@@ -184,10 +184,6 @@ class ImagicStableDiffusionPipeline(DiffusionPipeline):
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`.
"""
message = "Please use `image` instead of `init_image`."
init_image = deprecate("init_image", "0.12.0", message, take_from=kwargs)
image = init_image or image
accelerator = Accelerator(
gradient_accumulation_steps=1,
mixed_precision="fp16",
@@ -245,14 +241,14 @@ class ImagicStableDiffusionPipeline(DiffusionPipeline):
lr=embedding_learning_rate,
)
if isinstance(image, PIL.Image.Image):
image = preprocess(image)
if isinstance(init_image, PIL.Image.Image):
init_image = preprocess(init_image)
latents_dtype = text_embeddings.dtype
image = image.to(device=self.device, dtype=latents_dtype)
init_latent_image_dist = self.vae.encode(image).latent_dist
image_latents = init_latent_image_dist.sample(generator=generator)
image_latents = 0.18215 * image_latents
init_image = init_image.to(device=self.device, dtype=latents_dtype)
init_latent_image_dist = self.vae.encode(init_image).latent_dist
init_image_latents = init_latent_image_dist.sample(generator=generator)
init_image_latents = 0.18215 * init_image_latents
progress_bar = tqdm(range(text_embedding_optimization_steps), disable=not accelerator.is_local_main_process)
progress_bar.set_description("Steps")
@@ -263,12 +259,12 @@ class ImagicStableDiffusionPipeline(DiffusionPipeline):
for _ in range(text_embedding_optimization_steps):
with accelerator.accumulate(text_embeddings):
# Sample noise that we'll add to the latents
noise = torch.randn(image_latents.shape).to(image_latents.device)
timesteps = torch.randint(1000, (1,), device=image_latents.device)
noise = torch.randn(init_image_latents.shape).to(init_image_latents.device)
timesteps = torch.randint(1000, (1,), device=init_image_latents.device)
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = self.scheduler.add_noise(image_latents, noise, timesteps)
noisy_latents = self.scheduler.add_noise(init_image_latents, noise, timesteps)
# Predict the noise residual
noise_pred = self.unet(noisy_latents, timesteps, text_embeddings).sample
@@ -305,12 +301,12 @@ class ImagicStableDiffusionPipeline(DiffusionPipeline):
for _ in range(model_fine_tuning_optimization_steps):
with accelerator.accumulate(self.unet.parameters()):
# Sample noise that we'll add to the latents
noise = torch.randn(image_latents.shape).to(image_latents.device)
timesteps = torch.randint(1000, (1,), device=image_latents.device)
noise = torch.randn(init_image_latents.shape).to(init_image_latents.device)
timesteps = torch.randint(1000, (1,), device=init_image_latents.device)
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = self.scheduler.add_noise(image_latents, noise, timesteps)
noisy_latents = self.scheduler.add_noise(init_image_latents, noise, timesteps)
# Predict the noise residual
noise_pred = self.unet(noisy_latents, timesteps, text_embeddings).sample

View File

@@ -110,7 +110,7 @@ class ImageToImageInpaintingPipeline(DiffusionPipeline):
scheduler._internal_dict = FrozenDict(new_config)
if safety_checker is None:
logger.warning(
logger.warn(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"

View File

@@ -101,7 +101,7 @@ class StableDiffusionWalkPipeline(DiffusionPipeline):
scheduler._internal_dict = FrozenDict(new_config)
if safety_checker is None:
logger.warning(
logger.warn(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"

View File

@@ -6,13 +6,38 @@ import numpy as np
import torch
import PIL
from diffusers import SchedulerMixin, StableDiffusionPipeline
from diffusers.configuration_utils import FrozenDict
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
from diffusers.utils import PIL_INTERPOLATION, deprecate, logging
from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from diffusers.utils import deprecate, is_accelerate_available, logging
# TODO: remove and import from diffusers.utils when the new version of diffusers is released
from packaging import version
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
PIL_INTERPOLATION = {
"linear": PIL.Image.Resampling.BILINEAR,
"bilinear": PIL.Image.Resampling.BILINEAR,
"bicubic": PIL.Image.Resampling.BICUBIC,
"lanczos": PIL.Image.Resampling.LANCZOS,
"nearest": PIL.Image.Resampling.NEAREST,
}
else:
PIL_INTERPOLATION = {
"linear": PIL.Image.LINEAR,
"bilinear": PIL.Image.BILINEAR,
"bicubic": PIL.Image.BICUBIC,
"lanczos": PIL.Image.LANCZOS,
"nearest": PIL.Image.NEAREST,
}
# ------------------------------------------------------------------------------
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
re_attention = re.compile(
@@ -121,7 +146,7 @@ def parse_prompt_attention(text):
return res
def get_prompts_with_weights(pipe: StableDiffusionPipeline, prompt: List[str], max_length: int):
def get_prompts_with_weights(pipe: DiffusionPipeline, prompt: List[str], max_length: int):
r"""
Tokenize a list of prompts and return its tokens with weights of each token.
@@ -182,7 +207,7 @@ def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_midd
def get_unweighted_text_embeddings(
pipe: StableDiffusionPipeline,
pipe: DiffusionPipeline,
text_input: torch.Tensor,
chunk_length: int,
no_boseos_middle: Optional[bool] = True,
@@ -222,10 +247,10 @@ def get_unweighted_text_embeddings(
def get_weighted_text_embeddings(
pipe: StableDiffusionPipeline,
pipe: DiffusionPipeline,
prompt: Union[str, List[str]],
uncond_prompt: Optional[Union[str, List[str]]] = None,
max_embeddings_multiples: Optional[int] = 3,
max_embeddings_multiples: Optional[int] = 1,
no_boseos_middle: Optional[bool] = False,
skip_parsing: Optional[bool] = False,
skip_weighting: Optional[bool] = False,
@@ -239,14 +264,14 @@ def get_weighted_text_embeddings(
Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean.
Args:
pipe (`StableDiffusionPipeline`):
pipe (`DiffusionPipeline`):
Pipe to provide access to the tokenizer and the text encoder.
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
uncond_prompt (`str` or `List[str]`):
The unconditional prompt or prompts for guide the image generation. If unconditional prompt
is provided, the embeddings of prompt and uncond_prompt are concatenated.
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
max_embeddings_multiples (`int`, *optional*, defaults to `1`):
The max multiple length of prompt embeddings compared to the max output length of text encoder.
no_boseos_middle (`bool`, *optional*, defaults to `False`):
If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and
@@ -362,11 +387,11 @@ def preprocess_image(image):
return 2.0 * image - 1.0
def preprocess_mask(mask, scale_factor=8):
def preprocess_mask(mask):
mask = mask.convert("L")
w, h = mask.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"])
mask = mask.resize((w // 8, h // 8), resample=PIL_INTERPOLATION["nearest"])
mask = np.array(mask).astype(np.float32) / 255.0
mask = np.tile(mask, (4, 1, 1))
mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
@@ -375,7 +400,7 @@ def preprocess_mask(mask, scale_factor=8):
return mask
class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
r"""
Pipeline for text-to-image generation using Stable Diffusion without tokens length limit, and support parsing
weighting in prompt.
@@ -410,12 +435,50 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
scheduler: SchedulerMixin,
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
requires_safety_checker: bool = True,
):
super().__init__(
super().__init__()
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
" file"
)
deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(scheduler.config)
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
)
deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(scheduler.config)
new_config["clip_sample"] = False
scheduler._internal_dict = FrozenDict(new_config)
if safety_checker is None:
logger.warn(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
)
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
@@ -423,178 +486,76 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
requires_safety_checker=requires_safety_checker,
)
def _encode_prompt(
self,
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt,
max_embeddings_multiples,
):
def enable_xformers_memory_efficient_attention(self):
r"""
Encodes the prompt into text encoder hidden states.
Enable memory efficient attention as implemented in xformers.
When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
time. Speed up at training time is not guaranteed.
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
is used.
"""
self.unet.set_use_memory_efficient_attention_xformers(True)
def disable_xformers_memory_efficient_attention(self):
r"""
Disable memory efficient attention as implemented in xformers.
"""
self.unet.set_use_memory_efficient_attention_xformers(False)
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
r"""
Enable sliced attention computation.
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
in several steps. This is useful to save some memory in exchange for a small speed decrease.
Args:
prompt (`str` or `list(int)`):
prompt to be encoded
device: (`torch.device`):
torch device
num_images_per_prompt (`int`):
number of images that should be generated per prompt
do_classifier_free_guidance (`bool`):
whether to use classifier free guidance or not
negative_prompt (`str` or `List[str]`):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
The max multiple length of prompt embeddings compared to the max output length of text encoder.
slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
`attention_head_dim` must be a multiple of `slice_size`.
"""
batch_size = len(prompt) if isinstance(prompt, list) else 1
if slice_size == "auto":
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = self.unet.config.attention_head_dim // 2
self.unet.set_attention_slice(slice_size)
if negative_prompt is None:
negative_prompt = [""] * batch_size
elif isinstance(negative_prompt, str):
negative_prompt = [negative_prompt] * batch_size
if batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
def disable_attention_slicing(self):
r"""
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
back to computing attention in one step.
"""
# set slice_size = `None` to disable `attention slicing`
self.enable_attention_slicing(None)
text_embeddings, uncond_embeddings = get_weighted_text_embeddings(
pipe=self,
prompt=prompt,
uncond_prompt=negative_prompt if do_classifier_free_guidance else None,
max_embeddings_multiples=max_embeddings_multiples,
)
bs_embed, seq_len, _ = text_embeddings.shape
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
if do_classifier_free_guidance:
bs_embed, seq_len, _ = uncond_embeddings.shape
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
uncond_embeddings = uncond_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
return text_embeddings
def check_inputs(self, prompt, height, width, strength, callback_steps):
if not isinstance(prompt, str) and not isinstance(prompt, list):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
def get_timesteps(self, num_inference_steps, strength, device, is_text2img):
if is_text2img:
return self.scheduler.timesteps.to(device), num_inference_steps
def enable_sequential_cpu_offload(self):
r"""
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
`torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
"""
if is_accelerate_available():
from accelerate import cpu_offload
else:
# get the original timestep using init_timestep
offset = self.scheduler.config.get("steps_offset", 0)
init_timestep = int(num_inference_steps * strength) + offset
init_timestep = min(init_timestep, num_inference_steps)
raise ImportError("Please install accelerate via `pip install accelerate`")
t_start = max(num_inference_steps - init_timestep + offset, 0)
timesteps = self.scheduler.timesteps[t_start:].to(device)
return timesteps, num_inference_steps - t_start
device = self.device
def run_safety_checker(self, image, device, dtype):
if self.safety_checker is not None:
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
)
else:
has_nsfw_concept = None
return image, has_nsfw_concept
def decode_latents(self, latents):
latents = 1 / 0.18215 * latents
image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
return image
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
# check if the scheduler accepts generator
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
if accepts_generator:
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
def prepare_latents(self, image, timestep, batch_size, height, width, dtype, device, generator, latents=None):
if image is None:
shape = (
batch_size,
self.unet.in_channels,
height // self.vae_scale_factor,
width // self.vae_scale_factor,
)
if latents is None:
if device.type == "mps":
# randn does not work reproducibly on mps
latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
else:
latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
else:
if latents.shape != shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
latents = latents.to(device)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
return latents, None, None
else:
init_latent_dist = self.vae.encode(image).latent_dist
init_latents = init_latent_dist.sample(generator=generator)
init_latents = 0.18215 * init_latents
init_latents = torch.cat([init_latents] * batch_size, dim=0)
init_latents_orig = init_latents
shape = init_latents.shape
# add noise to latents using the timesteps
if device.type == "mps":
noise = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
else:
noise = torch.randn(shape, generator=generator, device=device, dtype=dtype)
latents = self.scheduler.add_noise(init_latents, noise, timestep)
return latents, init_latents_orig, noise
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]:
if cpu_offloaded_model is not None:
cpu_offload(cpu_offloaded_model, device)
@torch.no_grad()
def __call__(
self,
prompt: Union[str, List[str]],
negative_prompt: Optional[Union[str, List[str]]] = None,
image: Union[torch.FloatTensor, PIL.Image.Image] = None,
init_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
height: int = 512,
width: int = 512,
@@ -622,11 +583,11 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
image (`torch.FloatTensor` or `PIL.Image.Image`):
init_image (`torch.FloatTensor` or `PIL.Image.Image`):
`Image`, or tensor representing an image batch, that will be used as the starting point for the
process.
mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
`Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
`Image`, or tensor representing an image batch, to mask `init_image`. White pixels in the mask will be
replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
@@ -644,11 +605,11 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
strength (`float`, *optional*, defaults to 0.8):
Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1.
`image` will be used as a starting point, adding more noise to it the larger the `strength`. The
Conceptually, indicates how much to transform the reference `init_image`. Must be between 0 and 1.
`init_image` will be used as a starting point, adding more noise to it the larger the `strength`. The
number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added
noise will be maximum and the denoising process will run for the full number of iterations specified in
`num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
`num_inference_steps`. A value of 1, therefore, essentially ignores `init_image`.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
@@ -687,115 +648,222 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`.
"""
message = "Please use `image` instead of `init_image`."
init_image = deprecate("init_image", "0.12.0", message, take_from=kwargs)
image = init_image or image
# 0. Default height and width to unet
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor
if isinstance(prompt, str):
batch_size = 1
prompt = [prompt]
elif isinstance(prompt, list):
batch_size = len(prompt)
else:
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
# 1. Check inputs. Raise error if not correct
self.check_inputs(prompt, height, width, strength, callback_steps)
if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
# get prompt text embeddings
# 2. Define call parameters
batch_size = 1 if isinstance(prompt, str) else len(prompt)
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# get unconditional embeddings for classifier free guidance
if negative_prompt is None:
negative_prompt = [""] * batch_size
elif isinstance(negative_prompt, str):
negative_prompt = [negative_prompt] * batch_size
if batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
# 3. Encode input prompt
text_embeddings = self._encode_prompt(
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt,
max_embeddings_multiples,
text_embeddings, uncond_embeddings = get_weighted_text_embeddings(
pipe=self,
prompt=prompt,
uncond_prompt=negative_prompt if do_classifier_free_guidance else None,
max_embeddings_multiples=max_embeddings_multiples,
**kwargs,
)
dtype = text_embeddings.dtype
bs_embed, seq_len, _ = text_embeddings.shape
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
# 4. Preprocess image and mask
if isinstance(image, PIL.Image.Image):
image = preprocess_image(image)
if image is not None:
image = image.to(device=self.device, dtype=dtype)
if isinstance(mask_image, PIL.Image.Image):
mask_image = preprocess_mask(mask_image, self.vae_scale_factor)
if mask_image is not None:
mask = mask_image.to(device=self.device, dtype=dtype)
mask = torch.cat([mask] * batch_size * num_images_per_prompt)
if do_classifier_free_guidance:
bs_embed, seq_len, _ = uncond_embeddings.shape
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
uncond_embeddings = uncond_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
# set timesteps
self.scheduler.set_timesteps(num_inference_steps)
latents_dtype = text_embeddings.dtype
init_latents_orig = None
mask = None
noise = None
if init_image is None:
# get the initial random noise unless the user supplied it
# Unlike in other pipelines, latents need to be generated in the target device
# for 1-to-1 results reproducibility with the CompVis implementation.
# However this currently doesn't work in `mps`.
latents_shape = (
batch_size * num_images_per_prompt,
self.unet.in_channels,
height // 8,
width // 8,
)
if latents is None:
if self.device.type == "mps":
# randn does not exist on mps
latents = torch.randn(
latents_shape,
generator=generator,
device="cpu",
dtype=latents_dtype,
).to(self.device)
else:
latents = torch.randn(
latents_shape,
generator=generator,
device=self.device,
dtype=latents_dtype,
)
else:
if latents.shape != latents_shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
latents = latents.to(self.device)
timesteps = self.scheduler.timesteps.to(self.device)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
else:
mask = None
if isinstance(init_image, PIL.Image.Image):
init_image = preprocess_image(init_image)
# encode the init image into latents and scale the latents
init_image = init_image.to(device=self.device, dtype=latents_dtype)
init_latent_dist = self.vae.encode(init_image).latent_dist
init_latents = init_latent_dist.sample(generator=generator)
init_latents = 0.18215 * init_latents
init_latents = torch.cat([init_latents] * batch_size * num_images_per_prompt, dim=0)
init_latents_orig = init_latents
# 5. set timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device, image is None)
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
# preprocess mask
if mask_image is not None:
if isinstance(mask_image, PIL.Image.Image):
mask_image = preprocess_mask(mask_image)
mask_image = mask_image.to(device=self.device, dtype=latents_dtype)
mask = torch.cat([mask_image] * batch_size * num_images_per_prompt)
# 6. Prepare latent variables
latents, init_latents_orig, noise = self.prepare_latents(
image,
latent_timestep,
batch_size * num_images_per_prompt,
height,
width,
dtype,
device,
generator,
latents,
)
# check sizes
if not mask.shape == init_latents.shape:
raise ValueError("The mask and init_image should be the same size!")
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# get the original timestep using init_timestep
offset = self.scheduler.config.get("steps_offset", 0)
init_timestep = int(num_inference_steps * strength) + offset
init_timestep = min(init_timestep, num_inference_steps)
# 8. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
timesteps = self.scheduler.timesteps[-init_timestep]
timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt, device=self.device)
# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
# add noise to latents using the timesteps
if self.device.type == "mps":
# randn does not exist on mps
noise = torch.randn(
init_latents.shape,
generator=generator,
device="cpu",
dtype=latents_dtype,
).to(self.device)
else:
noise = torch.randn(
init_latents.shape,
generator=generator,
device=self.device,
dtype=latents_dtype,
)
latents = self.scheduler.add_noise(init_latents, noise, timesteps)
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
t_start = max(num_inference_steps - init_timestep + offset, 0)
timesteps = self.scheduler.timesteps[t_start:].to(self.device)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
if mask is not None:
# masking
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
latents = (init_latents_proper * mask) + (latents * (1 - mask))
for i, t in enumerate(self.progress_bar(timesteps)):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if i % callback_steps == 0:
if callback is not None:
callback(i, t, latents)
if is_cancelled_callback is not None and is_cancelled_callback():
return None
# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
# 9. Post-processing
image = self.decode_latents(latents)
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# 10. Run safety checker
image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
if mask is not None:
# masking
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
latents = (init_latents_proper * mask) + (latents * (1 - mask))
# call the callback, if provided
if i % callback_steps == 0:
if callback is not None:
callback(i, t, latents)
if is_cancelled_callback is not None and is_cancelled_callback():
return None
latents = 1 / 0.18215 * latents
image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
if self.safety_checker is not None:
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(
self.device
)
image, has_nsfw_concept = self.safety_checker(
images=image,
clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype),
)
else:
has_nsfw_concept = None
# 11. Convert to PIL
if output_type == "pil":
image = self.numpy_to_pil(image)
if not return_dict:
return image, has_nsfw_concept
return (image, has_nsfw_concept)
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
@@ -815,7 +883,6 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
is_cancelled_callback: Optional[Callable[[], bool]] = None,
callback_steps: Optional[int] = 1,
**kwargs,
):
@@ -863,9 +930,6 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
is_cancelled_callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. If the function returns
`True`, the inference will be cancelled.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
@@ -891,14 +955,13 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
output_type=output_type,
return_dict=return_dict,
callback=callback,
is_cancelled_callback=is_cancelled_callback,
callback_steps=callback_steps,
**kwargs,
)
def img2img(
self,
image: Union[torch.FloatTensor, PIL.Image.Image],
init_image: Union[torch.FloatTensor, PIL.Image.Image],
prompt: Union[str, List[str]],
negative_prompt: Optional[Union[str, List[str]]] = None,
strength: float = 0.8,
@@ -911,14 +974,13 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
is_cancelled_callback: Optional[Callable[[], bool]] = None,
callback_steps: Optional[int] = 1,
**kwargs,
):
r"""
Function for image-to-image generation.
Args:
image (`torch.FloatTensor` or `PIL.Image.Image`):
init_image (`torch.FloatTensor` or `PIL.Image.Image`):
`Image`, or tensor representing an image batch, that will be used as the starting point for the
process.
prompt (`str` or `List[str]`):
@@ -927,11 +989,11 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
strength (`float`, *optional*, defaults to 0.8):
Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1.
`image` will be used as a starting point, adding more noise to it the larger the `strength`. The
Conceptually, indicates how much to transform the reference `init_image`. Must be between 0 and 1.
`init_image` will be used as a starting point, adding more noise to it the larger the `strength`. The
number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added
noise will be maximum and the denoising process will run for the full number of iterations specified in
`num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
`num_inference_steps`. A value of 1, therefore, essentially ignores `init_image`.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference. This parameter will be modulated by `strength`.
@@ -960,9 +1022,6 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
is_cancelled_callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. If the function returns
`True`, the inference will be cancelled.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
@@ -976,7 +1035,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
return self.__call__(
prompt=prompt,
negative_prompt=negative_prompt,
image=image,
init_image=init_image,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
strength=strength,
@@ -987,14 +1046,13 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
output_type=output_type,
return_dict=return_dict,
callback=callback,
is_cancelled_callback=is_cancelled_callback,
callback_steps=callback_steps,
**kwargs,
)
def inpaint(
self,
image: Union[torch.FloatTensor, PIL.Image.Image],
init_image: Union[torch.FloatTensor, PIL.Image.Image],
mask_image: Union[torch.FloatTensor, PIL.Image.Image],
prompt: Union[str, List[str]],
negative_prompt: Optional[Union[str, List[str]]] = None,
@@ -1008,18 +1066,17 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
is_cancelled_callback: Optional[Callable[[], bool]] = None,
callback_steps: Optional[int] = 1,
**kwargs,
):
r"""
Function for inpaint.
Args:
image (`torch.FloatTensor` or `PIL.Image.Image`):
init_image (`torch.FloatTensor` or `PIL.Image.Image`):
`Image`, or tensor representing an image batch, that will be used as the starting point for the
process. This is the image whose masked region will be inpainted.
mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
`Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
`Image`, or tensor representing an image batch, to mask `init_image`. White pixels in the mask will be
replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
@@ -1031,7 +1088,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
strength (`float`, *optional*, defaults to 0.8):
Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength`
is 1, the denoising process will be run on the masked area for the full number of iterations specified
in `num_inference_steps`. `image` will be used as a reference for the masked area, adding more
in `num_inference_steps`. `init_image` will be used as a reference for the masked area, adding more
noise to that region the larger the `strength`. If `strength` is 0, no inpainting will occur.
num_inference_steps (`int`, *optional*, defaults to 50):
The reference number of denoising steps. More denoising steps usually lead to a higher quality image at
@@ -1061,9 +1118,6 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
is_cancelled_callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. If the function returns
`True`, the inference will be cancelled.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
@@ -1077,7 +1131,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
return self.__call__(
prompt=prompt,
negative_prompt=negative_prompt,
image=image,
init_image=init_image,
mask_image=mask_image,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
@@ -1089,7 +1143,6 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
output_type=output_type,
return_dict=return_dict,
callback=callback,
is_cancelled_callback=is_cancelled_callback,
callback_steps=callback_steps,
**kwargs,
)

View File

@@ -6,13 +6,35 @@ import numpy as np
import torch
import PIL
from diffusers import OnnxStableDiffusionPipeline, SchedulerMixin
from diffusers.onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel
from diffusers.onnx_utils import OnnxRuntimeModel
from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.utils import PIL_INTERPOLATION, deprecate, logging
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from diffusers.utils import logging
# TODO: remove and import from diffusers.utils when the new version of diffusers is released
from packaging import version
from transformers import CLIPFeatureExtractor, CLIPTokenizer
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
PIL_INTERPOLATION = {
"linear": PIL.Image.Resampling.BILINEAR,
"bilinear": PIL.Image.Resampling.BILINEAR,
"bicubic": PIL.Image.Resampling.BICUBIC,
"lanczos": PIL.Image.Resampling.LANCZOS,
"nearest": PIL.Image.Resampling.NEAREST,
}
else:
PIL_INTERPOLATION = {
"linear": PIL.Image.LINEAR,
"bilinear": PIL.Image.BILINEAR,
"bicubic": PIL.Image.BICUBIC,
"lanczos": PIL.Image.LANCZOS,
"nearest": PIL.Image.NEAREST,
}
# ------------------------------------------------------------------------------
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
re_attention = re.compile(
@@ -240,7 +262,7 @@ def get_weighted_text_embeddings(
Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean.
Args:
pipe (`OnnxStableDiffusionPipeline`):
pipe (`DiffusionPipeline`):
Pipe to provide access to the tokenizer and the text encoder.
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
@@ -370,11 +392,11 @@ def preprocess_image(image):
return 2.0 * image - 1.0
def preprocess_mask(mask, scale_factor=8):
def preprocess_mask(mask):
mask = mask.convert("L")
w, h = mask.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"])
mask = mask.resize((w // 8, h // 8), resample=PIL_INTERPOLATION["nearest"])
mask = np.array(mask).astype(np.float32) / 255.0
mask = np.tile(mask, (4, 1, 1))
mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
@@ -382,7 +404,7 @@ def preprocess_mask(mask, scale_factor=8):
return mask
class OnnxStableDiffusionLongPromptWeightingPipeline(OnnxStableDiffusionPipeline):
class OnnxStableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
r"""
Pipeline for text-to-image generation using Stable Diffusion without tokens length limit, and support parsing
weighting in prompt.
@@ -398,12 +420,12 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(OnnxStableDiffusionPipeline
text_encoder: OnnxRuntimeModel,
tokenizer: CLIPTokenizer,
unet: OnnxRuntimeModel,
scheduler: SchedulerMixin,
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
safety_checker: OnnxRuntimeModel,
feature_extractor: CLIPFeatureExtractor,
requires_safety_checker: bool = True,
):
super().__init__(
super().__init__()
self.register_modules(
vae_encoder=vae_encoder,
vae_decoder=vae_decoder,
text_encoder=text_encoder,
@@ -412,177 +434,14 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(OnnxStableDiffusionPipeline
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
requires_safety_checker=requires_safety_checker,
)
self.unet_in_channels = 4
self.vae_scale_factor = 8
def _encode_prompt(
self,
prompt,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt,
max_embeddings_multiples,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `list(int)`):
prompt to be encoded
num_images_per_prompt (`int`):
number of images that should be generated per prompt
do_classifier_free_guidance (`bool`):
whether to use classifier free guidance or not
negative_prompt (`str` or `List[str]`):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
The max multiple length of prompt embeddings compared to the max output length of text encoder.
"""
batch_size = len(prompt) if isinstance(prompt, list) else 1
if negative_prompt is None:
negative_prompt = [""] * batch_size
elif isinstance(negative_prompt, str):
negative_prompt = [negative_prompt] * batch_size
if batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
text_embeddings, uncond_embeddings = get_weighted_text_embeddings(
pipe=self,
prompt=prompt,
uncond_prompt=negative_prompt if do_classifier_free_guidance else None,
max_embeddings_multiples=max_embeddings_multiples,
)
text_embeddings = text_embeddings.repeat(num_images_per_prompt, 0)
if do_classifier_free_guidance:
uncond_embeddings = uncond_embeddings.repeat(num_images_per_prompt, 0)
text_embeddings = np.concatenate([uncond_embeddings, text_embeddings])
return text_embeddings
def check_inputs(self, prompt, height, width, strength, callback_steps):
if not isinstance(prompt, str) and not isinstance(prompt, list):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
def get_timesteps(self, num_inference_steps, strength, is_text2img):
if is_text2img:
return self.scheduler.timesteps, num_inference_steps
else:
# get the original timestep using init_timestep
offset = self.scheduler.config.get("steps_offset", 0)
init_timestep = int(num_inference_steps * strength) + offset
init_timestep = min(init_timestep, num_inference_steps)
t_start = max(num_inference_steps - init_timestep + offset, 0)
timesteps = self.scheduler.timesteps[t_start:]
return timesteps, num_inference_steps - t_start
def run_safety_checker(self, image):
if self.safety_checker is not None:
safety_checker_input = self.feature_extractor(
self.numpy_to_pil(image), return_tensors="np"
).pixel_values.astype(image.dtype)
# There will throw an error if use safety_checker directly and batchsize>1
images, has_nsfw_concept = [], []
for i in range(image.shape[0]):
image_i, has_nsfw_concept_i = self.safety_checker(
clip_input=safety_checker_input[i : i + 1], images=image[i : i + 1]
)
images.append(image_i)
has_nsfw_concept.append(has_nsfw_concept_i[0])
image = np.concatenate(images)
else:
has_nsfw_concept = None
return image, has_nsfw_concept
def decode_latents(self, latents):
latents = 1 / 0.18215 * latents
# image = self.vae_decoder(latent_sample=latents)[0]
# it seems likes there is a strange result for using half-precision vae decoder if batchsize>1
image = np.concatenate(
[self.vae_decoder(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])]
)
image = np.clip(image / 2 + 0.5, 0, 1)
image = image.transpose((0, 2, 3, 1))
return image
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
# check if the scheduler accepts generator
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
if accepts_generator:
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
def prepare_latents(self, image, timestep, batch_size, height, width, dtype, generator, latents=None):
if image is None:
shape = (
batch_size,
self.unet_in_channels,
height // self.vae_scale_factor,
width // self.vae_scale_factor,
)
if latents is None:
latents = torch.randn(shape, generator=generator, device="cpu").numpy().astype(dtype)
else:
if latents.shape != shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
# scale the initial noise by the standard deviation required by the scheduler
latents = (torch.from_numpy(latents) * self.scheduler.init_noise_sigma).numpy()
return latents, None, None
else:
init_latents = self.vae_encoder(sample=image)[0]
init_latents = 0.18215 * init_latents
init_latents = np.concatenate([init_latents] * batch_size, axis=0)
init_latents_orig = init_latents
shape = init_latents.shape
# add noise to latents using the timesteps
noise = torch.randn(shape, generator=generator, device="cpu").numpy().astype(dtype)
latents = self.scheduler.add_noise(
torch.from_numpy(init_latents), torch.from_numpy(noise), timestep
).numpy()
return latents, init_latents_orig, noise
@torch.no_grad()
def __call__(
self,
prompt: Union[str, List[str]],
negative_prompt: Optional[Union[str, List[str]]] = None,
image: Union[np.ndarray, PIL.Image.Image] = None,
init_image: Union[np.ndarray, PIL.Image.Image] = None,
mask_image: Union[np.ndarray, PIL.Image.Image] = None,
height: int = 512,
width: int = 512,
@@ -591,7 +450,7 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(OnnxStableDiffusionPipeline
strength: float = 0.8,
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[torch.Generator] = None,
generator: Optional[np.random.RandomState] = None,
latents: Optional[np.ndarray] = None,
max_embeddings_multiples: Optional[int] = 3,
output_type: Optional[str] = "pil",
@@ -610,11 +469,11 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(OnnxStableDiffusionPipeline
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
image (`np.ndarray` or `PIL.Image.Image`):
init_image (`np.ndarray` or `PIL.Image.Image`):
`Image`, or tensor representing an image batch, that will be used as the starting point for the
process.
mask_image (`np.ndarray` or `PIL.Image.Image`):
`Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
`Image`, or tensor representing an image batch, to mask `init_image`. White pixels in the mask will be
replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
@@ -632,19 +491,18 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(OnnxStableDiffusionPipeline
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
strength (`float`, *optional*, defaults to 0.8):
Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1.
`image` will be used as a starting point, adding more noise to it the larger the `strength`. The
Conceptually, indicates how much to transform the reference `init_image`. Must be between 0 and 1.
`init_image` will be used as a starting point, adding more noise to it the larger the `strength`. The
number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added
noise will be maximum and the denoising process will run for the full number of iterations specified in
`num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
`num_inference_steps`. A value of 1, therefore, essentially ignores `init_image`.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator`, *optional*):
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
deterministic.
generator (`np.random.RandomState`, *optional*):
A np.random.RandomState to make generation deterministic.
latents (`np.ndarray`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
@@ -675,127 +533,205 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(OnnxStableDiffusionPipeline
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`.
"""
message = "Please use `image` instead of `init_image`."
init_image = deprecate("init_image", "0.12.0", message, take_from=kwargs)
image = init_image or image
# 0. Default height and width to unet
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor
if isinstance(prompt, str):
batch_size = 1
prompt = [prompt]
elif isinstance(prompt, list):
batch_size = len(prompt)
else:
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
# 1. Check inputs. Raise error if not correct
self.check_inputs(prompt, height, width, strength, callback_steps)
if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
# get prompt text embeddings
# 2. Define call parameters
batch_size = 1 if isinstance(prompt, str) else len(prompt)
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# get unconditional embeddings for classifier free guidance
if negative_prompt is None:
negative_prompt = [""] * batch_size
elif isinstance(negative_prompt, str):
negative_prompt = [negative_prompt] * batch_size
if batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
# 3. Encode input prompt
text_embeddings = self._encode_prompt(
prompt,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt,
max_embeddings_multiples,
if generator is None:
generator = np.random
text_embeddings, uncond_embeddings = get_weighted_text_embeddings(
pipe=self,
prompt=prompt,
uncond_prompt=negative_prompt if do_classifier_free_guidance else None,
max_embeddings_multiples=max_embeddings_multiples,
**kwargs,
)
dtype = text_embeddings.dtype
# 4. Preprocess image and mask
if isinstance(image, PIL.Image.Image):
image = preprocess_image(image)
if image is not None:
image = image.astype(dtype)
if isinstance(mask_image, PIL.Image.Image):
mask_image = preprocess_mask(mask_image, self.vae_scale_factor)
if mask_image is not None:
mask = mask_image.astype(dtype)
mask = np.concatenate([mask] * batch_size * num_images_per_prompt)
else:
mask = None
text_embeddings = text_embeddings.repeat(num_images_per_prompt, 0)
if do_classifier_free_guidance:
uncond_embeddings = uncond_embeddings.repeat(num_images_per_prompt, 0)
text_embeddings = np.concatenate([uncond_embeddings, text_embeddings])
# 5. set timesteps
# set timesteps
self.scheduler.set_timesteps(num_inference_steps)
timestep_dtype = next(
(input.type for input in self.unet.model.get_inputs() if input.name == "timestep"), "tensor(float)"
)
timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype]
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, image is None)
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
# 6. Prepare latent variables
latents, init_latents_orig, noise = self.prepare_latents(
image,
latent_timestep,
batch_size * num_images_per_prompt,
height,
width,
dtype,
generator,
latents,
)
latents_dtype = text_embeddings.dtype
init_latents_orig = None
mask = None
noise = None
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
if init_image is None:
latents_shape = (
batch_size * num_images_per_prompt,
4,
height // 8,
width // 8,
)
# 8. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(torch.from_numpy(latent_model_input), t)
latent_model_input = latent_model_input.numpy()
if latents is None:
latents = generator.randn(*latents_shape).astype(latents_dtype)
elif latents.shape != latents_shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
# predict the noise residual
noise_pred = self.unet(
sample=latent_model_input,
timestep=np.array([t], dtype=timestep_dtype),
encoder_hidden_states=text_embeddings,
timesteps = self.scheduler.timesteps.to(self.device)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
else:
if isinstance(init_image, PIL.Image.Image):
init_image = preprocess_image(init_image)
# encode the init image into latents and scale the latents
init_image = init_image.astype(latents_dtype)
init_latents = self.vae_encoder(sample=init_image)[0]
init_latents = 0.18215 * init_latents
init_latents = np.concatenate([init_latents] * batch_size * num_images_per_prompt)
init_latents_orig = init_latents
# preprocess mask
if mask_image is not None:
if isinstance(mask_image, PIL.Image.Image):
mask_image = preprocess_mask(mask_image)
mask_image = mask_image.astype(latents_dtype)
mask = np.concatenate([mask_image] * batch_size * num_images_per_prompt)
# check sizes
if not mask.shape == init_latents.shape:
print(mask.shape, init_latents.shape)
raise ValueError("The mask and init_image should be the same size!")
# get the original timestep using init_timestep
offset = self.scheduler.config.get("steps_offset", 0)
init_timestep = int(num_inference_steps * strength) + offset
init_timestep = min(init_timestep, num_inference_steps)
timesteps = self.scheduler.timesteps[-init_timestep]
timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt)
# add noise to latents using the timesteps
noise = generator.randn(*init_latents.shape).astype(latents_dtype)
latents = self.scheduler.add_noise(
torch.from_numpy(init_latents), torch.from_numpy(noise), timesteps
).numpy()
t_start = max(num_inference_steps - init_timestep + offset, 0)
timesteps = self.scheduler.timesteps[t_start:]
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
for i, t in enumerate(self.progress_bar(timesteps)):
# expand the latents if we are doing classifier free guidance
latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
noise_pred = self.unet(
sample=latent_model_input,
timestep=np.array([t]),
encoder_hidden_states=text_embeddings,
)
noise_pred = noise_pred[0]
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample.numpy()
if mask is not None:
# masking
init_latents_proper = self.scheduler.add_noise(
torch.from_numpy(init_latents_orig),
torch.from_numpy(noise),
torch.tensor([t]),
).numpy()
latents = (init_latents_proper * mask) + (latents * (1 - mask))
# call the callback, if provided
if i % callback_steps == 0:
if callback is not None:
callback(i, t, latents)
if is_cancelled_callback is not None and is_cancelled_callback():
return None
latents = 1 / 0.18215 * latents
# image = self.vae_decoder(latent_sample=latents)[0]
# it seems likes there is a problem for using half-precision vae decoder if batchsize>1
image = []
for i in range(latents.shape[0]):
image.append(self.vae_decoder(latent_sample=latents[i : i + 1])[0])
image = np.concatenate(image)
image = np.clip(image / 2 + 0.5, 0, 1)
image = image.transpose((0, 2, 3, 1))
if self.safety_checker is not None:
safety_checker_input = self.feature_extractor(
self.numpy_to_pil(image), return_tensors="np"
).pixel_values.astype(image.dtype)
# There will throw an error if use safety_checker directly and batchsize>1
images, has_nsfw_concept = [], []
for i in range(image.shape[0]):
image_i, has_nsfw_concept_i = self.safety_checker(
clip_input=safety_checker_input[i : i + 1], images=image[i : i + 1]
)
noise_pred = noise_pred[0]
images.append(image_i)
has_nsfw_concept.append(has_nsfw_concept_i[0])
image = np.concatenate(images)
else:
has_nsfw_concept = None
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
scheduler_output = self.scheduler.step(
torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs
)
latents = scheduler_output.prev_sample.numpy()
if mask is not None:
# masking
init_latents_proper = self.scheduler.add_noise(
torch.from_numpy(init_latents_orig),
torch.from_numpy(noise),
t,
).numpy()
latents = (init_latents_proper * mask) + (latents * (1 - mask))
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if i % callback_steps == 0:
if callback is not None:
callback(i, t, latents)
if is_cancelled_callback is not None and is_cancelled_callback():
return None
# 9. Post-processing
image = self.decode_latents(latents)
# 10. Run safety checker
image, has_nsfw_concept = self.run_safety_checker(image)
# 11. Convert to PIL
if output_type == "pil":
image = self.numpy_to_pil(image)
if not return_dict:
return image, has_nsfw_concept
return (image, has_nsfw_concept)
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
@@ -809,7 +745,7 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(OnnxStableDiffusionPipeline
guidance_scale: float = 7.5,
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[torch.Generator] = None,
generator: Optional[np.random.RandomState] = None,
latents: Optional[np.ndarray] = None,
max_embeddings_multiples: Optional[int] = 3,
output_type: Optional[str] = "pil",
@@ -844,9 +780,8 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(OnnxStableDiffusionPipeline
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator`, *optional*):
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
deterministic.
generator (`np.random.RandomState`, *optional*):
A np.random.RandomState to make generation deterministic.
latents (`np.ndarray`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
@@ -893,7 +828,7 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(OnnxStableDiffusionPipeline
def img2img(
self,
image: Union[np.ndarray, PIL.Image.Image],
init_image: Union[np.ndarray, PIL.Image.Image],
prompt: Union[str, List[str]],
negative_prompt: Optional[Union[str, List[str]]] = None,
strength: float = 0.8,
@@ -901,7 +836,7 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(OnnxStableDiffusionPipeline
guidance_scale: Optional[float] = 7.5,
num_images_per_prompt: Optional[int] = 1,
eta: Optional[float] = 0.0,
generator: Optional[torch.Generator] = None,
generator: Optional[np.random.RandomState] = None,
max_embeddings_multiples: Optional[int] = 3,
output_type: Optional[str] = "pil",
return_dict: bool = True,
@@ -912,7 +847,7 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(OnnxStableDiffusionPipeline
r"""
Function for image-to-image generation.
Args:
image (`np.ndarray` or `PIL.Image.Image`):
init_image (`np.ndarray` or `PIL.Image.Image`):
`Image`, or ndarray representing an image batch, that will be used as the starting point for the
process.
prompt (`str` or `List[str]`):
@@ -921,11 +856,11 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(OnnxStableDiffusionPipeline
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
strength (`float`, *optional*, defaults to 0.8):
Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1.
`image` will be used as a starting point, adding more noise to it the larger the `strength`. The
Conceptually, indicates how much to transform the reference `init_image`. Must be between 0 and 1.
`init_image` will be used as a starting point, adding more noise to it the larger the `strength`. The
number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added
noise will be maximum and the denoising process will run for the full number of iterations specified in
`num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
`num_inference_steps`. A value of 1, therefore, essentially ignores `init_image`.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference. This parameter will be modulated by `strength`.
@@ -940,9 +875,8 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(OnnxStableDiffusionPipeline
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator`, *optional*):
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
deterministic.
generator (`np.random.RandomState`, *optional*):
A np.random.RandomState to make generation deterministic.
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
The max multiple length of prompt embeddings compared to the max output length of text encoder.
output_type (`str`, *optional*, defaults to `"pil"`):
@@ -967,7 +901,7 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(OnnxStableDiffusionPipeline
return self.__call__(
prompt=prompt,
negative_prompt=negative_prompt,
image=image,
init_image=init_image,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
strength=strength,
@@ -984,7 +918,7 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(OnnxStableDiffusionPipeline
def inpaint(
self,
image: Union[np.ndarray, PIL.Image.Image],
init_image: Union[np.ndarray, PIL.Image.Image],
mask_image: Union[np.ndarray, PIL.Image.Image],
prompt: Union[str, List[str]],
negative_prompt: Optional[Union[str, List[str]]] = None,
@@ -993,7 +927,7 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(OnnxStableDiffusionPipeline
guidance_scale: Optional[float] = 7.5,
num_images_per_prompt: Optional[int] = 1,
eta: Optional[float] = 0.0,
generator: Optional[torch.Generator] = None,
generator: Optional[np.random.RandomState] = None,
max_embeddings_multiples: Optional[int] = 3,
output_type: Optional[str] = "pil",
return_dict: bool = True,
@@ -1004,11 +938,11 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(OnnxStableDiffusionPipeline
r"""
Function for inpaint.
Args:
image (`np.ndarray` or `PIL.Image.Image`):
init_image (`np.ndarray` or `PIL.Image.Image`):
`Image`, or tensor representing an image batch, that will be used as the starting point for the
process. This is the image whose masked region will be inpainted.
mask_image (`np.ndarray` or `PIL.Image.Image`):
`Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
`Image`, or tensor representing an image batch, to mask `init_image`. White pixels in the mask will be
replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
@@ -1020,7 +954,7 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(OnnxStableDiffusionPipeline
strength (`float`, *optional*, defaults to 0.8):
Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength`
is 1, the denoising process will be run on the masked area for the full number of iterations specified
in `num_inference_steps`. `image` will be used as a reference for the masked area, adding more
in `num_inference_steps`. `init_image` will be used as a reference for the masked area, adding more
noise to that region the larger the `strength`. If `strength` is 0, no inpainting will occur.
num_inference_steps (`int`, *optional*, defaults to 50):
The reference number of denoising steps. More denoising steps usually lead to a higher quality image at
@@ -1036,9 +970,8 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(OnnxStableDiffusionPipeline
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator`, *optional*):
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
deterministic.
generator (`np.random.RandomState`, *optional*):
A np.random.RandomState to make generation deterministic.
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
The max multiple length of prompt embeddings compared to the max output length of text encoder.
output_type (`str`, *optional*, defaults to `"pil"`):
@@ -1063,7 +996,7 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(OnnxStableDiffusionPipeline
return self.__call__(
prompt=prompt,
negative_prompt=negative_prompt,
image=image,
init_image=init_image,
mask_image=mask_image,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,

View File

@@ -113,7 +113,7 @@ class MultilingualStableDiffusion(DiffusionPipeline):
scheduler._internal_dict = FrozenDict(new_config)
if safety_checker is None:
logger.warning(
logger.warn(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"

View File

@@ -13,7 +13,6 @@
# limitations under the License.
import importlib
import warnings
from typing import Callable, List, Optional, Union
import torch
@@ -22,7 +21,7 @@ from diffusers import LMSDiscreteScheduler
from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.utils import is_accelerate_available, logging
from k_diffusion.external import CompVisDenoiser, CompVisVDenoiser
from k_diffusion.external import CompVisDenoiser
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -34,12 +33,7 @@ class ModelWrapper:
self.alphas_cumprod = alphas_cumprod
def apply_model(self, *args, **kwargs):
if len(args) == 3:
encoder_hidden_states = args[-1]
args = args[:2]
if kwargs.get("cond", None) is not None:
encoder_hidden_states = kwargs.pop("cond")
return self.model(*args, encoder_hidden_states=encoder_hidden_states, **kwargs).sample
return self.model(*args, **kwargs).sample
class StableDiffusionPipeline(DiffusionPipeline):
@@ -69,7 +63,6 @@ class StableDiffusionPipeline(DiffusionPipeline):
feature_extractor ([`CLIPFeatureExtractor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
_optional_components = ["safety_checker", "feature_extractor"]
def __init__(
self,
@@ -84,7 +77,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
super().__init__()
if safety_checker is None:
logger.warning(
logger.warn(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
@@ -106,20 +99,31 @@ class StableDiffusionPipeline(DiffusionPipeline):
)
model = ModelWrapper(unet, scheduler.alphas_cumprod)
if scheduler.prediction_type == "v_prediction":
self.k_diffusion_model = CompVisVDenoiser(model)
else:
self.k_diffusion_model = CompVisDenoiser(model)
self.k_diffusion_model = CompVisDenoiser(model)
def set_sampler(self, scheduler_type: str):
warnings.warn("The `set_sampler` method is deprecated, please use `set_scheduler` instead.")
return self.set_scheduler(scheduler_type)
def set_scheduler(self, scheduler_type: str):
library = importlib.import_module("k_diffusion")
sampling = getattr(library, "sampling")
self.sampler = getattr(sampling, scheduler_type)
def enable_xformers_memory_efficient_attention(self):
r"""
Enable memory efficient attention as implemented in xformers.
When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
time. Speed up at training time is not guaranteed.
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
is used.
"""
self.unet.set_use_memory_efficient_attention_xformers(True)
def disable_xformers_memory_efficient_attention(self):
r"""
Disable memory efficient attention as implemented in xformers.
"""
self.unet.set_use_memory_efficient_attention_xformers(False)
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
r"""
Enable sliced attention computation.
@@ -431,7 +435,6 @@ class StableDiffusionPipeline(DiffusionPipeline):
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=text_embeddings.device)
sigmas = self.scheduler.sigmas
sigmas = sigmas.to(text_embeddings.dtype)
# 5. Prepare latent variables
num_channels_latents = self.unet.in_channels
@@ -452,7 +455,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
def model_fn(x, t):
latent_model_input = torch.cat([x] * 2)
noise_pred = self.k_diffusion_model(latent_model_input, t, cond=text_embeddings)
noise_pred = self.k_diffusion_model(latent_model_input, t, encoder_hidden_states=text_embeddings)
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

View File

@@ -42,7 +42,7 @@ class SpeechToImagePipeline(DiffusionPipeline):
super().__init__()
if safety_checker is None:
logger.warning(
logger.warn(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"

View File

@@ -50,7 +50,6 @@ class StableDiffusionMegaPipeline(DiffusionPipeline):
feature_extractor ([`CLIPFeatureExtractor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
_optional_components = ["safety_checker", "feature_extractor"]
def __init__(
self,
@@ -61,7 +60,6 @@ class StableDiffusionMegaPipeline(DiffusionPipeline):
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
requires_safety_checker: bool = True,
):
super().__init__()
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
@@ -87,7 +85,6 @@ class StableDiffusionMegaPipeline(DiffusionPipeline):
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self.register_to_config(requires_safety_checker=requires_safety_checker)
@property
def components(self) -> Dict[str, Any]:
@@ -124,7 +121,7 @@ class StableDiffusionMegaPipeline(DiffusionPipeline):
def inpaint(
self,
prompt: Union[str, List[str]],
image: Union[torch.FloatTensor, PIL.Image.Image],
init_image: Union[torch.FloatTensor, PIL.Image.Image],
mask_image: Union[torch.FloatTensor, PIL.Image.Image],
strength: float = 0.8,
num_inference_steps: Optional[int] = 50,
@@ -141,7 +138,7 @@ class StableDiffusionMegaPipeline(DiffusionPipeline):
# For more information on how this function works, please see: https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion#diffusers.StableDiffusionImg2ImgPipeline
return StableDiffusionInpaintPipelineLegacy(**self.components)(
prompt=prompt,
image=image,
init_image=init_image,
mask_image=mask_image,
strength=strength,
num_inference_steps=num_inference_steps,
@@ -159,7 +156,7 @@ class StableDiffusionMegaPipeline(DiffusionPipeline):
def img2img(
self,
prompt: Union[str, List[str]],
image: Union[torch.FloatTensor, PIL.Image.Image],
init_image: Union[torch.FloatTensor, PIL.Image.Image],
strength: float = 0.8,
num_inference_steps: Optional[int] = 50,
guidance_scale: Optional[float] = 7.5,
@@ -176,7 +173,7 @@ class StableDiffusionMegaPipeline(DiffusionPipeline):
# For more information on how this function works, please see: https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion#diffusers.StableDiffusionImg2ImgPipeline
return StableDiffusionImg2ImgPipeline(**self.components)(
prompt=prompt,
image=image,
init_image=init_image,
strength=strength,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,

View File

@@ -99,7 +99,7 @@ class TextInpainting(DiffusionPipeline):
scheduler._internal_dict = FrozenDict(new_config)
if safety_checker is None:
logger.warning(
logger.warn(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
@@ -183,6 +183,24 @@ class TextInpainting(DiffusionPipeline):
return torch.device(module._hf_hook.execution_device)
return self.device
def enable_xformers_memory_efficient_attention(self):
r"""
Enable memory efficient attention as implemented in xformers.
When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
time. Speed up at training time is not guaranteed.
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
is used.
"""
self.unet.set_use_memory_efficient_attention_xformers(True)
def disable_xformers_memory_efficient_attention(self):
r"""
Disable memory efficient attention as implemented in xformers.
"""
self.unet.set_use_memory_efficient_attention_xformers(False)
@torch.no_grad()
def __call__(
self,

View File

@@ -135,7 +135,7 @@ class WildcardStableDiffusionPipeline(DiffusionPipeline):
scheduler._internal_dict = FrozenDict(new_config)
if safety_checker is None:
logger.warning(
logger.warn(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"

View File

@@ -9,18 +9,8 @@ The `train_dreambooth.py` script shows how to implement the training procedure a
Before running the scripts, make sure to install the library's training dependencies:
**Important**
To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
```bash
git clone https://github.com/huggingface/diffusers
cd diffusers
pip install -e .
```
Then cd in the example folder and run
```bash
pip install -r requirements.txt
pip install -U -r requirements.txt
```
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
@@ -29,19 +19,6 @@ And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) e
accelerate config
```
Or for a default accelerate configuration without answering questions about your environment
```bash
accelerate config default
```
Or if your environment doesn't support an interactive shell e.g. a notebook
```python
from accelerate.utils import write_basic_config
write_basic_config()
```
### Dog toy example
You need to accept the model license before downloading or using the weights. In this example we'll use model version `v1-4`, so you'll need to visit [its card](https://huggingface.co/CompVis/stable-diffusion-v1-4), read the license and tick the checkbox if you agree.
@@ -62,8 +39,6 @@ Now let's get our dataset. Download images from [here](https://drive.google.com/
And launch the training using
**___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___**
```bash
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
export INSTANCE_DIR="path-to-instance-images"
@@ -86,7 +61,7 @@ accelerate launch train_dreambooth.py \
### Training with prior-preservation loss
Prior-preservation is used to avoid overfitting and language-drift. Refer to the paper to learn more about it. For prior-preservation we first generate images using the model with a class prompt and then use those during training along with our data.
According to the paper, it's recommended to generate `num_epochs * num_samples` images for prior-preservation. 200-300 works well for most cases. The `num_class_images` flag sets the number of images to generate with the class prompt. You can place existing images in `class_data_dir`, and the training script will generate any additional images so that `num_class_images` are present in `class_data_dir` during training time.
According to the paper, it's recommended to generate `num_epochs * num_samples` images for prior-preservation. 200-300 works well for most cases.
```bash
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
@@ -166,7 +141,7 @@ export INSTANCE_DIR="path-to-instance-images"
export CLASS_DIR="path-to-class-images"
export OUTPUT_DIR="path-to-save-model"
accelerate launch --mixed_precision="fp16" train_dreambooth.py \
accelerate launch train_dreambooth.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--instance_data_dir=$INSTANCE_DIR \
--class_data_dir=$CLASS_DIR \
@@ -182,7 +157,8 @@ accelerate launch --mixed_precision="fp16" train_dreambooth.py \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--num_class_images=200 \
--max_train_steps=800
--max_train_steps=800 \
--mixed_precision=fp16
```
### Fine-tune text encoder with the UNet.
@@ -218,17 +194,6 @@ accelerate launch train_dreambooth.py \
--max_train_steps=800
```
### Using DreamBooth for other pipelines than Stable Diffusion
Altdiffusion also support dreambooth now, the runing comman is basically the same as abouve, all you need to do is replace the `MODEL_NAME` like this:
One can now simply change the `pretrained_model_name_or_path` to another architecture such as [`AltDiffusion`](https://huggingface.co/docs/diffusers/api/pipelines/alt_diffusion).
```
export MODEL_NAME="CompVis/stable-diffusion-v1-4" --> export MODEL_NAME="BAAI/AltDiffusion-m9"
or
export MODEL_NAME="CompVis/stable-diffusion-v1-4" --> export MODEL_NAME="BAAI/AltDiffusion"
```
### Inference
Once you have trained a model using above command, the inference can be done simply using the `StableDiffusionPipeline`. Make sure to include the `identifier`(e.g. sks in above example) in your prompt.
@@ -327,97 +292,3 @@ python train_dreambooth_flax.py \
--num_class_images=200 \
--max_train_steps=800
```
### Training with prior-preservation loss
Prior-preservation is used to avoid overfitting and language-drift. Refer to the paper to learn more about it. For prior-preservation we first generate images using the model with a class prompt and then use those during training along with our data.
According to the paper, it's recommended to generate `num_epochs * num_samples` images for prior-preservation. 200-300 works well for most cases.
```bash
export MODEL_NAME="runwayml/stable-diffusion-inpainting"
export INSTANCE_DIR="path-to-instance-images"
export CLASS_DIR="path-to-class-images"
export OUTPUT_DIR="path-to-save-model"
accelerate launch train_dreambooth_inpaint.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--instance_data_dir=$INSTANCE_DIR \
--class_data_dir=$CLASS_DIR \
--output_dir=$OUTPUT_DIR \
--with_prior_preservation --prior_loss_weight=1.0 \
--instance_prompt="a photo of sks dog" \
--class_prompt="a photo of dog" \
--resolution=512 \
--train_batch_size=1 \
--gradient_accumulation_steps=1 \
--learning_rate=5e-6 \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--num_class_images=200 \
--max_train_steps=800
```
### Training with gradient checkpointing and 8-bit optimizer:
With the help of gradient checkpointing and the 8-bit optimizer from bitsandbytes it's possible to run train dreambooth on a 16GB GPU.
To install `bitandbytes` please refer to this [readme](https://github.com/TimDettmers/bitsandbytes#requirements--installation).
```bash
export MODEL_NAME="runwayml/stable-diffusion-inpainting"
export INSTANCE_DIR="path-to-instance-images"
export CLASS_DIR="path-to-class-images"
export OUTPUT_DIR="path-to-save-model"
accelerate launch train_dreambooth_inpaint.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--instance_data_dir=$INSTANCE_DIR \
--class_data_dir=$CLASS_DIR \
--output_dir=$OUTPUT_DIR \
--with_prior_preservation --prior_loss_weight=1.0 \
--instance_prompt="a photo of sks dog" \
--class_prompt="a photo of dog" \
--resolution=512 \
--train_batch_size=1 \
--gradient_accumulation_steps=2 --gradient_checkpointing \
--use_8bit_adam \
--learning_rate=5e-6 \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--num_class_images=200 \
--max_train_steps=800
```
### Fine-tune text encoder with the UNet.
The script also allows to fine-tune the `text_encoder` along with the `unet`. It's been observed experimentally that fine-tuning `text_encoder` gives much better results especially on faces.
Pass the `--train_text_encoder` argument to the script to enable training `text_encoder`.
___Note: Training text encoder requires more memory, with this option the training won't fit on 16GB GPU. It needs at least 24GB VRAM.___
```bash
export MODEL_NAME="runwayml/stable-diffusion-inpainting"
export INSTANCE_DIR="path-to-instance-images"
export CLASS_DIR="path-to-class-images"
export OUTPUT_DIR="path-to-save-model"
accelerate launch train_dreambooth_inpaint.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--train_text_encoder \
--instance_data_dir=$INSTANCE_DIR \
--class_data_dir=$CLASS_DIR \
--output_dir=$OUTPUT_DIR \
--with_prior_preservation --prior_loss_weight=1.0 \
--instance_prompt="a photo of sks dog" \
--class_prompt="a photo of dog" \
--resolution=512 \
--train_batch_size=1 \
--use_8bit_adam \
--gradient_checkpointing \
--learning_rate=2e-6 \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--num_class_images=200 \
--max_train_steps=800
```

View File

@@ -1,3 +1,4 @@
diffusers>==0.5.0
accelerate
torchvision
transformers>=4.21.0

View File

@@ -1,3 +1,4 @@
diffusers>==0.5.1
transformers>=4.21.0
flax
optax

View File

@@ -14,42 +14,18 @@ from torch.utils.data import Dataset
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import set_seed
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version
from huggingface_hub import HfFolder, Repository, whoami
from PIL import Image
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import AutoTokenizer, PretrainedConfig
from transformers import CLIPTextModel, CLIPTokenizer
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.10.0.dev0")
logger = get_logger(__name__)
def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
text_encoder_config = PretrainedConfig.from_pretrained(
pretrained_model_name_or_path,
subfolder="text_encoder",
revision=revision,
)
model_class = text_encoder_config.architectures[0]
if model_class == "CLIPTextModel":
from transformers import CLIPTextModel
return CLIPTextModel
elif model_class == "RobertaSeriesModelWithTransformation":
from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation
return RobertaSeriesModelWithTransformation
else:
raise ValueError(f"{model_class} is not supported.")
def parse_args(input_args=None):
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument(
@@ -111,8 +87,8 @@ def parse_args(input_args=None):
type=int,
default=100,
help=(
"Minimal class images for prior preservation loss. If there are not enough images already present in"
" class_data_dir, additional images will be sampled with class_prompt."
"Minimal class images for prior preservation loss. If not have enough images, additional images will be"
" sampled with class_prompt."
),
)
parser.add_argument(
@@ -148,7 +124,6 @@ def parse_args(input_args=None):
default=None,
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
)
parser.add_argument("--save_steps", type=int, default=500, help="Save checkpoint every X updates steps.")
parser.add_argument(
"--gradient_accumulation_steps",
type=int,
@@ -212,12 +187,12 @@ def parse_args(input_args=None):
parser.add_argument(
"--mixed_precision",
type=str,
default=None,
default="no",
choices=["no", "fp16", "bf16"],
help=(
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
"Whether to use mixed precision. Choose"
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
"and an Nvidia Ampere GPU."
),
)
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
@@ -304,10 +279,9 @@ class DreamBoothDataset(Dataset):
example["instance_images"] = self.image_transforms(instance_image)
example["instance_prompt_ids"] = self.tokenizer(
self.instance_prompt,
padding="do_not_pad",
truncation=True,
padding="max_length",
max_length=self.tokenizer.model_max_length,
return_tensors="pt",
).input_ids
if self.class_data_root:
@@ -317,37 +291,14 @@ class DreamBoothDataset(Dataset):
example["class_images"] = self.image_transforms(class_image)
example["class_prompt_ids"] = self.tokenizer(
self.class_prompt,
padding="do_not_pad",
truncation=True,
padding="max_length",
max_length=self.tokenizer.model_max_length,
return_tensors="pt",
).input_ids
return example
def collate_fn(examples, with_prior_preservation=False):
input_ids = [example["instance_prompt_ids"] for example in examples]
pixel_values = [example["instance_images"] for example in examples]
# Concat class and instance examples for prior preservation.
# We do this to avoid doing two forward passes.
if with_prior_preservation:
input_ids += [example["class_prompt_ids"] for example in examples]
pixel_values += [example["class_images"] for example in examples]
pixel_values = torch.stack(pixel_values)
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
input_ids = torch.cat(input_ids, dim=0)
batch = {
"input_ids": input_ids,
"pixel_values": pixel_values,
}
return batch
class PromptDataset(Dataset):
"A simple dataset to prepare the prompts to generate class images on multiple GPUs."
@@ -405,7 +356,7 @@ def main(args):
if cur_class_images < args.num_class_images:
torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
pipeline = DiffusionPipeline.from_pretrained(
pipeline = StableDiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
torch_dtype=torch_dtype,
safety_checker=None,
@@ -455,24 +406,19 @@ def main(args):
# Load the tokenizer
if args.tokenizer_name:
tokenizer = AutoTokenizer.from_pretrained(
tokenizer = CLIPTokenizer.from_pretrained(
args.tokenizer_name,
revision=args.revision,
use_fast=False,
)
elif args.pretrained_model_name_or_path:
tokenizer = AutoTokenizer.from_pretrained(
tokenizer = CLIPTokenizer.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="tokenizer",
revision=args.revision,
use_fast=False,
)
# import correct text encoder class
text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)
# Load models and create wrapper for stable diffusion
text_encoder = text_encoder_cls.from_pretrained(
text_encoder = CLIPTextModel.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="text_encoder",
revision=args.revision,
@@ -526,7 +472,7 @@ def main(args):
eps=args.adam_epsilon,
)
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
noise_scheduler = DDPMScheduler.from_config(args.pretrained_model_name_or_path, subfolder="scheduler")
train_dataset = DreamBoothDataset(
instance_data_root=args.instance_data_dir,
@@ -538,12 +484,34 @@ def main(args):
center_crop=args.center_crop,
)
def collate_fn(examples):
input_ids = [example["instance_prompt_ids"] for example in examples]
pixel_values = [example["instance_images"] for example in examples]
# Concat class and instance examples for prior preservation.
# We do this to avoid doing two forward passes.
if args.with_prior_preservation:
input_ids += [example["class_prompt_ids"] for example in examples]
pixel_values += [example["class_images"] for example in examples]
pixel_values = torch.stack(pixel_values)
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
input_ids = tokenizer.pad(
{"input_ids": input_ids},
padding="max_length",
max_length=tokenizer.model_max_length,
return_tensors="pt",
).input_ids
batch = {
"input_ids": input_ids,
"pixel_values": pixel_values,
}
return batch
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
batch_size=args.train_batch_size,
shuffle=True,
collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),
num_workers=1,
train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn, num_workers=1
)
# Scheduler and math around the number of training steps.
@@ -570,9 +538,9 @@ def main(args):
)
weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16":
if args.mixed_precision == "fp16":
weight_dtype = torch.float16
elif accelerator.mixed_precision == "bf16":
elif args.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
# Move text_encode and vae to gpu.
@@ -635,31 +603,23 @@ def main(args):
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
# Predict the noise residual
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
# Get the target for loss depending on the prediction type
if noise_scheduler.config.prediction_type == "epsilon":
target = noise
elif noise_scheduler.config.prediction_type == "v_prediction":
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
if args.with_prior_preservation:
# Chunk the noise and model_pred into two parts and compute the loss on each part separately.
model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
target, target_prior = torch.chunk(target, 2, dim=0)
# Chunk the noise and noise_pred into two parts and compute the loss on each part separately.
noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0)
noise, noise_prior = torch.chunk(noise, 2, dim=0)
# Compute instance loss
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean()
loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="none").mean([1, 2, 3]).mean()
# Compute prior loss
prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
prior_loss = F.mse_loss(noise_pred_prior.float(), noise_prior.float(), reduction="mean")
# Add the prior loss to the instance loss.
loss = loss + args.prior_loss_weight * prior_loss
else:
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
accelerator.backward(loss)
if accelerator.sync_gradients:
@@ -678,17 +638,6 @@ def main(args):
progress_bar.update(1)
global_step += 1
if global_step % args.save_steps == 0:
if accelerator.is_main_process:
pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
unet=accelerator.unwrap_model(unet),
text_encoder=accelerator.unwrap_model(text_encoder),
revision=args.revision,
)
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
pipeline.save_pretrained(save_path)
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
accelerator.log(logs, step=global_step)
@@ -700,7 +649,7 @@ def main(args):
# Create the pipeline using using the trained modules and save it.
if accelerator.is_main_process:
pipeline = DiffusionPipeline.from_pretrained(
pipeline = StableDiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
unet=accelerator.unwrap_model(unet),
text_encoder=accelerator.unwrap_model(text_encoder),

View File

@@ -23,7 +23,6 @@ from diffusers import (
FlaxUNet2DConditionModel,
)
from diffusers.pipelines.stable_diffusion import FlaxStableDiffusionSafetyChecker
from diffusers.utils import check_min_version
from flax import jax_utils
from flax.training import train_state
from flax.training.common_utils import shard
@@ -34,9 +33,6 @@ from tqdm.auto import tqdm
from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel, set_seed
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.10.0.dev0")
logger = logging.getLogger(__name__)
@@ -93,8 +89,8 @@ def parse_args():
type=int,
default=100,
help=(
"Minimal class images for prior preservation loss. If there are not enough images already present in"
" class_data_dir, additional images will be sampled with class_prompt."
"Minimal class images for prior preservation loss. If not have enough images, additional images will be"
" sampled with class_prompt."
),
)
parser.add_argument(

View File

@@ -1,14 +0,0 @@
# Research projects
This folder contains various research projects using 🧨 Diffusers.
They are not really maintained by the core maintainers of this library and often require a specific version of Diffusers that is indicated in the requirements file of each folder.
Updating them to the most recent version of the library will require some work.
To use any of them, just run the command
```
pip install -r requirements.txt
```
inside the folder of your choice.
If you need help with any of those, please open an issue where you directly ping the author(s), as indicated at the top of the README of each folder.

View File

@@ -1,26 +0,0 @@
# Dreambooth for the inpainting model
This script was added by @thedarkzeno .
Please note that this script is not actively maintained, you can open an issue and tag @thedarkzeno or @patil-suraj though.
```bash
export MODEL_NAME="runwayml/stable-diffusion-inpainting"
export INSTANCE_DIR="path-to-instance-images"
export OUTPUT_DIR="path-to-save-model"
accelerate launch train_dreambooth_inpaint.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--instance_data_dir=$INSTANCE_DIR \
--output_dir=$OUTPUT_DIR \
--instance_prompt="a photo of sks dog" \
--resolution=512 \
--train_batch_size=1 \
--gradient_accumulation_steps=1 \
--learning_rate=5e-6 \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--max_train_steps=400
```
The script is also compatible with prior preservation loss and gradient checkpointing

View File

@@ -1,7 +0,0 @@
diffusers==0.9.0
accelerate
torchvision
transformers>=4.21.0
ftfy
tensorboard
modelcards

View File

@@ -1,747 +0,0 @@
import argparse
import hashlib
import itertools
import math
import os
import random
from pathlib import Path
from typing import Optional
import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch.utils.data import Dataset
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import set_seed
from diffusers import (
AutoencoderKL,
DDPMScheduler,
StableDiffusionInpaintPipeline,
StableDiffusionPipeline,
UNet2DConditionModel,
)
from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version
from huggingface_hub import HfFolder, Repository, whoami
from PIL import Image, ImageDraw
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.10.0.dev0")
logger = get_logger(__name__)
def prepare_mask_and_masked_image(image, mask):
image = np.array(image.convert("RGB"))
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
mask = np.array(mask.convert("L"))
mask = mask.astype(np.float32) / 255.0
mask = mask[None, None]
mask[mask < 0.5] = 0
mask[mask >= 0.5] = 1
mask = torch.from_numpy(mask)
masked_image = image * (mask < 0.5)
return mask, masked_image
# generate random masks
def random_mask(im_shape, ratio=1, mask_full_image=False):
mask = Image.new("L", im_shape, 0)
draw = ImageDraw.Draw(mask)
size = (random.randint(0, int(im_shape[0] * ratio)), random.randint(0, int(im_shape[1] * ratio)))
# use this to always mask the whole image
if mask_full_image:
size = (int(im_shape[0] * ratio), int(im_shape[1] * ratio))
limits = (im_shape[0] - size[0] // 2, im_shape[1] - size[1] // 2)
center = (random.randint(size[0] // 2, limits[0]), random.randint(size[1] // 2, limits[1]))
draw_type = random.randint(0, 1)
if draw_type == 0 or mask_full_image:
draw.rectangle(
(center[0] - size[0] // 2, center[1] - size[1] // 2, center[0] + size[0] // 2, center[1] + size[1] // 2),
fill=255,
)
else:
draw.ellipse(
(center[0] - size[0] // 2, center[1] - size[1] // 2, center[0] + size[0] // 2, center[1] + size[1] // 2),
fill=255,
)
return mask
def parse_args():
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument(
"--pretrained_model_name_or_path",
type=str,
default=None,
required=True,
help="Path to pretrained model or model identifier from huggingface.co/models.",
)
parser.add_argument(
"--tokenizer_name",
type=str,
default=None,
help="Pretrained tokenizer name or path if not the same as model_name",
)
parser.add_argument(
"--instance_data_dir",
type=str,
default=None,
required=True,
help="A folder containing the training data of instance images.",
)
parser.add_argument(
"--class_data_dir",
type=str,
default=None,
required=False,
help="A folder containing the training data of class images.",
)
parser.add_argument(
"--instance_prompt",
type=str,
default=None,
help="The prompt with identifier specifying the instance",
)
parser.add_argument(
"--class_prompt",
type=str,
default=None,
help="The prompt to specify images in the same class as provided instance images.",
)
parser.add_argument(
"--with_prior_preservation",
default=False,
action="store_true",
help="Flag to add prior preservation loss.",
)
parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.")
parser.add_argument(
"--num_class_images",
type=int,
default=100,
help=(
"Minimal class images for prior preservation loss. If not have enough images, additional images will be"
" sampled with class_prompt."
),
)
parser.add_argument(
"--output_dir",
type=str,
default="text-inversion-model",
help="The output directory where the model predictions and checkpoints will be written.",
)
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
parser.add_argument(
"--resolution",
type=int,
default=512,
help=(
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
" resolution"
),
)
parser.add_argument(
"--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution"
)
parser.add_argument("--train_text_encoder", action="store_true", help="Whether to train the text encoder")
parser.add_argument(
"--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
)
parser.add_argument(
"--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
)
parser.add_argument("--num_train_epochs", type=int, default=1)
parser.add_argument(
"--max_train_steps",
type=int,
default=None,
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
)
parser.add_argument(
"--gradient_accumulation_steps",
type=int,
default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.",
)
parser.add_argument(
"--gradient_checkpointing",
action="store_true",
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
)
parser.add_argument(
"--learning_rate",
type=float,
default=5e-6,
help="Initial learning rate (after the potential warmup period) to use.",
)
parser.add_argument(
"--scale_lr",
action="store_true",
default=False,
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
)
parser.add_argument(
"--lr_scheduler",
type=str,
default="constant",
help=(
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
' "constant", "constant_with_warmup"]'
),
)
parser.add_argument(
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
)
parser.add_argument(
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
)
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
parser.add_argument(
"--hub_model_id",
type=str,
default=None,
help="The name of the repository to keep in sync with the local `output_dir`.",
)
parser.add_argument(
"--logging_dir",
type=str,
default="logs",
help=(
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
),
)
parser.add_argument(
"--mixed_precision",
type=str,
default="no",
choices=["no", "fp16", "bf16"],
help=(
"Whether to use mixed precision. Choose"
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
"and an Nvidia Ampere GPU."
),
)
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
args = parser.parse_args()
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
if env_local_rank != -1 and env_local_rank != args.local_rank:
args.local_rank = env_local_rank
if args.instance_data_dir is None:
raise ValueError("You must specify a train data directory.")
if args.with_prior_preservation:
if args.class_data_dir is None:
raise ValueError("You must specify a data directory for class images.")
if args.class_prompt is None:
raise ValueError("You must specify prompt for class images.")
return args
class DreamBoothDataset(Dataset):
"""
A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
It pre-processes the images and the tokenizes prompts.
"""
def __init__(
self,
instance_data_root,
instance_prompt,
tokenizer,
class_data_root=None,
class_prompt=None,
size=512,
center_crop=False,
):
self.size = size
self.center_crop = center_crop
self.tokenizer = tokenizer
self.instance_data_root = Path(instance_data_root)
if not self.instance_data_root.exists():
raise ValueError("Instance images root doesn't exists.")
self.instance_images_path = list(Path(instance_data_root).iterdir())
self.num_instance_images = len(self.instance_images_path)
self.instance_prompt = instance_prompt
self._length = self.num_instance_images
if class_data_root is not None:
self.class_data_root = Path(class_data_root)
self.class_data_root.mkdir(parents=True, exist_ok=True)
self.class_images_path = list(self.class_data_root.iterdir())
self.num_class_images = len(self.class_images_path)
self._length = max(self.num_class_images, self.num_instance_images)
self.class_prompt = class_prompt
else:
self.class_data_root = None
self.image_transforms_resize_and_crop = transforms.Compose(
[
transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
]
)
self.image_transforms = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
def __len__(self):
return self._length
def __getitem__(self, index):
example = {}
instance_image = Image.open(self.instance_images_path[index % self.num_instance_images])
if not instance_image.mode == "RGB":
instance_image = instance_image.convert("RGB")
instance_image = self.image_transforms_resize_and_crop(instance_image)
example["PIL_images"] = instance_image
example["instance_images"] = self.image_transforms(instance_image)
example["instance_prompt_ids"] = self.tokenizer(
self.instance_prompt,
padding="do_not_pad",
truncation=True,
max_length=self.tokenizer.model_max_length,
).input_ids
if self.class_data_root:
class_image = Image.open(self.class_images_path[index % self.num_class_images])
if not class_image.mode == "RGB":
class_image = class_image.convert("RGB")
class_image = self.image_transforms_resize_and_crop(class_image)
example["class_images"] = self.image_transforms(class_image)
example["class_PIL_images"] = class_image
example["class_prompt_ids"] = self.tokenizer(
self.class_prompt,
padding="do_not_pad",
truncation=True,
max_length=self.tokenizer.model_max_length,
).input_ids
return example
class PromptDataset(Dataset):
"A simple dataset to prepare the prompts to generate class images on multiple GPUs."
def __init__(self, prompt, num_samples):
self.prompt = prompt
self.num_samples = num_samples
def __len__(self):
return self.num_samples
def __getitem__(self, index):
example = {}
example["prompt"] = self.prompt
example["index"] = index
return example
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
if token is None:
token = HfFolder.get_token()
if organization is None:
username = whoami(token)["name"]
return f"{username}/{model_id}"
else:
return f"{organization}/{model_id}"
def main():
args = parse_args()
logging_dir = Path(args.output_dir, args.logging_dir)
accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision,
log_with="tensorboard",
logging_dir=logging_dir,
)
# Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
# This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
# TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.
if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1:
raise ValueError(
"Gradient accumulation is not supported when training the text encoder in distributed training. "
"Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
)
if args.seed is not None:
set_seed(args.seed)
if args.with_prior_preservation:
class_images_dir = Path(args.class_data_dir)
if not class_images_dir.exists():
class_images_dir.mkdir(parents=True)
cur_class_images = len(list(class_images_dir.iterdir()))
if cur_class_images < args.num_class_images:
torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
pipeline = StableDiffusionInpaintPipeline.from_pretrained(
args.pretrained_model_name_or_path, torch_dtype=torch_dtype, safety_checker=None
)
pipeline.set_progress_bar_config(disable=True)
num_new_images = args.num_class_images - cur_class_images
logger.info(f"Number of class images to sample: {num_new_images}.")
sample_dataset = PromptDataset(args.class_prompt, num_new_images)
sample_dataloader = torch.utils.data.DataLoader(
sample_dataset, batch_size=args.sample_batch_size, num_workers=1
)
sample_dataloader = accelerator.prepare(sample_dataloader)
pipeline.to(accelerator.device)
transform_to_pil = transforms.ToPILImage()
for example in tqdm(
sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
):
bsz = len(example["prompt"])
fake_images = torch.rand((3, args.resolution, args.resolution))
transform_to_pil = transforms.ToPILImage()
fake_pil_images = transform_to_pil(fake_images)
fake_mask = random_mask((args.resolution, args.resolution), ratio=1, mask_full_image=True)
images = pipeline(prompt=example["prompt"], mask_image=fake_mask, image=fake_pil_images).images
for i, image in enumerate(images):
hash_image = hashlib.sha1(image.tobytes()).hexdigest()
image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
image.save(image_filename)
del pipeline
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Handle the repository creation
if accelerator.is_main_process:
if args.push_to_hub:
if args.hub_model_id is None:
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
else:
repo_name = args.hub_model_id
repo = Repository(args.output_dir, clone_from=repo_name)
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
if "step_*" not in gitignore:
gitignore.write("step_*\n")
if "epoch_*" not in gitignore:
gitignore.write("epoch_*\n")
elif args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True)
# Load the tokenizer
if args.tokenizer_name:
tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
elif args.pretrained_model_name_or_path:
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
# Load models and create wrapper for stable diffusion
text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
vae.requires_grad_(False)
if not args.train_text_encoder:
text_encoder.requires_grad_(False)
if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
if args.train_text_encoder:
text_encoder.gradient_checkpointing_enable()
if args.scale_lr:
args.learning_rate = (
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
)
# Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
if args.use_8bit_adam:
try:
import bitsandbytes as bnb
except ImportError:
raise ImportError(
"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
)
optimizer_class = bnb.optim.AdamW8bit
else:
optimizer_class = torch.optim.AdamW
params_to_optimize = (
itertools.chain(unet.parameters(), text_encoder.parameters()) if args.train_text_encoder else unet.parameters()
)
optimizer = optimizer_class(
params_to_optimize,
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,
)
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
train_dataset = DreamBoothDataset(
instance_data_root=args.instance_data_dir,
instance_prompt=args.instance_prompt,
class_data_root=args.class_data_dir if args.with_prior_preservation else None,
class_prompt=args.class_prompt,
tokenizer=tokenizer,
size=args.resolution,
center_crop=args.center_crop,
)
def collate_fn(examples):
input_ids = [example["instance_prompt_ids"] for example in examples]
pixel_values = [example["instance_images"] for example in examples]
# Concat class and instance examples for prior preservation.
# We do this to avoid doing two forward passes.
if args.with_prior_preservation:
input_ids += [example["class_prompt_ids"] for example in examples]
pixel_values += [example["class_images"] for example in examples]
pior_pil = [example["class_PIL_images"] for example in examples]
masks = []
masked_images = []
for example in examples:
pil_image = example["PIL_images"]
# generate a random mask
mask = random_mask(pil_image.size, 1, False)
# prepare mask and masked image
mask, masked_image = prepare_mask_and_masked_image(pil_image, mask)
masks.append(mask)
masked_images.append(masked_image)
if args.with_prior_preservation:
for pil_image in pior_pil:
# generate a random mask
mask = random_mask(pil_image.size, 1, False)
# prepare mask and masked image
mask, masked_image = prepare_mask_and_masked_image(pil_image, mask)
masks.append(mask)
masked_images.append(masked_image)
pixel_values = torch.stack(pixel_values)
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids
masks = torch.stack(masks)
masked_images = torch.stack(masked_images)
batch = {"input_ids": input_ids, "pixel_values": pixel_values, "masks": masks, "masked_images": masked_images}
return batch
train_dataloader = torch.utils.data.DataLoader(
train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn
)
# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
overrode_max_train_steps = True
lr_scheduler = get_scheduler(
args.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
)
if args.train_text_encoder:
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
)
else:
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, optimizer, train_dataloader, lr_scheduler
)
weight_dtype = torch.float32
if args.mixed_precision == "fp16":
weight_dtype = torch.float16
elif args.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
# Move text_encode and vae to gpu.
# For mixed precision training we cast the text_encoder and vae weights to half-precision
# as these models are only used for inference, keeping weights in full precision is not required.
vae.to(accelerator.device, dtype=weight_dtype)
if not args.train_text_encoder:
text_encoder.to(accelerator.device, dtype=weight_dtype)
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if overrode_max_train_steps:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
# Afterwards we recalculate our number of training epochs
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
# We need to initialize the trackers we use, and also store our configuration.
# The trackers initializes automatically on the main process.
if accelerator.is_main_process:
accelerator.init_trackers("dreambooth", config=vars(args))
# Train!
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
logger.info("***** Running training *****")
logger.info(f" Num examples = {len(train_dataset)}")
logger.info(f" Num batches each epoch = {len(train_dataloader)}")
logger.info(f" Num Epochs = {args.num_train_epochs}")
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {args.max_train_steps}")
# Only show the progress bar once on each machine.
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
progress_bar.set_description("Steps")
global_step = 0
for epoch in range(args.num_train_epochs):
unet.train()
for step, batch in enumerate(train_dataloader):
with accelerator.accumulate(unet):
# Convert images to latent space
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
latents = latents * 0.18215
# Convert masked images to latent space
masked_latents = vae.encode(
batch["masked_images"].reshape(batch["pixel_values"].shape).to(dtype=weight_dtype)
).latent_dist.sample()
masked_latents = masked_latents * 0.18215
masks = batch["masks"]
# resize the mask to latents shape as we concatenate the mask to the latents
mask = torch.stack(
[
torch.nn.functional.interpolate(mask, size=(args.resolution // 8, args.resolution // 8))
for mask in masks
]
)
mask = mask.reshape(-1, 1, args.resolution // 8, args.resolution // 8)
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
bsz = latents.shape[0]
# Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
timesteps = timesteps.long()
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# concatenate the noised latents with the mask and the masked latents
latent_model_input = torch.cat([noisy_latents, mask, masked_latents], dim=1)
# Get the text embedding for conditioning
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
# Predict the noise residual
noise_pred = unet(latent_model_input, timesteps, encoder_hidden_states).sample
# Get the target for loss depending on the prediction type
if noise_scheduler.config.prediction_type == "epsilon":
target = noise
elif noise_scheduler.config.prediction_type == "v_prediction":
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
if args.with_prior_preservation:
# Chunk the noise and noise_pred into two parts and compute the loss on each part separately.
noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0)
target, target_prior = torch.chunk(target, 2, dim=0)
# Compute instance loss
loss = F.mse_loss(noise_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean()
# Compute prior loss
prior_loss = F.mse_loss(noise_pred_prior.float(), target_prior.float(), reduction="mean")
# Add the prior loss to the instance loss.
loss = loss + args.prior_loss_weight * prior_loss
else:
loss = F.mse_loss(noise_pred.float(), target.float(), reduction="mean")
accelerator.backward(loss)
if accelerator.sync_gradients:
params_to_clip = (
itertools.chain(unet.parameters(), text_encoder.parameters())
if args.train_text_encoder
else unet.parameters()
)
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
accelerator.log(logs, step=global_step)
if global_step >= args.max_train_steps:
break
accelerator.wait_for_everyone()
# Create the pipeline using using the trained modules and save it.
if accelerator.is_main_process:
pipeline = StableDiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
unet=accelerator.unwrap_model(unet),
text_encoder=accelerator.unwrap_model(text_encoder),
)
pipeline.save_pretrained(args.output_dir)
if args.push_to_hub:
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
accelerator.end_training()
if __name__ == "__main__":
main()

View File

@@ -1,12 +1,9 @@
# Overview
These examples show how to run [Diffuser](https://arxiv.org/abs/2205.09991) in Diffusers.
There are two ways to use the script, `run_diffuser_locomotion.py`.
The key option is a change of the variable `n_guide_steps`.
When `n_guide_steps=0`, the trajectories are sampled from the diffusion model, but not fine-tuned to maximize reward in the environment.
By default, `n_guide_steps=2` to match the original implementation.
These examples show how to run (Diffuser)[https://arxiv.org/abs/2205.09991] in Diffusers.
There are four scripts,
1. `run_diffuser_locomotion.py` to sample actions and run them in the environment,
2. and `run_diffuser_gen_trajectories.py` to just sample actions from the pre-trained diffusion model.
You will need some RL specific requirements to run the examples:

View File

@@ -0,0 +1,57 @@
import d4rl # noqa
import gym
import tqdm
from diffusers.experimental import ValueGuidedRLPipeline
config = dict(
n_samples=64,
horizon=32,
num_inference_steps=20,
n_guide_steps=0,
scale_grad_by_std=True,
scale=0.1,
eta=0.0,
t_grad_cutoff=2,
device="cpu",
)
if __name__ == "__main__":
env_name = "hopper-medium-v2"
env = gym.make(env_name)
pipeline = ValueGuidedRLPipeline.from_pretrained(
"bglick13/hopper-medium-v2-value-function-hor32",
env=env,
)
env.seed(0)
obs = env.reset()
total_reward = 0
total_score = 0
T = 1000
rollout = [obs.copy()]
try:
for t in tqdm.tqdm(range(T)):
# Call the policy
denorm_actions = pipeline(obs, planning_horizon=32)
# execute action in environment
next_observation, reward, terminal, _ = env.step(denorm_actions)
score = env.get_normalized_score(total_reward)
# update return
total_reward += reward
total_score += score
print(
f"Step: {t}, Reward: {reward}, Total Reward: {total_reward}, Score: {score}, Total Score:"
f" {total_score}"
)
# save observations for rendering
rollout.append(next_observation.copy())
obs = next_observation
except KeyboardInterrupt:
pass
print(f"Total reward: {total_reward}")

View File

@@ -8,7 +8,7 @@ config = dict(
n_samples=64,
horizon=32,
num_inference_steps=20,
n_guide_steps=2, # can set to 0 for faster sampling, does not use value network
n_guide_steps=2,
scale_grad_by_std=True,
scale=0.1,
eta=0.0,
@@ -40,7 +40,6 @@ if __name__ == "__main__":
# execute action in environment
next_observation, reward, terminal, _ = env.step(denorm_actions)
score = env.get_normalized_score(total_reward)
# update return
total_reward += reward
total_score += score
@@ -48,7 +47,6 @@ if __name__ == "__main__":
f"Step: {t}, Reward: {reward}, Total Reward: {total_reward}, Score: {score}, Total Score:"
f" {total_score}"
)
# save observations for rendering
rollout.append(next_observation.copy())

View File

@@ -12,18 +12,9 @@ ___This script is experimental. The script fine-tunes the whole model and often
Before running the scripts, make sure to install the library's training dependencies:
**Important**
To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
```bash
git clone https://github.com/huggingface/diffusers
cd diffusers
pip install .
```
Then cd in the example folder and run
```bash
pip install -r requirements.txt
pip install git+https://github.com/huggingface/diffusers.git
pip install -U -r requirements.txt
```
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
@@ -51,13 +42,11 @@ If you have already cloned the repo, then you won't need to go through these ste
#### Hardware
With `gradient_checkpointing` and `mixed_precision` it should be possible to fine tune the model on a single 24GB GPU. For higher `batch_size` and faster training it's better to use GPUs with >30GB memory.
**___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___**
```bash
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
export dataset_name="lambdalabs/pokemon-blip-captions"
accelerate launch --mixed_precision="fp16" train_text_to_image.py \
accelerate launch train_text_to_image.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--dataset_name=$dataset_name \
--use_ema \
@@ -65,6 +54,7 @@ accelerate launch --mixed_precision="fp16" train_text_to_image.py \
--train_batch_size=1 \
--gradient_accumulation_steps=4 \
--gradient_checkpointing \
--mixed_precision="fp16" \
--max_train_steps=15000 \
--learning_rate=1e-05 \
--max_grad_norm=1 \
@@ -80,7 +70,7 @@ If you wish to use custom loading logic, you should modify the script, we have l
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
export TRAIN_DIR="path_to_your_dataset"
accelerate launch --mixed_precision="fp16" train_text_to_image.py \
accelerate launch train_text_to_image.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--train_data_dir=$TRAIN_DIR \
--use_ema \
@@ -88,6 +78,7 @@ accelerate launch --mixed_precision="fp16" train_text_to_image.py \
--train_batch_size=1 \
--gradient_accumulation_steps=4 \
--gradient_checkpointing \
--mixed_precision="fp16" \
--max_train_steps=15000 \
--learning_rate=1e-05 \
--max_grad_norm=1 \

View File

@@ -1,7 +1,7 @@
diffusers==0.4.1
accelerate
torchvision
transformers>=4.21.0
datasets
ftfy
tensorboard
modelcards

View File

@@ -1,5 +1,5 @@
diffusers>==0.5.1
transformers>=4.21.0
datasets
flax
optax
torch

View File

@@ -15,18 +15,15 @@ from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import set_seed
from datasets import load_dataset
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel
from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from huggingface_hub import HfFolder, Repository, whoami
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.10.0.dev0")
logger = get_logger(__name__)
@@ -39,13 +36,6 @@ def parse_args():
required=True,
help="Path to pretrained model or model identifier from huggingface.co/models.",
)
parser.add_argument(
"--revision",
type=str,
default=None,
required=False,
help="Revision of pretrained model identifier from huggingface.co/models.",
)
parser.add_argument(
"--dataset_name",
type=str,
@@ -196,12 +186,12 @@ def parse_args():
parser.add_argument(
"--mixed_precision",
type=str,
default=None,
default="no",
choices=["no", "fp16", "bf16"],
help=(
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
"Whether to use mixed precision. Choose"
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
"and an Nvidia Ampere GPU."
),
)
parser.add_argument(
@@ -345,24 +335,10 @@ def main():
os.makedirs(args.output_dir, exist_ok=True)
# Load models and create wrapper for stable diffusion
tokenizer = CLIPTokenizer.from_pretrained(
args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
)
text_encoder = CLIPTextModel.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="text_encoder",
revision=args.revision,
)
vae = AutoencoderKL.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="vae",
revision=args.revision,
)
unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="unet",
revision=args.revision,
)
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
# Freeze vae and text_encoder
vae.requires_grad_(False)
@@ -396,7 +372,7 @@ def main():
weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,
)
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
noise_scheduler = DDPMScheduler.from_config(args.pretrained_model_name_or_path, subfolder="scheduler")
# Get the datasets: you can either provide your own training and evaluation files (see below)
# or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
@@ -520,9 +496,9 @@ def main():
)
weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16":
if args.mixed_precision == "fp16":
weight_dtype = torch.float16
elif accelerator.mixed_precision == "bf16":
elif args.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
# Move text_encode and vae to gpu.
@@ -586,17 +562,9 @@ def main():
# Get the text embedding for conditioning
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
# Get the target for loss depending on the prediction type
if noise_scheduler.config.prediction_type == "epsilon":
target = noise
elif noise_scheduler.config.prediction_type == "v_prediction":
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
# Predict the noise residual and compute loss
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
# Gather the losses across all processes for logging (if we use distributed training).
avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
@@ -632,12 +600,14 @@ def main():
if args.use_ema:
ema_unet.copy_to(unet.parameters())
pipeline = StableDiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
pipeline = StableDiffusionPipeline(
text_encoder=text_encoder,
vae=vae,
unet=unet,
revision=args.revision,
tokenizer=tokenizer,
scheduler=PNDMScheduler.from_config(args.pretrained_model_name_or_path, subfolder="scheduler"),
safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"),
feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
)
pipeline.save_pretrained(args.output_dir)

View File

@@ -23,7 +23,6 @@ from diffusers import (
FlaxUNet2DConditionModel,
)
from diffusers.pipelines.stable_diffusion import FlaxStableDiffusionSafetyChecker
from diffusers.utils import check_min_version
from flax import jax_utils
from flax.training import train_state
from flax.training.common_utils import shard
@@ -33,9 +32,6 @@ from tqdm.auto import tqdm
from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel, set_seed
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.10.0.dev0")
logger = logging.getLogger(__name__)

View File

@@ -16,18 +16,8 @@ Colab for inference
Before running the scripts, make sure to install the library's training dependencies:
**Important**
To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
```bash
git clone https://github.com/huggingface/diffusers
cd diffusers
pip install .
```
Then cd in the example folder and run
```bash
pip install -r requirements.txt
pip install diffusers"[training]" accelerate "transformers>=4.21.0"
```
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
@@ -57,8 +47,6 @@ Now let's get our dataset.Download 3-4 images from [here](https://drive.google.c
And launch the training using
**___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___**
```bash
export MODEL_NAME="runwayml/stable-diffusion-v1-5"
export DATA_DIR="path-to-dir-containing-images"

View File

@@ -1,6 +1,3 @@
accelerate
torchvision
transformers>=4.21.0
ftfy
tensorboard
modelcards

View File

@@ -1,3 +1,4 @@
diffusers>==0.5.1
transformers>=4.21.0
flax
optax

View File

@@ -19,7 +19,6 @@ from accelerate.utils import set_seed
from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel
from diffusers.optimization import get_scheduler
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from diffusers.utils import check_min_version
from huggingface_hub import HfFolder, Repository, whoami
# TODO: remove and import from diffusers.utils when the new version of diffusers is released
@@ -49,18 +48,14 @@ else:
# ------------------------------------------------------------------------------
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.10.0.dev0")
logger = get_logger(__name__)
def save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path):
def save_progress(text_encoder, placeholder_token_id, accelerator, args):
logger.info("Saving embeddings")
learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[placeholder_token_id]
learned_embeds_dict = {args.placeholder_token: learned_embeds.detach().cpu()}
torch.save(learned_embeds_dict, save_path)
torch.save(learned_embeds_dict, os.path.join(args.output_dir, "learned_embeds.bin"))
def parse_args():
@@ -71,12 +66,6 @@ def parse_args():
default=500,
help="Save learned_embeds.bin every X updates steps.",
)
parser.add_argument(
"--only_save_embeds",
action="store_true",
default=False,
help="Save only the embeddings for the new concept.",
)
parser.add_argument(
"--pretrained_model_name_or_path",
type=str,
@@ -84,13 +73,6 @@ def parse_args():
required=True,
help="Path to pretrained model or model identifier from huggingface.co/models.",
)
parser.add_argument(
"--revision",
type=str,
default=None,
required=False,
help="Revision of pretrained model identifier from huggingface.co/models.",
)
parser.add_argument(
"--tokenizer_name",
type=str,
@@ -423,21 +405,9 @@ def main():
placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token)
# Load models and create wrapper for stable diffusion
text_encoder = CLIPTextModel.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="text_encoder",
revision=args.revision,
)
vae = AutoencoderKL.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="vae",
revision=args.revision,
)
unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="unet",
revision=args.revision,
)
text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
# Resize the token embeddings as we are adding new special tokens to the tokenizer
text_encoder.resize_token_embeddings(len(tokenizer))
@@ -471,7 +441,7 @@ def main():
eps=args.adam_epsilon,
)
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
noise_scheduler = DDPMScheduler.from_config(args.pretrained_model_name_or_path, subfolder="scheduler")
train_dataset = TextualInversionDataset(
data_root=args.train_data_dir,
@@ -562,17 +532,9 @@ def main():
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
# Predict the noise residual
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
# Get the target for loss depending on the prediction type
if noise_scheduler.config.prediction_type == "epsilon":
target = noise
elif noise_scheduler.config.prediction_type == "v_prediction":
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
loss = F.mse_loss(model_pred, target, reduction="none").mean([1, 2, 3]).mean()
loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
accelerator.backward(loss)
# Zero out the gradients for all token embeddings except the newly added
@@ -594,8 +556,7 @@ def main():
progress_bar.update(1)
global_step += 1
if global_step % args.save_steps == 0:
save_path = os.path.join(args.output_dir, f"learned_embeds-steps-{global_step}.bin")
save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path)
save_progress(text_encoder, placeholder_token_id, accelerator, args)
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
@@ -608,25 +569,18 @@ def main():
# Create the pipeline using using the trained modules and save it.
if accelerator.is_main_process:
if args.push_to_hub and args.only_save_embeds:
logger.warn("Enabling full model saving because --push_to_hub=True was specified.")
save_full_model = True
else:
save_full_model = not args.only_save_embeds
if save_full_model:
pipeline = StableDiffusionPipeline(
text_encoder=accelerator.unwrap_model(text_encoder),
vae=vae,
unet=unet,
tokenizer=tokenizer,
scheduler=PNDMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler"),
safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"),
feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
)
pipeline.save_pretrained(args.output_dir)
# Save the newly trained embeddings
save_path = os.path.join(args.output_dir, "learned_embeds.bin")
save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path)
pipeline = StableDiffusionPipeline(
text_encoder=accelerator.unwrap_model(text_encoder),
vae=vae,
unet=unet,
tokenizer=tokenizer,
scheduler=PNDMScheduler.from_config(args.pretrained_model_name_or_path, subfolder="scheduler"),
safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"),
feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
)
pipeline.save_pretrained(args.output_dir)
# Also save the newly trained embeddings
save_progress(text_encoder, placeholder_token_id, accelerator, args)
if args.push_to_hub:
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)

View File

@@ -24,7 +24,6 @@ from diffusers import (
FlaxUNet2DConditionModel,
)
from diffusers.pipelines.stable_diffusion import FlaxStableDiffusionSafetyChecker
from diffusers.utils import check_min_version
from flax import jax_utils
from flax.training import train_state
from flax.training.common_utils import shard
@@ -56,9 +55,6 @@ else:
}
# ------------------------------------------------------------------------------
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.10.0.dev0")
logger = logging.getLogger(__name__)

View File

@@ -6,21 +6,10 @@ Creating a training image set is [described in a different document](https://hug
Before running the scripts, make sure to install the library's training dependencies:
**Important**
To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
```bash
git clone https://github.com/huggingface/diffusers
cd diffusers
pip install .
pip install diffusers[training] accelerate datasets tensorboard
```
Then cd in the example folder and run
```bash
pip install -r requirements.txt
```
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
```bash

View File

@@ -11,11 +11,12 @@ import torch.nn.functional as F
from accelerate import Accelerator
from accelerate.logging import get_logger
from datasets import load_dataset
from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel, __version__
from diffusers.optimization import get_scheduler
from diffusers.training_utils import EMAModel
from diffusers.utils import check_min_version
from diffusers.utils import deprecate
from huggingface_hub import HfFolder, Repository, whoami
from packaging import version
from torchvision.transforms import (
CenterCrop,
Compose,
@@ -28,11 +29,8 @@ from torchvision.transforms import (
from tqdm.auto import tqdm
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.10.0.dev0")
logger = get_logger(__name__)
diffusers_version = version.parse(version.parse(__version__).base_version)
def _extract_into_tensor(arr, timesteps, broadcast_shape):
@@ -196,10 +194,9 @@ def parse_args():
)
parser.add_argument(
"--prediction_type",
type=str,
default="epsilon",
choices=["epsilon", "sample"],
"--predict_epsilon",
action="store_true",
default=True,
help="Whether the model should predict the 'epsilon'/noise error or directly the reconstructed image 'x0'.",
)
@@ -259,13 +256,13 @@ def main(args):
"UpBlock2D",
),
)
accepts_prediction_type = "prediction_type" in set(inspect.signature(DDPMScheduler.__init__).parameters.keys())
accepts_predict_epsilon = "predict_epsilon" in set(inspect.signature(DDPMScheduler.__init__).parameters.keys())
if accepts_prediction_type:
if accepts_predict_epsilon:
noise_scheduler = DDPMScheduler(
num_train_timesteps=args.ddpm_num_steps,
beta_schedule=args.ddpm_beta_schedule,
prediction_type=args.prediction_type,
predict_epsilon=args.predict_epsilon,
)
else:
noise_scheduler = DDPMScheduler(num_train_timesteps=args.ddpm_num_steps, beta_schedule=args.ddpm_beta_schedule)
@@ -322,12 +319,7 @@ def main(args):
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
ema_model = EMAModel(
accelerator.unwrap_model(model),
inv_gamma=args.ema_inv_gamma,
power=args.ema_power,
max_value=args.ema_max_decay,
)
ema_model = EMAModel(model, inv_gamma=args.ema_inv_gamma, power=args.ema_power, max_value=args.ema_max_decay)
# Handle the repository creation
if accelerator.is_main_process:
@@ -373,9 +365,9 @@ def main(args):
# Predict the noise residual
model_output = model(noisy_images, timesteps).sample
if args.prediction_type == "epsilon":
if args.predict_epsilon:
loss = F.mse_loss(model_output, noise) # this could have different weights!
elif args.prediction_type == "sample":
else:
alpha_t = _extract_into_tensor(
noise_scheduler.alphas_cumprod, timesteps, (clean_images.shape[0], 1, 1, 1)
)
@@ -384,8 +376,6 @@ def main(args):
model_output, clean_images, reduction="none"
) # use SNR weighting from distillation paper
loss = loss.mean()
else:
raise ValueError(f"Unsupported prediction type: {args.prediction_type}")
accelerator.backward(loss)
@@ -419,7 +409,11 @@ def main(args):
scheduler=noise_scheduler,
)
generator = torch.Generator(device=pipeline.device).manual_seed(0)
deprecate("todo: remove this check", "0.10.0", "when the most used version is >= 0.8.0")
if diffusers_version < version.parse("0.8.0"):
generator = torch.manual_seed(0)
else:
generator = torch.Generator(device=pipeline.device).manual_seed(0)
# run pipeline in inference (sample random noise and denoise)
images = pipeline(
generator=generator,

View File

@@ -1,8 +1,6 @@
import argparse
import math
import os
from pathlib import Path
from typing import Optional
import torch
import torch.nn.functional as F
@@ -11,10 +9,9 @@ from accelerate import Accelerator
from accelerate.logging import get_logger
from datasets import load_dataset
from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
from diffusers.hub_utils import init_git_repo, push_to_hub
from diffusers.optimization import get_scheduler
from diffusers.training_utils import EMAModel
from diffusers.utils import check_min_version
from huggingface_hub import HfFolder, Repository, whoami
from onnxruntime.training.ortmodule import ORTModule
from torchvision.transforms import (
CenterCrop,
@@ -28,22 +25,9 @@ from torchvision.transforms import (
from tqdm.auto import tqdm
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.10.0.dev0")
logger = get_logger(__name__)
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
if token is None:
token = HfFolder.get_token()
if organization is None:
username = whoami(token)["name"]
return f"{username}/{model_id}"
else:
return f"{organization}/{model_id}"
def main(args):
logging_dir = os.path.join(args.output_dir, args.logging_dir)
accelerator = Accelerator(
@@ -129,22 +113,8 @@ def main(args):
ema_model = EMAModel(model, inv_gamma=args.ema_inv_gamma, power=args.ema_power, max_value=args.ema_max_decay)
# Handle the repository creation
if accelerator.is_main_process:
if args.push_to_hub:
if args.hub_model_id is None:
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
else:
repo_name = args.hub_model_id
repo = Repository(args.output_dir, clone_from=repo_name)
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
if "step_*" not in gitignore:
gitignore.write("step_*\n")
if "epoch_*" not in gitignore:
gitignore.write("epoch_*\n")
elif args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True)
if args.push_to_hub:
repo = init_git_repo(args, at_init=True)
if accelerator.is_main_process:
run = os.path.split(__file__)[-1].split(".")[0]
@@ -216,9 +186,10 @@ def main(args):
if epoch % args.save_model_epochs == 0 or epoch == args.num_epochs - 1:
# save the model
pipeline.save_pretrained(args.output_dir)
if args.push_to_hub:
repo.push_to_hub(commit_message=f"Epoch {epoch}", blocking=False)
push_to_hub(args, pipeline, repo, commit_message=f"Epoch {epoch}", blocking=False)
else:
pipeline.save_pretrained(args.output_dir)
accelerator.wait_for_everyone()
accelerator.end_training()

View File

@@ -33,7 +33,6 @@ from diffusers import (
DPMSolverMultistepScheduler,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
HeunDiscreteScheduler,
LDMTextToImagePipeline,
LMSDiscreteScheduler,
PNDMScheduler,
@@ -41,9 +40,8 @@ from diffusers import (
UNet2DConditionModel,
)
from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel
from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder, PaintByExamplePipeline
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from transformers import AutoFeatureExtractor, BertTokenizerFast, CLIPTextModel, CLIPTokenizer, CLIPVisionConfig
from transformers import AutoFeatureExtractor, BertTokenizerFast, CLIPTextModel, CLIPTokenizer
def shave_segments(path, n_shave_prefix_segments=1):
@@ -209,12 +207,11 @@ def conv_attn_to_linear(checkpoint):
checkpoint[key] = checkpoint[key][:, :, 0]
def create_unet_diffusers_config(original_config, image_size: int):
def create_unet_diffusers_config(original_config):
"""
Creates a config for the diffusers based on the config of the LDM model.
"""
unet_params = original_config.model.params.unet_config.params
vae_params = original_config.model.params.first_stage_config.params.ddconfig
block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult]
@@ -232,19 +229,8 @@ def create_unet_diffusers_config(original_config, image_size: int):
up_block_types.append(block_type)
resolution //= 2
vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1)
head_dim = unet_params.num_heads if "num_heads" in unet_params else None
use_linear_projection = (
unet_params.use_linear_in_transformer if "use_linear_in_transformer" in unet_params else False
)
if use_linear_projection:
# stable diffusion 2-base-512 and 2-768
if head_dim is None:
head_dim = [5, 10, 20, 20]
config = dict(
sample_size=image_size // vae_scale_factor,
sample_size=unet_params.image_size,
in_channels=unet_params.in_channels,
out_channels=unet_params.out_channels,
down_block_types=tuple(down_block_types),
@@ -252,14 +238,13 @@ def create_unet_diffusers_config(original_config, image_size: int):
block_out_channels=tuple(block_out_channels),
layers_per_block=unet_params.num_res_blocks,
cross_attention_dim=unet_params.context_dim,
attention_head_dim=head_dim,
use_linear_projection=use_linear_projection,
attention_head_dim=unet_params.num_heads,
)
return config
def create_vae_diffusers_config(original_config, image_size: int):
def create_vae_diffusers_config(original_config):
"""
Creates a config for the diffusers based on the config of the LDM model.
"""
@@ -271,7 +256,7 @@ def create_vae_diffusers_config(original_config, image_size: int):
up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
config = dict(
sample_size=image_size,
sample_size=vae_params.resolution,
in_channels=vae_params.in_channels,
out_channels=vae_params.out_ch,
down_block_types=tuple(down_block_types),
@@ -648,89 +633,6 @@ def convert_ldm_clip_checkpoint(checkpoint):
return text_model
def convert_paint_by_example_checkpoint(checkpoint):
config = CLIPVisionConfig.from_pretrained("openai/clip-vit-large-patch14")
model = PaintByExampleImageEncoder(config)
keys = list(checkpoint.keys())
text_model_dict = {}
for key in keys:
if key.startswith("cond_stage_model.transformer"):
text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
# load clip vision
model.model.load_state_dict(text_model_dict)
# load mapper
keys_mapper = {
k[len("cond_stage_model.mapper.res") :]: v
for k, v in checkpoint.items()
if k.startswith("cond_stage_model.mapper")
}
MAPPING = {
"attn.c_qkv": ["attn1.to_q", "attn1.to_k", "attn1.to_v"],
"attn.c_proj": ["attn1.to_out.0"],
"ln_1": ["norm1"],
"ln_2": ["norm3"],
"mlp.c_fc": ["ff.net.0.proj"],
"mlp.c_proj": ["ff.net.2"],
}
mapped_weights = {}
for key, value in keys_mapper.items():
prefix = key[: len("blocks.i")]
suffix = key.split(prefix)[-1].split(".")[-1]
name = key.split(prefix)[-1].split(suffix)[0][1:-1]
mapped_names = MAPPING[name]
num_splits = len(mapped_names)
for i, mapped_name in enumerate(mapped_names):
new_name = ".".join([prefix, mapped_name, suffix])
shape = value.shape[0] // num_splits
mapped_weights[new_name] = value[i * shape : (i + 1) * shape]
model.mapper.load_state_dict(mapped_weights)
# load final layer norm
model.final_layer_norm.load_state_dict(
{
"bias": checkpoint["cond_stage_model.final_ln.bias"],
"weight": checkpoint["cond_stage_model.final_ln.weight"],
}
)
# load final proj
model.proj_out.load_state_dict(
{
"bias": checkpoint["proj_out.bias"],
"weight": checkpoint["proj_out.weight"],
}
)
# load uncond vector
model.uncond_vector.data = torch.nn.Parameter(checkpoint["learnable_vector"])
return model
def convert_open_clip_checkpoint(checkpoint):
text_model = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2", subfolder="text_encoder")
# SKIP for now - need openclip -> HF conversion script here
# keys = list(checkpoint.keys())
#
# text_model_dict = {}
# for key in keys:
# if key.startswith("cond_stage_model.model.transformer"):
# text_model_dict[key[len("cond_stage_model.model.transformer.") :]] = checkpoint[key]
#
# text_model.load_state_dict(text_model_dict)
return text_model
if __name__ == "__main__":
parser = argparse.ArgumentParser()
@@ -744,42 +646,12 @@ if __name__ == "__main__":
type=str,
help="The YAML config file corresponding to the original architecture.",
)
parser.add_argument(
"--num_in_channels",
default=None,
type=int,
help="The number of input channels. If `None` number of input channels will be automatically inferred.",
)
parser.add_argument(
"--scheduler_type",
default="pndm",
type=str,
help="Type of scheduler to use. Should be one of ['pndm', 'lms', 'ddim', 'euler', 'euler-ancest', 'dpm']",
)
parser.add_argument(
"--pipeline_type",
default=None,
type=str,
help="The pipeline type. If `None` pipeline will be automatically inferred.",
)
parser.add_argument(
"--image_size",
default=None,
type=int,
help=(
"The image size that the model was trained on. Use 512 for Stable Diffusion v1.X and Stable Siffusion v2"
" Base. Use 768 for Stable Diffusion v2."
),
)
parser.add_argument(
"--prediction_type",
default=None,
type=str,
help=(
"The prediction type that the model was trained on. Use 'epsilon' for Stable Diffusion v1.X and Stable"
" Siffusion v2 Base. Use 'v-prediction' for Stable Diffusion v2."
),
)
parser.add_argument(
"--extract_ema",
action="store_true",
@@ -790,135 +662,73 @@ if __name__ == "__main__":
),
)
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
args = parser.parse_args()
image_size = args.image_size
prediction_type = args.prediction_type
checkpoint = torch.load(args.checkpoint_path)
global_step = checkpoint["global_step"]
checkpoint = checkpoint["state_dict"]
if args.original_config_file is None:
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
if key_name in checkpoint and checkpoint[key_name].shape[-1] == 1024:
# model_type = "v2"
os.system(
"wget https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml"
)
args.original_config_file = "./v2-inference-v.yaml"
else:
# model_type = "v1"
os.system(
"wget https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
)
args.original_config_file = "./v1-inference.yaml"
os.system(
"wget https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
)
args.original_config_file = "./v1-inference.yaml"
original_config = OmegaConf.load(args.original_config_file)
if args.num_in_channels is not None:
original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = args.num_in_channels
if (
"parameterization" in original_config["model"]["params"]
and original_config["model"]["params"]["parameterization"] == "v"
):
if prediction_type is None:
# NOTE: For stable diffusion 2 base it is recommended to pass `prediction_type=="epsilon"`
# as it relies on a brittle global step parameter here
prediction_type = "epsilon" if global_step == 875000 else "v_prediction"
if image_size is None:
# NOTE: For stable diffusion 2 base one has to pass `image_size==512`
# as it relies on a brittle global step parameter here
image_size = 512 if global_step == 875000 else 768
else:
if prediction_type is None:
prediction_type = "epsilon"
if image_size is None:
image_size = 512
checkpoint = torch.load(args.checkpoint_path)
checkpoint = checkpoint["state_dict"]
num_train_timesteps = original_config.model.params.timesteps
beta_start = original_config.model.params.linear_start
beta_end = original_config.model.params.linear_end
scheduler = DDIMScheduler(
beta_end=beta_end,
beta_schedule="scaled_linear",
beta_start=beta_start,
num_train_timesteps=num_train_timesteps,
steps_offset=1,
clip_sample=False,
set_alpha_to_one=False,
prediction_type=prediction_type,
)
if args.scheduler_type == "pndm":
config = dict(scheduler.config)
config["skip_prk_steps"] = True
scheduler = PNDMScheduler.from_config(config)
scheduler = PNDMScheduler(
beta_end=beta_end,
beta_schedule="scaled_linear",
beta_start=beta_start,
num_train_timesteps=num_train_timesteps,
skip_prk_steps=True,
)
elif args.scheduler_type == "lms":
scheduler = LMSDiscreteScheduler.from_config(scheduler.config)
elif args.scheduler_type == "heun":
scheduler = HeunDiscreteScheduler.from_config(scheduler.config)
scheduler = LMSDiscreteScheduler(beta_start=beta_start, beta_end=beta_end, beta_schedule="scaled_linear")
elif args.scheduler_type == "euler":
scheduler = EulerDiscreteScheduler.from_config(scheduler.config)
scheduler = EulerDiscreteScheduler(beta_start=beta_start, beta_end=beta_end, beta_schedule="scaled_linear")
elif args.scheduler_type == "euler-ancestral":
scheduler = EulerAncestralDiscreteScheduler.from_config(scheduler.config)
scheduler = EulerAncestralDiscreteScheduler(
beta_start=beta_start, beta_end=beta_end, beta_schedule="scaled_linear"
)
elif args.scheduler_type == "dpm":
scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config)
scheduler = DPMSolverMultistepScheduler(
beta_start=beta_start, beta_end=beta_end, beta_schedule="scaled_linear"
)
elif args.scheduler_type == "ddim":
scheduler = scheduler
scheduler = DDIMScheduler(
beta_start=beta_start,
beta_end=beta_end,
beta_schedule="scaled_linear",
clip_sample=False,
set_alpha_to_one=False,
)
else:
raise ValueError(f"Scheduler of type {args.scheduler_type} doesn't exist!")
# Convert the UNet2DConditionModel model.
unet_config = create_unet_diffusers_config(original_config, image_size=image_size)
unet = UNet2DConditionModel(**unet_config)
unet_config = create_unet_diffusers_config(original_config)
converted_unet_checkpoint = convert_ldm_unet_checkpoint(
checkpoint, unet_config, path=args.checkpoint_path, extract_ema=args.extract_ema
)
unet = UNet2DConditionModel(**unet_config)
unet.load_state_dict(converted_unet_checkpoint)
# Convert the VAE model.
vae_config = create_vae_diffusers_config(original_config, image_size=image_size)
vae_config = create_vae_diffusers_config(original_config)
converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
vae = AutoencoderKL(**vae_config)
vae.load_state_dict(converted_vae_checkpoint)
# Convert the text model.
model_type = args.pipeline_type
if model_type is None:
model_type = original_config.model.params.cond_stage_config.target.split(".")[-1]
if model_type == "FrozenOpenCLIPEmbedder":
text_model = convert_open_clip_checkpoint(checkpoint)
tokenizer = CLIPTokenizer.from_pretrained("stabilityai/stable-diffusion-2", subfolder="tokenizer")
pipe = StableDiffusionPipeline(
vae=vae,
text_encoder=text_model,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False,
)
elif model_type == "PaintByExample":
vision_model = convert_paint_by_example_checkpoint(checkpoint)
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker")
pipe = PaintByExamplePipeline(
vae=vae,
image_encoder=vision_model,
unet=unet,
scheduler=scheduler,
safety_checker=None,
feature_extractor=feature_extractor,
)
elif model_type == "FrozenCLIPEmbedder":
text_model_type = original_config.model.params.cond_stage_config.target.split(".")[-1]
if text_model_type == "FrozenCLIPEmbedder":
text_model = convert_ldm_clip_checkpoint(checkpoint)
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")

View File

@@ -215,10 +215,8 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F
)
del pipeline.safety_checker
safety_checker = OnnxRuntimeModel.from_pretrained(output_path / "safety_checker")
feature_extractor = pipeline.feature_extractor
else:
safety_checker = None
feature_extractor = None
onnx_pipeline = OnnxStableDiffusionPipeline(
vae_encoder=OnnxRuntimeModel.from_pretrained(output_path / "vae_encoder"),
@@ -228,8 +226,7 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F
unet=OnnxRuntimeModel.from_pretrained(output_path / "unet"),
scheduler=pipeline.scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
requires_safety_checker=safety_checker is not None,
feature_extractor=pipeline.feature_extractor,
)
onnx_pipeline.save_pretrained(output_path)

View File

@@ -1,791 +0,0 @@
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Conversion script for the Versatile Stable Diffusion checkpoints. """
import argparse
from argparse import Namespace
import torch
from diffusers import (
AutoencoderKL,
DDIMScheduler,
DPMSolverMultistepScheduler,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
LMSDiscreteScheduler,
PNDMScheduler,
UNet2DConditionModel,
VersatileDiffusionPipeline,
)
from diffusers.pipelines.versatile_diffusion.modeling_text_unet import UNetFlatConditionModel
from transformers import (
CLIPFeatureExtractor,
CLIPTextModelWithProjection,
CLIPTokenizer,
CLIPVisionModelWithProjection,
)
SCHEDULER_CONFIG = Namespace(
**{
"beta_linear_start": 0.00085,
"beta_linear_end": 0.012,
"timesteps": 1000,
"scale_factor": 0.18215,
}
)
IMAGE_UNET_CONFIG = Namespace(
**{
"input_channels": 4,
"model_channels": 320,
"output_channels": 4,
"num_noattn_blocks": [2, 2, 2, 2],
"channel_mult": [1, 2, 4, 4],
"with_attn": [True, True, True, False],
"num_heads": 8,
"context_dim": 768,
"use_checkpoint": True,
}
)
TEXT_UNET_CONFIG = Namespace(
**{
"input_channels": 768,
"model_channels": 320,
"output_channels": 768,
"num_noattn_blocks": [2, 2, 2, 2],
"channel_mult": [1, 2, 4, 4],
"second_dim": [4, 4, 4, 4],
"with_attn": [True, True, True, False],
"num_heads": 8,
"context_dim": 768,
"use_checkpoint": True,
}
)
AUTOENCODER_CONFIG = Namespace(
**{
"double_z": True,
"z_channels": 4,
"resolution": 256,
"in_channels": 3,
"out_ch": 3,
"ch": 128,
"ch_mult": [1, 2, 4, 4],
"num_res_blocks": 2,
"attn_resolutions": [],
"dropout": 0.0,
}
)
def shave_segments(path, n_shave_prefix_segments=1):
"""
Removes segments. Positive values shave the first segments, negative shave the last segments.
"""
if n_shave_prefix_segments >= 0:
return ".".join(path.split(".")[n_shave_prefix_segments:])
else:
return ".".join(path.split(".")[:n_shave_prefix_segments])
def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
"""
Updates paths inside resnets to the new naming scheme (local renaming)
"""
mapping = []
for old_item in old_list:
new_item = old_item.replace("in_layers.0", "norm1")
new_item = new_item.replace("in_layers.2", "conv1")
new_item = new_item.replace("out_layers.0", "norm2")
new_item = new_item.replace("out_layers.3", "conv2")
new_item = new_item.replace("emb_layers.1", "time_emb_proj")
new_item = new_item.replace("skip_connection", "conv_shortcut")
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
mapping.append({"old": old_item, "new": new_item})
return mapping
def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
"""
Updates paths inside resnets to the new naming scheme (local renaming)
"""
mapping = []
for old_item in old_list:
new_item = old_item
new_item = new_item.replace("nin_shortcut", "conv_shortcut")
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
mapping.append({"old": old_item, "new": new_item})
return mapping
def renew_attention_paths(old_list, n_shave_prefix_segments=0):
"""
Updates paths inside attentions to the new naming scheme (local renaming)
"""
mapping = []
for old_item in old_list:
new_item = old_item
# new_item = new_item.replace('norm.weight', 'group_norm.weight')
# new_item = new_item.replace('norm.bias', 'group_norm.bias')
# new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
# new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
# new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
mapping.append({"old": old_item, "new": new_item})
return mapping
def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
"""
Updates paths inside attentions to the new naming scheme (local renaming)
"""
mapping = []
for old_item in old_list:
new_item = old_item
new_item = new_item.replace("norm.weight", "group_norm.weight")
new_item = new_item.replace("norm.bias", "group_norm.bias")
new_item = new_item.replace("q.weight", "query.weight")
new_item = new_item.replace("q.bias", "query.bias")
new_item = new_item.replace("k.weight", "key.weight")
new_item = new_item.replace("k.bias", "key.bias")
new_item = new_item.replace("v.weight", "value.weight")
new_item = new_item.replace("v.bias", "value.bias")
new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
mapping.append({"old": old_item, "new": new_item})
return mapping
def assign_to_checkpoint(
paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
):
"""
This does the final conversion step: take locally converted weights and apply a global renaming
to them. It splits attention layers, and takes into account additional replacements
that may arise.
Assigns the weights to the new checkpoint.
"""
assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
# Splits the attention layers into three variables.
if attention_paths_to_split is not None:
for path, path_map in attention_paths_to_split.items():
old_tensor = old_checkpoint[path]
channels = old_tensor.shape[0] // 3
target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
query, key, value = old_tensor.split(channels // num_heads, dim=1)
checkpoint[path_map["query"]] = query.reshape(target_shape)
checkpoint[path_map["key"]] = key.reshape(target_shape)
checkpoint[path_map["value"]] = value.reshape(target_shape)
for path in paths:
new_path = path["new"]
# These have already been assigned
if attention_paths_to_split is not None and new_path in attention_paths_to_split:
continue
# Global renaming happens here
new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
if additional_replacements is not None:
for replacement in additional_replacements:
new_path = new_path.replace(replacement["old"], replacement["new"])
# proj_attn.weight has to be converted from conv 1D to linear
if "proj_attn.weight" in new_path:
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
elif path["old"] in old_checkpoint:
checkpoint[new_path] = old_checkpoint[path["old"]]
def conv_attn_to_linear(checkpoint):
keys = list(checkpoint.keys())
attn_keys = ["query.weight", "key.weight", "value.weight"]
for key in keys:
if ".".join(key.split(".")[-2:]) in attn_keys:
if checkpoint[key].ndim > 2:
checkpoint[key] = checkpoint[key][:, :, 0, 0]
elif "proj_attn.weight" in key:
if checkpoint[key].ndim > 2:
checkpoint[key] = checkpoint[key][:, :, 0]
def create_image_unet_diffusers_config(unet_params):
"""
Creates a config for the diffusers based on the config of the VD model.
"""
block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult]
down_block_types = []
resolution = 1
for i in range(len(block_out_channels)):
block_type = "CrossAttnDownBlock2D" if unet_params.with_attn[i] else "DownBlock2D"
down_block_types.append(block_type)
if i != len(block_out_channels) - 1:
resolution *= 2
up_block_types = []
for i in range(len(block_out_channels)):
block_type = "CrossAttnUpBlock2D" if unet_params.with_attn[-i - 1] else "UpBlock2D"
up_block_types.append(block_type)
resolution //= 2
if not all(n == unet_params.num_noattn_blocks[0] for n in unet_params.num_noattn_blocks):
raise ValueError("Not all num_res_blocks are equal, which is not supported in this script.")
config = dict(
sample_size=None,
in_channels=unet_params.input_channels,
out_channels=unet_params.output_channels,
down_block_types=tuple(down_block_types),
up_block_types=tuple(up_block_types),
block_out_channels=tuple(block_out_channels),
layers_per_block=unet_params.num_noattn_blocks[0],
cross_attention_dim=unet_params.context_dim,
attention_head_dim=unet_params.num_heads,
)
return config
def create_text_unet_diffusers_config(unet_params):
"""
Creates a config for the diffusers based on the config of the VD model.
"""
block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult]
down_block_types = []
resolution = 1
for i in range(len(block_out_channels)):
block_type = "CrossAttnDownBlockFlat" if unet_params.with_attn[i] else "DownBlockFlat"
down_block_types.append(block_type)
if i != len(block_out_channels) - 1:
resolution *= 2
up_block_types = []
for i in range(len(block_out_channels)):
block_type = "CrossAttnUpBlockFlat" if unet_params.with_attn[-i - 1] else "UpBlockFlat"
up_block_types.append(block_type)
resolution //= 2
if not all(n == unet_params.num_noattn_blocks[0] for n in unet_params.num_noattn_blocks):
raise ValueError("Not all num_res_blocks are equal, which is not supported in this script.")
config = dict(
sample_size=None,
in_channels=(unet_params.input_channels, 1, 1),
out_channels=(unet_params.output_channels, 1, 1),
down_block_types=tuple(down_block_types),
up_block_types=tuple(up_block_types),
block_out_channels=tuple(block_out_channels),
layers_per_block=unet_params.num_noattn_blocks[0],
cross_attention_dim=unet_params.context_dim,
attention_head_dim=unet_params.num_heads,
)
return config
def create_vae_diffusers_config(vae_params):
"""
Creates a config for the diffusers based on the config of the VD model.
"""
block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult]
down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
config = dict(
sample_size=vae_params.resolution,
in_channels=vae_params.in_channels,
out_channels=vae_params.out_ch,
down_block_types=tuple(down_block_types),
up_block_types=tuple(up_block_types),
block_out_channels=tuple(block_out_channels),
latent_channels=vae_params.z_channels,
layers_per_block=vae_params.num_res_blocks,
)
return config
def create_diffusers_scheduler(original_config):
schedular = DDIMScheduler(
num_train_timesteps=original_config.model.params.timesteps,
beta_start=original_config.model.params.linear_start,
beta_end=original_config.model.params.linear_end,
beta_schedule="scaled_linear",
)
return schedular
def convert_vd_unet_checkpoint(checkpoint, config, unet_key, extract_ema=False):
"""
Takes a state dict and a config, and returns a converted checkpoint.
"""
# extract state_dict for UNet
unet_state_dict = {}
keys = list(checkpoint.keys())
# at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
if sum(k.startswith("model_ema") for k in keys) > 100:
print("Checkpoint has both EMA and non-EMA weights.")
if extract_ema:
print(
"In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA"
" weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag."
)
for key in keys:
if key.startswith("model.diffusion_model"):
flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key)
else:
print(
"In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA"
" weights (usually better for inference), please make sure to add the `--extract_ema` flag."
)
for key in keys:
if key.startswith(unet_key):
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
new_checkpoint = {}
new_checkpoint["time_embedding.linear_1.weight"] = checkpoint["model.diffusion_model.time_embed.0.weight"]
new_checkpoint["time_embedding.linear_1.bias"] = checkpoint["model.diffusion_model.time_embed.0.bias"]
new_checkpoint["time_embedding.linear_2.weight"] = checkpoint["model.diffusion_model.time_embed.2.weight"]
new_checkpoint["time_embedding.linear_2.bias"] = checkpoint["model.diffusion_model.time_embed.2.bias"]
new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
# Retrieves the keys for the input blocks only
num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
input_blocks = {
layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key]
for layer_id in range(num_input_blocks)
}
# Retrieves the keys for the middle blocks only
num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
middle_blocks = {
layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
for layer_id in range(num_middle_blocks)
}
# Retrieves the keys for the output blocks only
num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
output_blocks = {
layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key]
for layer_id in range(num_output_blocks)
}
for i in range(1, num_input_blocks):
block_id = (i - 1) // (config["layers_per_block"] + 1)
layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
resnets = [
key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
]
attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
f"input_blocks.{i}.0.op.weight"
)
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
f"input_blocks.{i}.0.op.bias"
)
elif f"input_blocks.{i}.0.weight" in unet_state_dict:
# text_unet uses linear layers in place of downsamplers
shape = unet_state_dict[f"input_blocks.{i}.0.weight"].shape
if shape[0] != shape[1]:
continue
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.weight"] = unet_state_dict.pop(
f"input_blocks.{i}.0.weight"
)
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.bias"] = unet_state_dict.pop(
f"input_blocks.{i}.0.bias"
)
paths = renew_resnet_paths(resnets)
meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
assign_to_checkpoint(
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
)
if len(attentions):
paths = renew_attention_paths(attentions)
meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
assign_to_checkpoint(
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
)
resnet_0 = middle_blocks[0]
attentions = middle_blocks[1]
resnet_1 = middle_blocks[2]
resnet_0_paths = renew_resnet_paths(resnet_0)
assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
resnet_1_paths = renew_resnet_paths(resnet_1)
assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
attentions_paths = renew_attention_paths(attentions)
meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
assign_to_checkpoint(
attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
)
for i in range(num_output_blocks):
block_id = i // (config["layers_per_block"] + 1)
layer_in_block_id = i % (config["layers_per_block"] + 1)
output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
output_block_list = {}
for layer in output_block_layers:
layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
if layer_id in output_block_list:
output_block_list[layer_id].append(layer_name)
else:
output_block_list[layer_id] = [layer_name]
if len(output_block_list) > 1:
resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
paths = renew_resnet_paths(resnets)
meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
assign_to_checkpoint(
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
)
if ["conv.weight", "conv.bias"] in output_block_list.values():
index = list(output_block_list.values()).index(["conv.weight", "conv.bias"])
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
f"output_blocks.{i}.{index}.conv.weight"
]
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
f"output_blocks.{i}.{index}.conv.bias"
]
# Clear attentions as they have been attributed above.
if len(attentions) == 2:
attentions = []
elif f"output_blocks.{i}.1.weight" in unet_state_dict:
# text_unet uses linear layers in place of upsamplers
shape = unet_state_dict[f"output_blocks.{i}.1.weight"].shape
if shape[0] != shape[1]:
continue
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.weight"] = unet_state_dict.pop(
f"output_blocks.{i}.1.weight"
)
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.bias"] = unet_state_dict.pop(
f"output_blocks.{i}.1.bias"
)
# Clear attentions as they have been attributed above.
if len(attentions) == 2:
attentions = []
elif f"output_blocks.{i}.2.weight" in unet_state_dict:
# text_unet uses linear layers in place of upsamplers
shape = unet_state_dict[f"output_blocks.{i}.2.weight"].shape
if shape[0] != shape[1]:
continue
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.weight"] = unet_state_dict.pop(
f"output_blocks.{i}.2.weight"
)
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.bias"] = unet_state_dict.pop(
f"output_blocks.{i}.2.bias"
)
if len(attentions):
paths = renew_attention_paths(attentions)
meta_path = {
"old": f"output_blocks.{i}.1",
"new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
}
assign_to_checkpoint(
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
)
else:
resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
for path in resnet_0_paths:
old_path = ".".join(["output_blocks", str(i), path["old"]])
new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
new_checkpoint[new_path] = unet_state_dict[old_path]
return new_checkpoint
def convert_vd_vae_checkpoint(checkpoint, config):
# extract state dict for VAE
vae_state_dict = {}
keys = list(checkpoint.keys())
for key in keys:
vae_state_dict[key] = checkpoint.get(key)
new_checkpoint = {}
new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
# Retrieves the keys for the encoder down blocks only
num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
down_blocks = {
layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
}
# Retrieves the keys for the decoder up blocks only
num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
up_blocks = {
layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
}
for i in range(num_down_blocks):
resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
f"encoder.down.{i}.downsample.conv.weight"
)
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
f"encoder.down.{i}.downsample.conv.bias"
)
paths = renew_vae_resnet_paths(resnets)
meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
num_mid_res_blocks = 2
for i in range(1, num_mid_res_blocks + 1):
resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
paths = renew_vae_resnet_paths(resnets)
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
paths = renew_vae_attention_paths(mid_attentions)
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
conv_attn_to_linear(new_checkpoint)
for i in range(num_up_blocks):
block_id = num_up_blocks - 1 - i
resnets = [
key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
]
if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
f"decoder.up.{block_id}.upsample.conv.weight"
]
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
f"decoder.up.{block_id}.upsample.conv.bias"
]
paths = renew_vae_resnet_paths(resnets)
meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
num_mid_res_blocks = 2
for i in range(1, num_mid_res_blocks + 1):
resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
paths = renew_vae_resnet_paths(resnets)
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
paths = renew_vae_attention_paths(mid_attentions)
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
conv_attn_to_linear(new_checkpoint)
return new_checkpoint
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--unet_checkpoint_path", default=None, type=str, required=False, help="Path to the checkpoint to convert."
)
parser.add_argument(
"--vae_checkpoint_path", default=None, type=str, required=False, help="Path to the checkpoint to convert."
)
parser.add_argument(
"--optimus_checkpoint_path", default=None, type=str, required=False, help="Path to the checkpoint to convert."
)
parser.add_argument(
"--scheduler_type",
default="pndm",
type=str,
help="Type of scheduler to use. Should be one of ['pndm', 'lms', 'ddim', 'euler', 'euler-ancest', 'dpm']",
)
parser.add_argument(
"--extract_ema",
action="store_true",
help=(
"Only relevant for checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights"
" or not. Defaults to `False`. Add `--extract_ema` to extract the EMA weights. EMA weights usually yield"
" higher quality images for inference. Non-EMA weights are usually better to continue fine-tuning."
),
)
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
args = parser.parse_args()
scheduler_config = SCHEDULER_CONFIG
num_train_timesteps = scheduler_config.timesteps
beta_start = scheduler_config.beta_linear_start
beta_end = scheduler_config.beta_linear_end
if args.scheduler_type == "pndm":
scheduler = PNDMScheduler(
beta_end=beta_end,
beta_schedule="scaled_linear",
beta_start=beta_start,
num_train_timesteps=num_train_timesteps,
skip_prk_steps=True,
steps_offset=1,
)
elif args.scheduler_type == "lms":
scheduler = LMSDiscreteScheduler(beta_start=beta_start, beta_end=beta_end, beta_schedule="scaled_linear")
elif args.scheduler_type == "euler":
scheduler = EulerDiscreteScheduler(beta_start=beta_start, beta_end=beta_end, beta_schedule="scaled_linear")
elif args.scheduler_type == "euler-ancestral":
scheduler = EulerAncestralDiscreteScheduler(
beta_start=beta_start, beta_end=beta_end, beta_schedule="scaled_linear"
)
elif args.scheduler_type == "dpm":
scheduler = DPMSolverMultistepScheduler(
beta_start=beta_start, beta_end=beta_end, beta_schedule="scaled_linear"
)
elif args.scheduler_type == "ddim":
scheduler = DDIMScheduler(
beta_start=beta_start,
beta_end=beta_end,
beta_schedule="scaled_linear",
clip_sample=False,
set_alpha_to_one=False,
steps_offset=1,
)
else:
raise ValueError(f"Scheduler of type {args.scheduler_type} doesn't exist!")
# Convert the UNet2DConditionModel models.
if args.unet_checkpoint_path is not None:
# image UNet
image_unet_config = create_image_unet_diffusers_config(IMAGE_UNET_CONFIG)
checkpoint = torch.load(args.unet_checkpoint_path)
converted_image_unet_checkpoint = convert_vd_unet_checkpoint(
checkpoint, image_unet_config, unet_key="model.diffusion_model.unet_image.", extract_ema=args.extract_ema
)
image_unet = UNet2DConditionModel(**image_unet_config)
image_unet.load_state_dict(converted_image_unet_checkpoint)
# text UNet
text_unet_config = create_text_unet_diffusers_config(TEXT_UNET_CONFIG)
converted_text_unet_checkpoint = convert_vd_unet_checkpoint(
checkpoint, text_unet_config, unet_key="model.diffusion_model.unet_text.", extract_ema=args.extract_ema
)
text_unet = UNetFlatConditionModel(**text_unet_config)
text_unet.load_state_dict(converted_text_unet_checkpoint)
# Convert the VAE model.
if args.vae_checkpoint_path is not None:
vae_config = create_vae_diffusers_config(AUTOENCODER_CONFIG)
checkpoint = torch.load(args.vae_checkpoint_path)
converted_vae_checkpoint = convert_vd_vae_checkpoint(checkpoint, vae_config)
vae = AutoencoderKL(**vae_config)
vae.load_state_dict(converted_vae_checkpoint)
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
image_feature_extractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-large-patch14")
text_encoder = CLIPTextModelWithProjection.from_pretrained("openai/clip-vit-large-patch14")
image_encoder = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14")
pipe = VersatileDiffusionPipeline(
scheduler=scheduler,
tokenizer=tokenizer,
image_feature_extractor=image_feature_extractor,
text_encoder=text_encoder,
image_encoder=image_encoder,
image_unet=image_unet,
text_unet=text_unet,
vae=vae,
)
pipe.save_pretrained(args.dump_path)

View File

@@ -91,15 +91,12 @@ _deps = [
"isort>=5.5.4",
"jax>=0.2.8,!=0.3.2",
"jaxlib>=0.1.65",
"k-diffusion",
"librosa",
"modelcards>=0.1.4",
"numpy",
"parameterized",
"pytest",
"pytest-timeout",
"pytest-xdist",
"safetensors",
"sentencepiece>=0.1.91,!=0.1.92",
"scipy",
"regex!=2019.12.17",
@@ -107,7 +104,7 @@ _deps = [
"tensorboard",
"torch>=1.4",
"torchvision",
"transformers>=4.25.1",
"transformers>=4.21.0",
]
# this is a lookup table with items like:
@@ -183,17 +180,14 @@ extras["docs"] = deps_list("hf-doc-builder")
extras["training"] = deps_list("accelerate", "datasets", "tensorboard", "modelcards")
extras["test"] = deps_list(
"datasets",
"k-diffusion",
"librosa",
"parameterized",
"pytest",
"pytest-timeout",
"pytest-xdist",
"safetensors",
"sentencepiece",
"scipy",
"torchvision",
"transformers",
"transformers"
)
extras["torch"] = deps_list("torch", "accelerate")
@@ -218,7 +212,7 @@ install_requires = [
setup(
name="diffusers",
version="0.10.0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
version="0.8.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
description="Diffusers",
long_description=open("README.md", "r", encoding="utf-8").read(),
long_description_content_type="text/markdown",

View File

@@ -1,41 +1,22 @@
__version__ = "0.10.0"
from .configuration_utils import ConfigMixin
from .onnx_utils import OnnxRuntimeModel
from .utils import (
OptionalDependencyNotAvailable,
is_flax_available,
is_inflect_available,
is_k_diffusion_available,
is_librosa_available,
is_onnx_available,
is_scipy_available,
is_torch_available,
is_transformers_available,
is_transformers_version,
is_unidecode_available,
logging,
)
# Make sure `transformers` is up to date
if is_transformers_available():
import transformers
__version__ = "0.8.0.dev0"
if is_transformers_version("<", "4.25.1"):
raise ImportError(
f"`diffusers` requires transformers >= 4.25.1 to function correctly, but {transformers.__version__} was"
" found in your environment. You can upgrade it with pip: `pip install transformers --upgrade`"
)
else:
pass
from .configuration_utils import ConfigMixin
from .onnx_utils import OnnxRuntimeModel
from .utils import logging
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils.dummy_pt_objects import * # noqa F403
else:
if is_torch_available():
from .modeling_utils import ModelMixin
from .models import AutoencoderKL, Transformer2DModel, UNet1DModel, UNet2DConditionModel, UNet2DModel, VQModel
from .optimization import (
@@ -63,14 +44,10 @@ else:
DDIMScheduler,
DDPMScheduler,
DPMSolverMultistepScheduler,
DPMSolverSinglestepScheduler,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
HeunDiscreteScheduler,
IPNDMScheduler,
KarrasVeScheduler,
KDPM2AncestralDiscreteScheduler,
KDPM2DiscreteScheduler,
PNDMScheduler,
RePaintScheduler,
SchedulerMixin,
@@ -78,57 +55,30 @@ else:
VQDiffusionScheduler,
)
from .training_utils import EMAModel
try:
if not (is_torch_available() and is_scipy_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils.dummy_torch_and_scipy_objects import * # noqa F403
else:
from .utils.dummy_pt_objects import * # noqa F403
if is_torch_available() and is_scipy_available():
from .schedulers import LMSDiscreteScheduler
try:
if not (is_torch_available() and is_transformers_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils.dummy_torch_and_transformers_objects import * # noqa F403
else:
from .utils.dummy_torch_and_scipy_objects import * # noqa F403
if is_torch_available() and is_transformers_available():
from .pipelines import (
AltDiffusionImg2ImgPipeline,
AltDiffusionPipeline,
CycleDiffusionPipeline,
LDMTextToImagePipeline,
PaintByExamplePipeline,
StableDiffusionDepth2ImgPipeline,
StableDiffusionImageVariationPipeline,
StableDiffusionImg2ImgPipeline,
StableDiffusionInpaintPipeline,
StableDiffusionInpaintPipelineLegacy,
StableDiffusionPipeline,
StableDiffusionPipelineSafe,
StableDiffusionUpscalePipeline,
VersatileDiffusionDualGuidedPipeline,
VersatileDiffusionImageVariationPipeline,
VersatileDiffusionPipeline,
VersatileDiffusionTextToImagePipeline,
VQDiffusionPipeline,
)
try:
if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils.dummy_torch_and_transformers_and_k_diffusion_objects import * # noqa F403
else:
from .pipelines import StableDiffusionKDiffusionPipeline
from .utils.dummy_torch_and_transformers_objects import * # noqa F403
try:
if not (is_torch_available() and is_transformers_available() and is_onnx_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils.dummy_torch_and_transformers_and_onnx_objects import * # noqa F403
else:
if is_torch_available() and is_transformers_available() and is_onnx_available():
from .pipelines import (
OnnxStableDiffusionImg2ImgPipeline,
OnnxStableDiffusionInpaintPipeline,
@@ -136,21 +86,10 @@ else:
OnnxStableDiffusionPipeline,
StableDiffusionOnnxPipeline,
)
try:
if not (is_torch_available() and is_librosa_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils.dummy_torch_and_librosa_objects import * # noqa F403
else:
from .pipelines import AudioDiffusionPipeline, Mel
from .utils.dummy_torch_and_transformers_and_onnx_objects import * # noqa F403
try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils.dummy_flax_objects import * # noqa F403
else:
if is_flax_available():
from .modeling_flax_utils import FlaxModelMixin
from .models.unet_2d_condition_flax import FlaxUNet2DConditionModel
from .models.vae_flax import FlaxAutoencoderKL
@@ -165,11 +104,10 @@ else:
FlaxSchedulerMixin,
FlaxScoreSdeVeScheduler,
)
try:
if not (is_flax_available() and is_transformers_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils.dummy_flax_and_transformers_objects import * # noqa F403
else:
from .utils.dummy_flax_objects import * # noqa F403
if is_flax_available() and is_transformers_available():
from .pipelines import FlaxStableDiffusionPipeline
else:
from .utils.dummy_flax_and_transformers_objects import * # noqa F403

View File

@@ -24,8 +24,6 @@ import re
from collections import OrderedDict
from typing import Any, Dict, Tuple, Union
import numpy as np
from huggingface_hub import hf_hub_download
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
from requests import HTTPError
@@ -82,21 +80,20 @@ class ConfigMixin:
- **config_name** (`str`) -- A filename under which the config should stored when calling
[`~ConfigMixin.save_config`] (should be overridden by parent class).
- **ignore_for_config** (`List[str]`) -- A list of attributes that should not be saved in the config (should be
overridden by subclass).
- **has_compatibles** (`bool`) -- Whether the class has compatible classes (should be overridden by subclass).
- **_deprecated_kwargs** (`List[str]`) -- Keyword arguments that are deprecated. Note that the init function
should only have a `kwargs` argument if at least one argument is deprecated (should be overridden by
subclass).
overridden by parent class).
- **has_compatibles** (`bool`) -- Whether the class has compatible classes (should be overridden by parent
class).
"""
config_name = None
ignore_for_config = []
has_compatibles = False
_deprecated_kwargs = []
def register_to_config(self, **kwargs):
if self.config_name is None:
raise NotImplementedError(f"Make sure that {self.__class__} has defined a class name `config_name`")
kwargs["_class_name"] = self.__class__.__name__
kwargs["_diffusers_version"] = __version__
# Special case for `kwargs` used in deprecation warning added to schedulers
# TODO: remove this when we remove the deprecation warning, and the `kwargs` argument,
# or solve in a more general way.
@@ -201,11 +198,6 @@ class ConfigMixin:
if "dtype" in unused_kwargs:
init_dict["dtype"] = unused_kwargs.pop("dtype")
# add possible deprecated kwargs
for deprecated_kwarg in cls._deprecated_kwargs:
if deprecated_kwarg in unused_kwargs:
init_dict[deprecated_kwarg] = unused_kwargs.pop(deprecated_kwarg)
# Return model and optionally state and/or unused_kwargs
model = cls(**init_dict)
@@ -470,7 +462,7 @@ class ConfigMixin:
unused_kwargs = {**config_dict, **kwargs}
# 7. Define "hidden" config parameters that were saved for compatible classes
hidden_config_dict = {k: v for k, v in original_dict.items() if k not in init_dict}
hidden_config_dict = {k: v for k, v in original_dict.items() if k not in init_dict and not k.startswith("_")}
return init_dict, unused_kwargs, hidden_config_dict
@@ -501,15 +493,6 @@ class ConfigMixin:
`str`: String containing all the attributes that make up this configuration instance in JSON format.
"""
config_dict = self._internal_dict if hasattr(self, "_internal_dict") else {}
config_dict["_class_name"] = self.__class__.__name__
config_dict["_diffusers_version"] = __version__
def to_json_saveable(value):
if isinstance(value, np.ndarray):
value = value.tolist()
return value
config_dict = {k: to_json_saveable(v) for k, v in config_dict.items()}
return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
def to_json_file(self, json_file_path: Union[str, os.PathLike]):
@@ -537,7 +520,7 @@ def register_to_config(init):
def inner_init(self, *args, **kwargs):
# Ignore private kwargs in the init.
init_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")}
config_init_kwargs = {k: v for k, v in kwargs.items() if k.startswith("_")}
init(self, *args, **init_kwargs)
if not isinstance(self, ConfigMixin):
raise RuntimeError(
f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does "
@@ -562,9 +545,7 @@ def register_to_config(init):
if k not in ignore and k not in new_kwargs
}
)
new_kwargs = {**config_init_kwargs, **new_kwargs}
getattr(self, "register_to_config")(**new_kwargs)
init(self, *args, **init_kwargs)
return inner_init
@@ -581,7 +562,7 @@ def flax_register_to_config(cls):
)
# Ignore private kwargs in the init. Retrieve all passed attributes
init_kwargs = {k: v for k, v in kwargs.items()}
init_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")}
# Retrieve default values
fields = dataclasses.fields(self)

View File

@@ -15,15 +15,12 @@ deps = {
"isort": "isort>=5.5.4",
"jax": "jax>=0.2.8,!=0.3.2",
"jaxlib": "jaxlib>=0.1.65",
"k-diffusion": "k-diffusion",
"librosa": "librosa",
"modelcards": "modelcards>=0.1.4",
"numpy": "numpy",
"parameterized": "parameterized",
"pytest": "pytest",
"pytest-timeout": "pytest-timeout",
"pytest-xdist": "pytest-xdist",
"safetensors": "safetensors",
"sentencepiece": "sentencepiece>=0.1.91,!=0.1.92",
"scipy": "scipy",
"regex": "regex!=2019.12.17",
@@ -31,5 +28,5 @@ deps = {
"tensorboard": "tensorboard",
"torch": "torch>=1.4",
"torchvision": "torchvision",
"transformers": "transformers>=4.25.1",
"transformers": "transformers>=4.21.0",
}

View File

@@ -23,22 +23,6 @@ from ...utils.dummy_pt_objects import DDPMScheduler
class ValueGuidedRLPipeline(DiffusionPipeline):
r"""
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Pipeline for sampling actions from a diffusion model trained to predict sequences of states.
Original implementation inspired by this repository: https://github.com/jannerm/diffuser.
Parameters:
value_function ([`UNet1DModel`]): A specialized UNet for fine-tuning trajectories base on reward.
unet ([`UNet1DModel`]): U-Net architecture to denoise the encoded trajectories.
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `unet` to denoise the encoded trajectories. Default for this
application is [`DDPMScheduler`].
env: An environment following the OpenAI gym API to act in. For now only Hopper has pretrained models.
"""
def __init__(
self,
value_function: UNet1DModel,
@@ -94,26 +78,20 @@ class ValueGuidedRLPipeline(DiffusionPipeline):
for _ in range(n_guide_steps):
with torch.enable_grad():
x.requires_grad_()
# permute to match dimension for pre-trained models
y = self.value_function(x.permute(0, 2, 1), timesteps).sample
grad = torch.autograd.grad([y.sum()], [x])[0]
posterior_variance = self.scheduler._get_variance(i)
model_std = torch.exp(0.5 * posterior_variance)
grad = model_std * grad
grad[timesteps < 2] = 0
x = x.detach()
x = x + scale * grad
x = self.reset_x0(x, conditions, self.action_dim)
prev_x = self.unet(x.permute(0, 2, 1), timesteps).sample.permute(0, 2, 1)
# TODO: verify deprecation of this kwarg
x = self.scheduler.step(prev_x, i, x, predict_epsilon=False)["prev_sample"]
# apply conditions to the trajectory (set the initial state)
# apply conditions to the trajectory
x = self.reset_x0(x, conditions, self.action_dim)
x = self.to_torch(x)
return x, y
@@ -147,6 +125,5 @@ class ValueGuidedRLPipeline(DiffusionPipeline):
else:
# if we didn't run value guiding, select a random action
selected_index = np.random.randint(0, batch_size)
denorm_actions = denorm_actions[selected_index, 0]
return denorm_actions

View File

@@ -15,16 +15,16 @@
import os
import shutil
import sys
from pathlib import Path
from typing import Dict, Optional, Union
from uuid import uuid4
import requests
from huggingface_hub import HfFolder, whoami
from huggingface_hub import HfFolder, Repository, whoami
from . import __version__
from .utils import ENV_VARS_TRUE_VALUES, HUGGINGFACE_CO_RESOLVE_ENDPOINT, logging
from .utils import ENV_VARS_TRUE_VALUES, deprecate, logging
from .utils.import_utils import (
_flax_version,
_jax_version,
@@ -46,9 +46,7 @@ logger = logging.get_logger(__name__)
MODEL_CARD_TEMPLATE_PATH = Path(__file__).parent / "utils" / "model_card_template.md"
SESSION_ID = uuid4().hex
HF_HUB_OFFLINE = os.getenv("HF_HUB_OFFLINE", "").upper() in ENV_VARS_TRUE_VALUES
DISABLE_TELEMETRY = os.getenv("DISABLE_TELEMETRY", "").upper() in ENV_VARS_TRUE_VALUES
HUGGINGFACE_CO_TELEMETRY = HUGGINGFACE_CO_RESOLVE_ENDPOINT + "/api/telemetry/"
def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str:
@@ -75,27 +73,6 @@ def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str:
return ua
def send_telemetry(data: Dict, name: str):
"""
Sends logs to the Hub telemetry endpoint.
Args:
data: the fields to track, e.g. {"example_name": "dreambooth"}
name: a unique name to differentiate the telemetry logs, e.g. "diffusers_examples" or "diffusers_notebooks"
"""
if DISABLE_TELEMETRY or HF_HUB_OFFLINE:
pass
headers = {"user-agent": http_user_agent(data)}
endpoint = HUGGINGFACE_CO_TELEMETRY + name
try:
r = requests.head(endpoint, headers=headers)
r.raise_for_status()
except Exception:
# We don't want to error in case of connection errors of any kind.
pass
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
if token is None:
token = HfFolder.get_token()
@@ -106,6 +83,121 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token:
return f"{organization}/{model_id}"
def init_git_repo(args, at_init: bool = False):
"""
Args:
Initializes a git repo in `args.hub_model_id`.
at_init (`bool`, *optional*, defaults to `False`):
Whether this function is called before any training or not. If `self.args.overwrite_output_dir` is `True`
and `at_init` is `True`, the path to the repo (which is `self.args.output_dir`) might be wiped out.
"""
deprecation_message = (
"Please use `huggingface_hub.Repository`. "
"See `examples/unconditional_image_generation/train_unconditional.py` for an example."
)
deprecate("init_git_repo()", "0.10.0", deprecation_message)
if hasattr(args, "local_rank") and args.local_rank not in [-1, 0]:
return
hub_token = args.hub_token if hasattr(args, "hub_token") else None
use_auth_token = True if hub_token is None else hub_token
if not hasattr(args, "hub_model_id") or args.hub_model_id is None:
repo_name = Path(args.output_dir).absolute().name
else:
repo_name = args.hub_model_id
if "/" not in repo_name:
repo_name = get_full_repo_name(repo_name, token=hub_token)
try:
repo = Repository(
args.output_dir,
clone_from=repo_name,
use_auth_token=use_auth_token,
private=args.hub_private_repo,
)
except EnvironmentError:
if args.overwrite_output_dir and at_init:
# Try again after wiping output_dir
shutil.rmtree(args.output_dir)
repo = Repository(
args.output_dir,
clone_from=repo_name,
use_auth_token=use_auth_token,
)
else:
raise
repo.git_pull()
# By default, ignore the checkpoint folders
if not os.path.exists(os.path.join(args.output_dir, ".gitignore")):
with open(os.path.join(args.output_dir, ".gitignore"), "w", encoding="utf-8") as writer:
writer.writelines(["checkpoint-*/"])
return repo
def push_to_hub(
args,
pipeline,
repo: Repository,
commit_message: Optional[str] = "End of training",
blocking: bool = True,
**kwargs,
) -> str:
"""
Parameters:
Upload *self.model* and *self.tokenizer* to the 🤗 model hub on the repo *self.args.hub_model_id*.
commit_message (`str`, *optional*, defaults to `"End of training"`):
Message to commit while pushing.
blocking (`bool`, *optional*, defaults to `True`):
Whether the function should return only when the `git push` has finished.
kwargs:
Additional keyword arguments passed along to [`create_model_card`].
Returns:
The url of the commit of your model in the given repository if `blocking=False`, a tuple with the url of the
commit and an object to track the progress of the commit if `blocking=True`
"""
deprecation_message = (
"Please use `huggingface_hub.Repository` and `Repository.push_to_hub()`. "
"See `examples/unconditional_image_generation/train_unconditional.py` for an example."
)
deprecate("push_to_hub()", "0.10.0", deprecation_message)
if not hasattr(args, "hub_model_id") or args.hub_model_id is None:
model_name = Path(args.output_dir).name
else:
model_name = args.hub_model_id.split("/")[-1]
output_dir = args.output_dir
os.makedirs(output_dir, exist_ok=True)
logger.info(f"Saving pipeline checkpoint to {output_dir}")
pipeline.save_pretrained(output_dir)
# Only push from one node.
if hasattr(args, "local_rank") and args.local_rank not in [-1, 0]:
return
# Cancel any async push in progress if blocking=True. The commits will all be pushed together.
if (
blocking
and len(repo.command_queue) > 0
and repo.command_queue[-1] is not None
and not repo.command_queue[-1].is_done
):
repo.command_queue[-1]._process.kill()
git_head_commit_url = repo.push_to_hub(commit_message=commit_message, blocking=blocking, auto_lfs_prune=True)
# push separately the model card to be independent from the rest of the model
create_model_card(args, model_name=model_name)
try:
repo.push_to_hub(commit_message="update model card README.md", blocking=blocking, auto_lfs_prune=True)
except EnvironmentError as exc:
logger.error(f"Error pushing update to the model card. Please read logs and retry.\n${exc}")
return git_head_commit_url
def create_model_card(args, model_name):
if not is_modelcards_available:
raise ValueError(

View File

@@ -28,7 +28,6 @@ from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, R
from requests import HTTPError
from . import __version__, is_torch_available
from .hub_utils import send_telemetry
from .modeling_flax_pytorch_utils import convert_pytorch_state_dict_to_flax
from .utils import (
CONFIG_NAME,
@@ -333,17 +332,13 @@ class FlaxModelMixin:
elif os.path.isfile(os.path.join(pretrained_path_with_subfolder, WEIGHTS_NAME)):
raise EnvironmentError(
f"{WEIGHTS_NAME} file found in directory {pretrained_path_with_subfolder}. Please load the model"
" using `from_pt=True`."
" using `from_pt=True`."
)
else:
raise EnvironmentError(
f"Error no file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME} found in directory "
f"{pretrained_path_with_subfolder}."
)
send_telemetry(
{"model_class": cls.__name__, "model_path": "local", "framework": "flax"},
name="diffusers_from_pretrained",
)
else:
try:
model_file = hf_hub_download(
@@ -359,10 +354,6 @@ class FlaxModelMixin:
subfolder=subfolder,
revision=revision,
)
send_telemetry(
{"model_class": cls.__name__, "model_path": "hub", "framework": "flax"},
name="diffusers_from_pretrained",
)
except RepositoryNotFoundError:
raise EnvironmentError(

View File

@@ -26,15 +26,12 @@ from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, R
from requests import HTTPError
from . import __version__
from .hub_utils import send_telemetry
from .utils import (
CONFIG_NAME,
DIFFUSERS_CACHE,
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
SAFETENSORS_WEIGHTS_NAME,
WEIGHTS_NAME,
is_accelerate_available,
is_safetensors_available,
is_torch_version,
logging,
)
@@ -54,9 +51,6 @@ if is_accelerate_available():
from accelerate.utils import set_module_tensor_to_device
from accelerate.utils.versions import is_torch_version
if is_safetensors_available():
import safetensors
def get_parameter_device(parameter: torch.nn.Module):
try:
@@ -90,13 +84,10 @@ def get_parameter_dtype(parameter: torch.nn.Module):
def load_state_dict(checkpoint_file: Union[str, os.PathLike]):
"""
Reads a checkpoint file, returning properly formatted errors if they arise.
Reads a PyTorch checkpoint file, returning properly formatted errors if they arise.
"""
try:
if os.path.basename(checkpoint_file) == WEIGHTS_NAME:
return torch.load(checkpoint_file, map_location="cpu")
else:
return safetensors.torch.load_file(checkpoint_file, device="cpu")
return torch.load(checkpoint_file, map_location="cpu")
except Exception as e:
try:
with open(checkpoint_file) as f:
@@ -113,7 +104,7 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike]):
) from e
except (UnicodeDecodeError, ValueError):
raise OSError(
f"Unable to load weights from checkpoint file for '{checkpoint_file}' "
f"Unable to load weights from pytorch checkpoint file for '{checkpoint_file}' "
f"at '{checkpoint_file}'. "
"If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True."
)
@@ -192,8 +183,7 @@ class ModelMixin(torch.nn.Module):
self,
save_directory: Union[str, os.PathLike],
is_main_process: bool = True,
save_function: Callable = None,
safe_serialization: bool = False,
save_function: Callable = torch.save,
):
"""
Save a model and its configuration file to a directory, so that it can be re-loaded using the
@@ -208,21 +198,12 @@ class ModelMixin(torch.nn.Module):
the main process to avoid race conditions.
save_function (`Callable`):
The function to use to save the state dictionary. Useful on distributed training like TPUs when one
need to replace `torch.save` by another method. Can be configured with the environment variable
`DIFFUSERS_SAVE_MODE`.
safe_serialization (`bool`, *optional*, defaults to `False`):
Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
need to replace `torch.save` by another method.
"""
if safe_serialization and not is_safetensors_available():
raise ImportError("`safe_serialization` requires the `safetensors library: `pip install safetensors`.")
if os.path.isfile(save_directory):
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
return
if save_function is None:
save_function = safetensors.torch.save_file if safe_serialization else torch.save
os.makedirs(save_directory, exist_ok=True)
model_to_save = self
@@ -235,21 +216,18 @@ class ModelMixin(torch.nn.Module):
# Save the model
state_dict = model_to_save.state_dict()
weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
# Clean the folder from a previous save
for filename in os.listdir(save_directory):
full_filename = os.path.join(save_directory, filename)
# If we have a shard file that is not going to be replaced, we delete it, but only from the main process
# in distributed settings to avoid race conditions.
weights_no_suffix = weights_name.replace(".bin", "").replace(".safetensors", "")
if filename.startswith(weights_no_suffix) and os.path.isfile(full_filename) and is_main_process:
if filename.startswith(WEIGHTS_NAME[:-4]) and os.path.isfile(full_filename) and is_main_process:
os.remove(full_filename)
# Save the model
save_function(state_dict, os.path.join(save_directory, weights_name))
save_function(state_dict, os.path.join(save_directory, WEIGHTS_NAME))
logger.info(f"Model weights saved in {os.path.join(save_directory, weights_name)}")
logger.info(f"Model weights saved in {os.path.join(save_directory, WEIGHTS_NAME)}")
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
@@ -354,7 +332,7 @@ class ModelMixin(torch.nn.Module):
if low_cpu_mem_usage and not is_accelerate_available():
low_cpu_mem_usage = False
logger.warning(
logger.warn(
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
@@ -397,44 +375,80 @@ class ModelMixin(torch.nn.Module):
# This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
# Load model
model_file = None
if is_safetensors_available():
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
if os.path.isdir(pretrained_model_name_or_path):
if os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
# Load from a PyTorch checkpoint
model_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
elif subfolder is not None and os.path.isfile(
os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)
):
model_file = os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)
else:
raise EnvironmentError(
f"Error no file named {WEIGHTS_NAME} found in directory {pretrained_model_name_or_path}."
)
else:
try:
model_file = cls._get_model_file(
# Load from URL or cache if already cached
model_file = hf_hub_download(
pretrained_model_name_or_path,
weights_name=SAFETENSORS_WEIGHTS_NAME,
filename=WEIGHTS_NAME,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
subfolder=subfolder,
revision=revision,
)
except:
pass
if model_file is None:
model_file = cls._get_model_file(
pretrained_model_name_or_path,
weights_name=WEIGHTS_NAME,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
)
except RepositoryNotFoundError:
raise EnvironmentError(
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
"token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
"login`."
)
except RevisionNotFoundError:
raise EnvironmentError(
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for "
"this model name. Check the model page at "
f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
)
except EntryNotFoundError:
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named {WEIGHTS_NAME}."
)
except HTTPError as err:
raise EnvironmentError(
"There was a specific connection error when trying to load"
f" {pretrained_model_name_or_path}:\n{err}"
)
except ValueError:
raise EnvironmentError(
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
f" directory containing a file named {WEIGHTS_NAME} or"
" \nCheckout your internet connection or see how to run the library in"
" offline mode at 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
)
except EnvironmentError:
raise EnvironmentError(
f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from "
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
f"containing a file named {WEIGHTS_NAME}"
)
# restore default dtype
if low_cpu_mem_usage:
# Instantiate model with empty weights
with accelerate.init_empty_weights():
config, unused_kwargs = cls.load_config(
model, unused_kwargs = cls.from_config(
config_path,
cache_dir=cache_dir,
return_unused_kwargs=True,
@@ -448,7 +462,6 @@ class ModelMixin(torch.nn.Module):
device_map=device_map,
**kwargs,
)
model = cls.from_config(config, **unused_kwargs)
# if device_map is Non,e load the state dict on move the params from meta device to the cpu
if device_map is None:
@@ -469,7 +482,7 @@ class ModelMixin(torch.nn.Module):
"error_msgs": [],
}
else:
config, unused_kwargs = cls.load_config(
model, unused_kwargs = cls.from_config(
config_path,
cache_dir=cache_dir,
return_unused_kwargs=True,
@@ -483,24 +496,8 @@ class ModelMixin(torch.nn.Module):
device_map=device_map,
**kwargs,
)
model = cls.from_config(config, **unused_kwargs)
state_dict = load_state_dict(model_file)
dtype = set(v.dtype for v in state_dict.values())
if len(dtype) > 1 and torch.float32 not in dtype:
raise ValueError(
f"The weights of the model file {model_file} have a mixture of incompatible dtypes {dtype}. Please"
f" make sure that {model_file} weights have only one dtype."
)
elif len(dtype) > 1 and torch.float32 in dtype:
dtype = torch.float32
else:
dtype = dtype.pop()
# move model to correct dtype
model = model.to(dtype)
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
model,
state_dict,
@@ -532,100 +529,6 @@ class ModelMixin(torch.nn.Module):
return model
@classmethod
def _get_model_file(
cls,
pretrained_model_name_or_path,
*,
weights_name,
subfolder,
cache_dir,
force_download,
proxies,
resume_download,
local_files_only,
use_auth_token,
user_agent,
revision,
):
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
if os.path.isdir(pretrained_model_name_or_path):
if os.path.isfile(os.path.join(pretrained_model_name_or_path, weights_name)):
# Load from a PyTorch checkpoint
model_file = os.path.join(pretrained_model_name_or_path, weights_name)
elif subfolder is not None and os.path.isfile(
os.path.join(pretrained_model_name_or_path, subfolder, weights_name)
):
model_file = os.path.join(pretrained_model_name_or_path, subfolder, weights_name)
else:
raise EnvironmentError(
f"Error no file named {weights_name} found in directory {pretrained_model_name_or_path}."
)
send_telemetry(
{"model_class": cls.__name__, "model_path": "local", "framework": "pytorch"},
name="diffusers_from_pretrained",
)
return model_file
else:
try:
# Load from URL or cache if already cached
model_file = hf_hub_download(
pretrained_model_name_or_path,
filename=weights_name,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
user_agent=user_agent,
subfolder=subfolder,
revision=revision,
)
send_telemetry(
{"model_class": cls.__name__, "model_path": "hub", "framework": "pytorch"},
name="diffusers_from_pretrained",
)
return model_file
except RepositoryNotFoundError:
raise EnvironmentError(
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
"token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
"login`."
)
except RevisionNotFoundError:
raise EnvironmentError(
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for "
"this model name. Check the model page at "
f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
)
except EntryNotFoundError:
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named {weights_name}."
)
except HTTPError as err:
raise EnvironmentError(
"There was a specific connection error when trying to load"
f" {pretrained_model_name_or_path}:\n{err}"
)
except ValueError:
raise EnvironmentError(
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
f" directory containing a file named {weights_name} or"
" \nCheckout your internet connection or see how to run the library in"
" offline mode at 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
)
except EnvironmentError:
raise EnvironmentError(
f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from "
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
f"containing a file named {weights_name}"
)
@classmethod
def _load_pretrained_model(
cls,
@@ -774,86 +677,15 @@ class ModelMixin(torch.nn.Module):
return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable)
def _get_model_file(
pretrained_model_name_or_path,
*,
weights_name,
subfolder,
cache_dir,
force_download,
proxies,
resume_download,
local_files_only,
use_auth_token,
user_agent,
revision,
):
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
if os.path.isdir(pretrained_model_name_or_path):
if os.path.isfile(os.path.join(pretrained_model_name_or_path, weights_name)):
# Load from a PyTorch checkpoint
model_file = os.path.join(pretrained_model_name_or_path, weights_name)
return model_file
elif subfolder is not None and os.path.isfile(
os.path.join(pretrained_model_name_or_path, subfolder, weights_name)
):
model_file = os.path.join(pretrained_model_name_or_path, subfolder, weights_name)
return model_file
else:
raise EnvironmentError(
f"Error no file named {weights_name} found in directory {pretrained_model_name_or_path}."
)
else:
try:
# Load from URL or cache if already cached
model_file = hf_hub_download(
pretrained_model_name_or_path,
filename=weights_name,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
user_agent=user_agent,
subfolder=subfolder,
revision=revision,
)
return model_file
def unwrap_model(model: torch.nn.Module) -> torch.nn.Module:
"""
Recursively unwraps a model from potential containers (as used in distributed training).
except RepositoryNotFoundError:
raise EnvironmentError(
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
"token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
"login`."
)
except RevisionNotFoundError:
raise EnvironmentError(
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for "
"this model name. Check the model page at "
f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
)
except EntryNotFoundError:
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named {weights_name}."
)
except HTTPError as err:
raise EnvironmentError(
f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n{err}"
)
except ValueError:
raise EnvironmentError(
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
f" directory containing a file named {weights_name} or"
" \nCheckout your internet connection or see how to run the library in"
" offline mode at 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
)
except EnvironmentError:
raise EnvironmentError(
f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from "
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
f"containing a file named {weights_name}"
)
Args:
model (`torch.nn.Module`): The model to unwrap.
"""
# since there could be multiple levels of wrapping, unwrap recursively
if hasattr(model, "module"):
return unwrap_model(model.module)
else:
return model

View File

@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import warnings
from dataclasses import dataclass
from typing import Optional
@@ -99,12 +98,8 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
num_vector_embeds: Optional[int] = None,
activation_fn: str = "geglu",
num_embeds_ada_norm: Optional[int] = None,
use_linear_projection: bool = False,
only_cross_attention: bool = False,
upcast_attention: bool = False,
):
super().__init__()
self.use_linear_projection = use_linear_projection
self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim
inner_dim = num_attention_heads * attention_head_dim
@@ -130,10 +125,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
self.in_channels = in_channels
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
if use_linear_projection:
self.proj_in = nn.Linear(in_channels, inner_dim)
else:
self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
elif self.is_input_vectorized:
assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
@@ -159,8 +151,6 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
activation_fn=activation_fn,
num_embeds_ada_norm=num_embeds_ada_norm,
attention_bias=attention_bias,
only_cross_attention=only_cross_attention,
upcast_attention=upcast_attention,
)
for d in range(num_layers)
]
@@ -168,14 +158,15 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
# 4. Define output layers
if self.is_input_continuous:
if use_linear_projection:
self.proj_out = nn.Linear(in_channels, inner_dim)
else:
self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
elif self.is_input_vectorized:
self.norm_out = nn.LayerNorm(inner_dim)
self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
def _set_attention_slice(self, slice_size):
for block in self.transformer_blocks:
block._set_attention_slice(slice_size)
def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True):
"""
Args:
@@ -199,16 +190,10 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
if self.is_input_continuous:
batch, channel, height, weight = hidden_states.shape
residual = hidden_states
hidden_states = self.norm(hidden_states)
if not self.use_linear_projection:
hidden_states = self.proj_in(hidden_states)
inner_dim = hidden_states.shape[1]
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
else:
inner_dim = hidden_states.shape[1]
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
hidden_states = self.proj_in(hidden_states)
hidden_states = self.proj_in(hidden_states)
inner_dim = hidden_states.shape[1]
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
elif self.is_input_vectorized:
hidden_states = self.latent_image_embedding(hidden_states)
@@ -218,17 +203,8 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
# 3. Output
if self.is_input_continuous:
if not self.use_linear_projection:
hidden_states = (
hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
)
hidden_states = self.proj_out(hidden_states)
else:
hidden_states = self.proj_out(hidden_states)
hidden_states = (
hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
)
hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2)
hidden_states = self.proj_out(hidden_states)
output = hidden_states + residual
elif self.is_input_vectorized:
hidden_states = self.norm_out(hidden_states)
@@ -237,13 +213,17 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
logits = logits.permute(0, 2, 1)
# log(p(x_0))
output = F.log_softmax(logits.double(), dim=1).float()
output = F.log_softmax(logits, dim=1, dtype=torch.double).float()
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)
def _set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
for block in self.transformer_blocks:
block._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
class AttentionBlock(nn.Module):
"""
@@ -284,45 +264,11 @@ class AttentionBlock(nn.Module):
self.rescale_output_factor = rescale_output_factor
self.proj_attn = nn.Linear(channels, channels, 1)
self._use_memory_efficient_attention_xformers = False
def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
if not is_xformers_available():
raise ModuleNotFoundError(
"Refer to https://github.com/facebookresearch/xformers for more information on how to install"
" xformers",
name="xformers",
)
elif not torch.cuda.is_available():
raise ValueError(
"torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
" available for GPU "
)
else:
try:
# Make sure we can run the memory efficient attention
_ = xformers.ops.memory_efficient_attention(
torch.randn((1, 2, 40), device="cuda"),
torch.randn((1, 2, 40), device="cuda"),
torch.randn((1, 2, 40), device="cuda"),
)
except Exception as e:
raise e
self._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
def reshape_heads_to_batch_dim(self, tensor):
batch_size, seq_len, dim = tensor.shape
head_size = self.num_heads
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
return tensor
def reshape_batch_dim_to_heads(self, tensor):
batch_size, seq_len, dim = tensor.shape
head_size = self.num_heads
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
return tensor
def transpose_for_scores(self, projection: torch.Tensor) -> torch.Tensor:
new_projection_shape = projection.size()[:-1] + (self.num_heads, -1)
# move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D)
new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3)
return new_projection
def forward(self, hidden_states):
residual = hidden_states
@@ -340,40 +286,62 @@ class AttentionBlock(nn.Module):
scale = 1 / math.sqrt(self.channels / self.num_heads)
query_proj = self.reshape_heads_to_batch_dim(query_proj)
key_proj = self.reshape_heads_to_batch_dim(key_proj)
value_proj = self.reshape_heads_to_batch_dim(value_proj)
if self._use_memory_efficient_attention_xformers:
# Memory efficient attention
hidden_states = xformers.ops.memory_efficient_attention(query_proj, key_proj, value_proj, attn_bias=None)
hidden_states = hidden_states.to(query_proj.dtype)
else:
attention_scores = torch.baddbmm(
torch.empty(
query_proj.shape[0],
query_proj.shape[1],
key_proj.shape[1],
dtype=query_proj.dtype,
device=query_proj.device,
),
query_proj,
key_proj.transpose(-1, -2),
beta=0,
alpha=scale,
# get scores
if self.num_heads > 1:
query_states = (
self.transpose_for_scores(query_proj)
.contiguous()
.view(batch * self.num_heads, height * width, self.num_head_size)
)
attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype)
hidden_states = torch.bmm(attention_probs, value_proj)
key_states = (
self.transpose_for_scores(key_proj)
.transpose(3, 2)
.contiguous()
.view(batch * self.num_heads, self.num_head_size, height * width)
)
value_states = (
self.transpose_for_scores(value_proj)
.contiguous()
.view(batch * self.num_heads, height * width, self.num_head_size)
)
else:
query_states, key_states, value_states = query_proj, key_proj.transpose(-1, -2), value_proj
# reshape hidden_states
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
attention_scores = torch.baddbmm(
torch.empty(
query_states.shape[0],
query_states.shape[1],
key_states.shape[2],
dtype=query_states.dtype,
device=query_states.device,
),
query_states,
key_states,
beta=0,
alpha=scale,
)
attention_probs = torch.softmax(attention_scores, dim=-1, dtype=torch.float).type(attention_scores.dtype)
# compute attention output
hidden_states = torch.bmm(attention_probs, value_states)
if self.num_heads > 1:
hidden_states = (
hidden_states.view(batch, self.num_heads, height * width, self.num_head_size)
.permute(0, 2, 1, 3)
.contiguous()
)
new_hidden_states_shape = hidden_states.size()[:-2] + (self.channels,)
hidden_states = hidden_states.view(new_hidden_states_shape)
# compute next hidden_states
hidden_states = self.proj_attn(hidden_states)
hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width)
# res connect and rescale
hidden_states = (hidden_states + residual) / self.rescale_output_factor
hidden_states = hidden_states + residual
if self.rescale_output_factor != 1.0:
hidden_states = hidden_states / self.rescale_output_factor
return hidden_states
@@ -404,60 +372,40 @@ class BasicTransformerBlock(nn.Module):
activation_fn: str = "geglu",
num_embeds_ada_norm: Optional[int] = None,
attention_bias: bool = False,
only_cross_attention: bool = False,
upcast_attention: bool = False,
):
super().__init__()
self.only_cross_attention = only_cross_attention
self.use_ada_layer_norm = num_embeds_ada_norm is not None
# 1. Self-Attn
self.attn1 = CrossAttention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
upcast_attention=upcast_attention,
) # is a self-attention
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
self.attn2 = CrossAttention(
query_dim=dim,
cross_attention_dim=cross_attention_dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
) # is self-attn if context is none
# 2. Cross-Attn
if cross_attention_dim is not None:
self.attn2 = CrossAttention(
query_dim=dim,
cross_attention_dim=cross_attention_dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
upcast_attention=upcast_attention,
) # is self-attn if context is none
# layer norms
self.use_ada_layer_norm = num_embeds_ada_norm is not None
if self.use_ada_layer_norm:
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm)
else:
self.attn2 = None
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
if cross_attention_dim is not None:
self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
else:
self.norm2 = None
# 3. Feed-forward
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.norm3 = nn.LayerNorm(dim)
# if xformers is installed try to use memory_efficient_attention by default
if is_xformers_available():
try:
self.set_use_memory_efficient_attention_xformers(True)
except Exception as e:
warnings.warn(
"Could not enable memory efficient attention. Make sure xformers is installed"
f" correctly and a GPU is available: {e}"
)
def _set_attention_slice(self, slice_size):
self.attn1._slice_size = slice_size
self.attn2._slice_size = slice_size
def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
def _set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
if not is_xformers_available():
print("Here is how to install it")
raise ModuleNotFoundError(
@@ -488,18 +436,13 @@ class BasicTransformerBlock(nn.Module):
norm_hidden_states = (
self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states)
)
hidden_states = self.attn1(norm_hidden_states) + hidden_states
if self.only_cross_attention:
hidden_states = self.attn1(norm_hidden_states, context) + hidden_states
else:
hidden_states = self.attn1(norm_hidden_states) + hidden_states
if self.attn2 is not None:
# 2. Cross-Attention
norm_hidden_states = (
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
)
hidden_states = self.attn2(norm_hidden_states, context=context) + hidden_states
# 2. Cross-Attention
norm_hidden_states = (
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
)
hidden_states = self.attn2(norm_hidden_states, context=context) + hidden_states
# 3. Feed-forward
hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
@@ -530,19 +473,16 @@ class CrossAttention(nn.Module):
dim_head: int = 64,
dropout: float = 0.0,
bias=False,
upcast_attention: bool = False,
):
super().__init__()
inner_dim = dim_head * heads
cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
self.upcast_attention = upcast_attention
self.scale = dim_head**-0.5
self.heads = heads
# for slice_size > 0 the attention score computation
# is split across the batch axis to save memory
# You can set slice_size with `set_attention_slice`
self.sliceable_head_dim = heads
self._slice_size = None
self._use_memory_efficient_attention_xformers = False
@@ -557,45 +497,48 @@ class CrossAttention(nn.Module):
def reshape_heads_to_batch_dim(self, tensor):
batch_size, seq_len, dim = tensor.shape
head_size = self.heads
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
tensor = tensor.view(batch_size, seq_len, head_size, dim // head_size)
return tensor
def reshape_batch_dim_to_heads(self, tensor):
batch_size, seq_len, dim = tensor.shape
head_size = self.heads
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
tensor = tensor.view(batch_size // head_size, head_size, seq_len, dim)
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
return tensor
def set_attention_slice(self, slice_size):
if slice_size is not None and slice_size > self.sliceable_head_dim:
raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
self._slice_size = slice_size
def forward(self, hidden_states, context=None, mask=None):
batch_size, sequence_length, _ = hidden_states.shape
query = self.to_q(hidden_states)
context = context if context is not None else hidden_states
context_sequence_length = context.shape[1]
key = self.to_k(context)
value = self.to_v(context)
dim = query.shape[-1]
query = self.reshape_heads_to_batch_dim(query)
key = self.reshape_heads_to_batch_dim(key)
value = self.reshape_heads_to_batch_dim(value)
query = (
self.reshape_heads_to_batch_dim(query)
.permute(0, 2, 1, 3)
.reshape(batch_size * self.heads, sequence_length, dim // self.heads)
)
value = (
self.reshape_heads_to_batch_dim(value)
.permute(0, 2, 1, 3)
.reshape(batch_size * self.heads, context_sequence_length, dim // self.heads)
)
# TODO(PVP) - mask is currently never used. Remember to re-implement when used
# attention, what we cannot get enough of
if self._use_memory_efficient_attention_xformers:
key = self.reshape_heads_to_batch_dim(key).permute(0, 2, 1, 3).reshape(batch_size * self.heads, context_sequence_length, dim // self.heads)
hidden_states = self._memory_efficient_attention_xformers(query, key, value)
# Some versions of xformers return output in fp32, cast it back to the dtype of the input
hidden_states = hidden_states.to(query.dtype)
else:
key = self.reshape_heads_to_batch_dim(key).permute(0, 2, 3, 1).reshape(batch_size * self.heads, dim // self.heads, context_sequence_length)
if self._slice_size is None or query.shape[0] // self._slice_size == 1:
hidden_states = self._attention(query, key, value)
else:
@@ -608,23 +551,16 @@ class CrossAttention(nn.Module):
return hidden_states
def _attention(self, query, key, value):
if self.upcast_attention:
query = query.float()
key = key.float()
attention_scores = torch.baddbmm(
torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
torch.empty(query.shape[0], query.shape[1], key.shape[2], dtype=query.dtype, device=query.device),
query,
key.transpose(-1, -2),
key,
beta=0,
alpha=self.scale,
)
attention_probs = attention_scores.softmax(dim=-1)
# cast back to the original dtype
attention_probs = attention_probs.to(value.dtype)
# compute attention output
hidden_states = torch.bmm(attention_probs, value)
# reshape hidden_states
@@ -640,25 +576,14 @@ class CrossAttention(nn.Module):
for i in range(hidden_states.shape[0] // slice_size):
start_idx = i * slice_size
end_idx = (i + 1) * slice_size
query_slice = query[start_idx:end_idx]
key_slice = key[start_idx:end_idx]
if self.upcast_attention:
query_slice = query_slice.float()
key_slice = key_slice.float()
attn_slice = torch.baddbmm(
torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query_slice.dtype, device=query.device),
query_slice,
key_slice.transpose(-1, -2),
torch.empty(slice_size, query.shape[1], key.shape[2], dtype=query.dtype, device=query.device),
query[start_idx:end_idx],
key[start_idx:end_idx],
beta=0,
alpha=self.scale,
)
attn_slice = attn_slice.softmax(dim=-1)
# cast back to the original dtype
attn_slice = attn_slice.to(value.dtype)
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
hidden_states[start_idx:end_idx] = attn_slice
@@ -700,16 +625,14 @@ class FeedForward(nn.Module):
inner_dim = int(dim * mult)
dim_out = dim_out if dim_out is not None else dim
if activation_fn == "gelu":
act_fn = GELU(dim, inner_dim)
elif activation_fn == "geglu":
act_fn = GEGLU(dim, inner_dim)
if activation_fn == "geglu":
geglu = GEGLU(dim, inner_dim)
elif activation_fn == "geglu-approximate":
act_fn = ApproximateGELU(dim, inner_dim)
geglu = ApproximateGELU(dim, inner_dim)
self.net = nn.ModuleList([])
# project in
self.net.append(act_fn)
self.net.append(geglu)
# project dropout
self.net.append(nn.Dropout(dropout))
# project out
@@ -721,27 +644,6 @@ class FeedForward(nn.Module):
return hidden_states
class GELU(nn.Module):
r"""
GELU activation function
"""
def __init__(self, dim_in: int, dim_out: int):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out)
def gelu(self, gate):
if gate.device.type != "mps":
return F.gelu(gate)
# mps: gelu is not implemented for float16
return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
def forward(self, hidden_states):
hidden_states = self.proj(hidden_states)
hidden_states = self.gelu(hidden_states)
return hidden_states
# feedforward
class GEGLU(nn.Module):
r"""
@@ -800,121 +702,3 @@ class AdaLayerNorm(nn.Module):
scale, shift = torch.chunk(emb, 2)
x = self.norm(x) * (1 + scale) + shift
return x
class DualTransformer2DModel(nn.Module):
"""
Dual transformer wrapper that combines two `Transformer2DModel`s for mixed inference.
Parameters:
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
in_channels (`int`, *optional*):
Pass if the input is continuous. The number of channels in the input and output.
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use.
cross_attention_dim (`int`, *optional*): The number of context dimensions to use.
sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
Note that this is fixed at training time as it is used for learning a number of position embeddings. See
`ImagePositionalEmbeddings`.
num_vector_embeds (`int`, *optional*):
Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels.
Includes the class for the masked latent pixel.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`.
The number of diffusion steps used during training. Note that this is fixed at training time as it is used
to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for
up to but not more than steps than `num_embeds_ada_norm`.
attention_bias (`bool`, *optional*):
Configure if the TransformerBlocks' attention should contain a bias parameter.
"""
def __init__(
self,
num_attention_heads: int = 16,
attention_head_dim: int = 88,
in_channels: Optional[int] = None,
num_layers: int = 1,
dropout: float = 0.0,
norm_num_groups: int = 32,
cross_attention_dim: Optional[int] = None,
attention_bias: bool = False,
sample_size: Optional[int] = None,
num_vector_embeds: Optional[int] = None,
activation_fn: str = "geglu",
num_embeds_ada_norm: Optional[int] = None,
):
super().__init__()
self.transformers = nn.ModuleList(
[
Transformer2DModel(
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
in_channels=in_channels,
num_layers=num_layers,
dropout=dropout,
norm_num_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
attention_bias=attention_bias,
sample_size=sample_size,
num_vector_embeds=num_vector_embeds,
activation_fn=activation_fn,
num_embeds_ada_norm=num_embeds_ada_norm,
)
for _ in range(2)
]
)
# Variables that can be set by a pipeline:
# The ratio of transformer1 to transformer2's output states to be combined during inference
self.mix_ratio = 0.5
# The shape of `encoder_hidden_states` is expected to be
# `(batch_size, condition_lengths[0]+condition_lengths[1], num_features)`
self.condition_lengths = [77, 257]
# Which transformer to use to encode which condition.
# E.g. `(1, 0)` means that we'll use `transformers[1](conditions[0])` and `transformers[0](conditions[1])`
self.transformer_index_for_condition = [1, 0]
def forward(self, hidden_states, encoder_hidden_states, timestep=None, return_dict: bool = True):
"""
Args:
hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
hidden_states
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, context dim)`, *optional*):
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
self-attention.
timestep ( `torch.long`, *optional*):
Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
Returns:
[`~models.attention.Transformer2DModelOutput`] or `tuple`: [`~models.attention.Transformer2DModelOutput`]
if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample
tensor.
"""
input_states = hidden_states
encoded_states = []
tokens_start = 0
for i in range(2):
# for each of the two transformers, pass the corresponding condition tokens
condition_state = encoder_hidden_states[:, tokens_start : tokens_start + self.condition_lengths[i]]
transformer_index = self.transformer_index_for_condition[i]
encoded_state = self.transformers[transformer_index](input_states, condition_state, timestep, return_dict)[
0
]
encoded_states.append(encoded_state - input_states)
tokens_start += self.condition_lengths[i]
output_states = encoded_states[0] * self.mix_ratio + encoded_states[1] * (1 - self.mix_ratio)
output_states = output_states + input_states
if not return_dict:
return (output_states,)
return Transformer2DModelOutput(sample=output_states)

View File

@@ -104,8 +104,6 @@ class FlaxBasicTransformerBlock(nn.Module):
Hidden states dimension inside each head
dropout (:obj:`float`, *optional*, defaults to 0.0):
Dropout rate
only_cross_attention (`bool`, defaults to `False`):
Whether to only apply cross attention.
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
"""
@@ -113,11 +111,10 @@ class FlaxBasicTransformerBlock(nn.Module):
n_heads: int
d_head: int
dropout: float = 0.0
only_cross_attention: bool = False
dtype: jnp.dtype = jnp.float32
def setup(self):
# self attention (or cross_attention if only_cross_attention is True)
# self attention
self.attn1 = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype)
# cross attention
self.attn2 = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype)
@@ -129,10 +126,7 @@ class FlaxBasicTransformerBlock(nn.Module):
def __call__(self, hidden_states, context, deterministic=True):
# self attention
residual = hidden_states
if self.only_cross_attention:
hidden_states = self.attn1(self.norm1(hidden_states), context, deterministic=deterministic)
else:
hidden_states = self.attn1(self.norm1(hidden_states), deterministic=deterministic)
hidden_states = self.attn1(self.norm1(hidden_states), deterministic=deterministic)
hidden_states = hidden_states + residual
# cross attention
@@ -165,8 +159,6 @@ class FlaxTransformer2DModel(nn.Module):
Number of transformers block
dropout (:obj:`float`, *optional*, defaults to 0.0):
Dropout rate
use_linear_projection (`bool`, defaults to `False`): tbd
only_cross_attention (`bool`, defaults to `False`): tbd
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
"""
@@ -175,70 +167,49 @@ class FlaxTransformer2DModel(nn.Module):
d_head: int
depth: int = 1
dropout: float = 0.0
use_linear_projection: bool = False
only_cross_attention: bool = False
dtype: jnp.dtype = jnp.float32
def setup(self):
self.norm = nn.GroupNorm(num_groups=32, epsilon=1e-5)
inner_dim = self.n_heads * self.d_head
if self.use_linear_projection:
self.proj_in = nn.Dense(inner_dim, dtype=self.dtype)
else:
self.proj_in = nn.Conv(
inner_dim,
kernel_size=(1, 1),
strides=(1, 1),
padding="VALID",
dtype=self.dtype,
)
self.proj_in = nn.Conv(
inner_dim,
kernel_size=(1, 1),
strides=(1, 1),
padding="VALID",
dtype=self.dtype,
)
self.transformer_blocks = [
FlaxBasicTransformerBlock(
inner_dim,
self.n_heads,
self.d_head,
dropout=self.dropout,
only_cross_attention=self.only_cross_attention,
dtype=self.dtype,
)
FlaxBasicTransformerBlock(inner_dim, self.n_heads, self.d_head, dropout=self.dropout, dtype=self.dtype)
for _ in range(self.depth)
]
if self.use_linear_projection:
self.proj_out = nn.Dense(inner_dim, dtype=self.dtype)
else:
self.proj_out = nn.Conv(
inner_dim,
kernel_size=(1, 1),
strides=(1, 1),
padding="VALID",
dtype=self.dtype,
)
self.proj_out = nn.Conv(
inner_dim,
kernel_size=(1, 1),
strides=(1, 1),
padding="VALID",
dtype=self.dtype,
)
def __call__(self, hidden_states, context, deterministic=True):
batch, height, width, channels = hidden_states.shape
residual = hidden_states
hidden_states = self.norm(hidden_states)
if self.use_linear_projection:
hidden_states = hidden_states.reshape(batch, height * width, channels)
hidden_states = self.proj_in(hidden_states)
else:
hidden_states = self.proj_in(hidden_states)
hidden_states = hidden_states.reshape(batch, height * width, channels)
hidden_states = self.proj_in(hidden_states)
hidden_states = hidden_states.reshape(batch, height * width, channels)
for transformer_block in self.transformer_blocks:
hidden_states = transformer_block(hidden_states, context, deterministic=deterministic)
if self.use_linear_projection:
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.reshape(batch, height, width, channels)
else:
hidden_states = hidden_states.reshape(batch, height, width, channels)
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.reshape(batch, height, width, channels)
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states + residual
return hidden_states

View File

@@ -49,11 +49,14 @@ def get_timestep_embedding(
emb = scale * emb
# concat sine and cosine embeddings
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
sin = torch.sin(emb)
cos = torch.cos(emb)
# flip sine and cosine embeddings
if flip_sin_to_cos:
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
emb = torch.cat([cos, sin], dim=-1)
else:
emb = torch.cat([sin, cos], dim=-1)
# zero pad
if embedding_dim % 2 == 1:
@@ -126,7 +129,7 @@ class GaussianFourierProjection(nn.Module):
if self.log:
x = torch.log(x)
x_proj = x[:, None] * self.weight[None, :] * 2 * np.pi
x_proj = x[:, None] * self.weight[None, :] * (2 * np.pi)
if self.flip_sin_to_cos:
out = torch.cat([torch.cos(x_proj), torch.sin(x_proj)], dim=-1)

View File

@@ -84,11 +84,10 @@ class FlaxTimesteps(nn.Module):
Time step embedding dimension
"""
dim: int = 32
flip_sin_to_cos: bool = False
freq_shift: float = 1
@nn.compact
def __call__(self, timesteps):
return get_sinusoidal_embeddings(
timesteps, embedding_dim=self.dim, flip_sin_to_cos=self.flip_sin_to_cos, freq_shift=self.freq_shift
timesteps, embedding_dim=self.dim, freq_shift=self.freq_shift, flip_sin_to_cos=True
)

View File

@@ -476,7 +476,9 @@ class ResnetBlock2D(nn.Module):
if self.conv_shortcut is not None:
input_tensor = self.conv_shortcut(input_tensor)
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
output_tensor = input_tensor + hidden_states
if self.output_scale_factor != 1.0:
output_tensor = output_tensor / self.output_scale_factor
return output_tensor

View File

@@ -43,8 +43,8 @@ class UNet2DModel(ModelMixin, ConfigMixin):
implements for all the model (such as downloading or saving, etc.)
Parameters:
sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
Height and width of input/output sample.
sample_size (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`, *optional*):
Input sample size.
in_channels (`int`, *optional*, defaults to 3): Number of channels in the input image.
out_channels (`int`, *optional*, defaults to 3): Number of channels in the output.
center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
@@ -71,7 +71,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
@register_to_config
def __init__(
self,
sample_size: Optional[Union[int, Tuple[int, int]]] = None,
sample_size: Optional[int] = None,
in_channels: int = 3,
out_channels: int = 3,
center_input_sample: bool = False,
@@ -175,7 +175,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32)
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=num_groups_out, eps=norm_eps)
self.conv_act = nn.SiLU()
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
def forward(
self,
@@ -209,11 +209,6 @@ class UNet2DModel(ModelMixin, ConfigMixin):
timesteps = timesteps * torch.ones(sample.shape[0], dtype=timesteps.dtype, device=timesteps.device)
t_emb = self.time_proj(timesteps)
# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=self.dtype)
emb = self.time_embedding(t_emb)
# 2. pre-process
@@ -247,7 +242,9 @@ class UNet2DModel(ModelMixin, ConfigMixin):
sample = upsample_block(sample, res_samples, emb)
# 6. post-process
sample = self.conv_norm_out(sample)
# make sure hidden states is in float32
# when running in half-precision
sample = self.conv_norm_out(sample.float()).type(sample.dtype)
sample = self.conv_act(sample)
sample = self.conv_out(sample)

View File

@@ -15,7 +15,7 @@ import numpy as np
import torch
from torch import nn
from .attention import AttentionBlock, DualTransformer2DModel, Transformer2DModel
from .attention import AttentionBlock, Transformer2DModel
from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, ResnetBlock2D, Upsample2D
@@ -32,10 +32,6 @@ def get_down_block(
resnet_groups=None,
cross_attention_dim=None,
downsample_padding=None,
dual_cross_attention=False,
use_linear_projection=False,
only_cross_attention=False,
upcast_attention=False,
):
down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
if down_block_type == "DownBlock2D":
@@ -78,10 +74,6 @@ def get_down_block(
downsample_padding=downsample_padding,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attn_num_head_channels,
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
upcast_attention=upcast_attention,
)
elif down_block_type == "SkipDownBlock2D":
return SkipDownBlock2D(
@@ -145,10 +137,6 @@ def get_up_block(
attn_num_head_channels,
resnet_groups=None,
cross_attention_dim=None,
dual_cross_attention=False,
use_linear_projection=False,
only_cross_attention=False,
upcast_attention=False,
):
up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
if up_block_type == "UpBlock2D":
@@ -178,10 +166,6 @@ def get_up_block(
resnet_groups=resnet_groups,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attn_num_head_channels,
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
upcast_attention=upcast_attention,
)
elif up_block_type == "AttnUpBlock2D":
return AttnUpBlock2D(
@@ -258,6 +242,7 @@ class UNetMidBlock2D(nn.Module):
attn_num_head_channels=1,
attention_type="default",
output_scale_factor=1.0,
**kwargs,
):
super().__init__()
@@ -337,13 +322,10 @@ class UNetMidBlock2DCrossAttn(nn.Module):
attention_type="default",
output_scale_factor=1.0,
cross_attention_dim=1280,
dual_cross_attention=False,
use_linear_projection=False,
upcast_attention=False,
**kwargs,
):
super().__init__()
self.has_cross_attention = True
self.attention_type = attention_type
self.attn_num_head_channels = attn_num_head_channels
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
@@ -366,30 +348,16 @@ class UNetMidBlock2DCrossAttn(nn.Module):
attentions = []
for _ in range(num_layers):
if not dual_cross_attention:
attentions.append(
Transformer2DModel(
attn_num_head_channels,
in_channels // attn_num_head_channels,
in_channels=in_channels,
num_layers=1,
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
use_linear_projection=use_linear_projection,
upcast_attention=upcast_attention,
)
)
else:
attentions.append(
DualTransformer2DModel(
attn_num_head_channels,
in_channels // attn_num_head_channels,
in_channels=in_channels,
num_layers=1,
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
)
attentions.append(
Transformer2DModel(
attn_num_head_channels,
in_channels // attn_num_head_channels,
in_channels=in_channels,
num_layers=1,
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
)
)
resnets.append(
ResnetBlock2D(
in_channels=in_channels,
@@ -408,6 +376,25 @@ class UNetMidBlock2DCrossAttn(nn.Module):
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
def set_attention_slice(self, slice_size):
if slice_size is not None and self.attn_num_head_channels % slice_size != 0:
raise ValueError(
f"Make sure slice_size {slice_size} is a divisor of "
f"the number of heads used in cross_attention {self.attn_num_head_channels}"
)
if slice_size is not None and slice_size > self.attn_num_head_channels:
raise ValueError(
f"Chunk_size {slice_size} has to be smaller or equal to "
f"the number of heads used in cross_attention {self.attn_num_head_channels}"
)
for attn in self.attentions:
attn._set_attention_slice(slice_size)
def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
for attn in self.attentions:
attn._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
@@ -518,16 +505,11 @@ class CrossAttnDownBlock2D(nn.Module):
output_scale_factor=1.0,
downsample_padding=1,
add_downsample=True,
dual_cross_attention=False,
use_linear_projection=False,
only_cross_attention=False,
upcast_attention=False,
):
super().__init__()
resnets = []
attentions = []
self.has_cross_attention = True
self.attention_type = attention_type
self.attn_num_head_channels = attn_num_head_channels
@@ -547,31 +529,16 @@ class CrossAttnDownBlock2D(nn.Module):
pre_norm=resnet_pre_norm,
)
)
if not dual_cross_attention:
attentions.append(
Transformer2DModel(
attn_num_head_channels,
out_channels // attn_num_head_channels,
in_channels=out_channels,
num_layers=1,
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
upcast_attention=upcast_attention,
)
)
else:
attentions.append(
DualTransformer2DModel(
attn_num_head_channels,
out_channels // attn_num_head_channels,
in_channels=out_channels,
num_layers=1,
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
)
attentions.append(
Transformer2DModel(
attn_num_head_channels,
out_channels // attn_num_head_channels,
in_channels=out_channels,
num_layers=1,
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
)
)
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
@@ -588,6 +555,25 @@ class CrossAttnDownBlock2D(nn.Module):
self.gradient_checkpointing = False
def set_attention_slice(self, slice_size):
if slice_size is not None and self.attn_num_head_channels % slice_size != 0:
raise ValueError(
f"Make sure slice_size {slice_size} is a divisor of "
f"the number of heads used in cross_attention {self.attn_num_head_channels}"
)
if slice_size is not None and slice_size > self.attn_num_head_channels:
raise ValueError(
f"Chunk_size {slice_size} has to be smaller or equal to "
f"the number of heads used in cross_attention {self.attn_num_head_channels}"
)
for attn in self.attentions:
attn._set_attention_slice(slice_size)
def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
for attn in self.attentions:
attn._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
output_states = ()
@@ -1068,10 +1054,7 @@ class AttnUpBlock2D(nn.Module):
self.upsamplers = None
def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
for resnet, attn in zip(self.resnets, self.attentions):
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
for resnet, attn, res_hidden_states in zip(self.resnets, self.attentions, reversed(res_hidden_states_tuple)):
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
hidden_states = resnet(hidden_states, temb)
@@ -1103,16 +1086,11 @@ class CrossAttnUpBlock2D(nn.Module):
attention_type="default",
output_scale_factor=1.0,
add_upsample=True,
dual_cross_attention=False,
use_linear_projection=False,
only_cross_attention=False,
upcast_attention=False,
):
super().__init__()
resnets = []
attentions = []
self.has_cross_attention = True
self.attention_type = attention_type
self.attn_num_head_channels = attn_num_head_channels
@@ -1134,31 +1112,16 @@ class CrossAttnUpBlock2D(nn.Module):
pre_norm=resnet_pre_norm,
)
)
if not dual_cross_attention:
attentions.append(
Transformer2DModel(
attn_num_head_channels,
out_channels // attn_num_head_channels,
in_channels=out_channels,
num_layers=1,
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
upcast_attention=upcast_attention,
)
)
else:
attentions.append(
DualTransformer2DModel(
attn_num_head_channels,
out_channels // attn_num_head_channels,
in_channels=out_channels,
num_layers=1,
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
)
attentions.append(
Transformer2DModel(
attn_num_head_channels,
out_channels // attn_num_head_channels,
in_channels=out_channels,
num_layers=1,
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
)
)
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
@@ -1169,6 +1132,27 @@ class CrossAttnUpBlock2D(nn.Module):
self.gradient_checkpointing = False
def set_attention_slice(self, slice_size):
if slice_size is not None and self.attn_num_head_channels % slice_size != 0:
raise ValueError(
f"Make sure slice_size {slice_size} is a divisor of "
f"the number of heads used in cross_attention {self.attn_num_head_channels}"
)
if slice_size is not None and slice_size > self.attn_num_head_channels:
raise ValueError(
f"Chunk_size {slice_size} has to be smaller or equal to "
f"the number of heads used in cross_attention {self.attn_num_head_channels}"
)
for attn in self.attentions:
attn._set_attention_slice(slice_size)
self.gradient_checkpointing = False
def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
for attn in self.attentions:
attn._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
def forward(
self,
hidden_states,

View File

@@ -46,8 +46,6 @@ class FlaxCrossAttnDownBlock2D(nn.Module):
num_layers: int = 1
attn_num_head_channels: int = 1
add_downsample: bool = True
use_linear_projection: bool = False
only_cross_attention: bool = False
dtype: jnp.dtype = jnp.float32
def setup(self):
@@ -70,8 +68,6 @@ class FlaxCrossAttnDownBlock2D(nn.Module):
n_heads=self.attn_num_head_channels,
d_head=self.out_channels // self.attn_num_head_channels,
depth=1,
use_linear_projection=self.use_linear_projection,
only_cross_attention=self.only_cross_attention,
dtype=self.dtype,
)
attentions.append(attn_block)
@@ -182,8 +178,6 @@ class FlaxCrossAttnUpBlock2D(nn.Module):
num_layers: int = 1
attn_num_head_channels: int = 1
add_upsample: bool = True
use_linear_projection: bool = False
only_cross_attention: bool = False
dtype: jnp.dtype = jnp.float32
def setup(self):
@@ -207,8 +201,6 @@ class FlaxCrossAttnUpBlock2D(nn.Module):
n_heads=self.attn_num_head_channels,
d_head=self.out_channels // self.attn_num_head_channels,
depth=1,
use_linear_projection=self.use_linear_projection,
only_cross_attention=self.only_cross_attention,
dtype=self.dtype,
)
attentions.append(attn_block)
@@ -318,7 +310,6 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module):
dropout: float = 0.0
num_layers: int = 1
attn_num_head_channels: int = 1
use_linear_projection: bool = False
dtype: jnp.dtype = jnp.float32
def setup(self):
@@ -340,7 +331,6 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module):
n_heads=self.attn_num_head_channels,
d_head=self.in_channels // self.attn_num_head_channels,
depth=1,
use_linear_projection=self.use_linear_projection,
dtype=self.dtype,
)
attentions.append(attn_block)

View File

@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
@@ -56,12 +56,11 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
implements for all the models (such as downloading or saving, etc.)
Parameters:
sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
Height and width of input/output sample.
sample_size (`int`, *optional*): The size of the input sample.
in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample.
out_channels (`int`, *optional*, defaults to 4): The number of channels in the output.
center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
flip_sin_to_cos (`bool`, *optional*, defaults to `True`):
Whether to flip the sin to cos in the time embedding.
freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
@@ -98,7 +97,6 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
"DownBlock2D",
),
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
only_cross_attention: Union[bool, Tuple[bool]] = False,
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
layers_per_block: int = 2,
downsample_padding: int = 1,
@@ -107,11 +105,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
norm_num_groups: int = 32,
norm_eps: float = 1e-5,
cross_attention_dim: int = 1280,
attention_head_dim: Union[int, Tuple[int]] = 8,
dual_cross_attention: bool = False,
use_linear_projection: bool = False,
num_class_embeds: Optional[int] = None,
upcast_attention: bool = False,
attention_head_dim: int = 8,
):
super().__init__()
@@ -127,20 +121,10 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
# class embedding
if num_class_embeds is not None:
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
self.down_blocks = nn.ModuleList([])
self.mid_block = None
self.up_blocks = nn.ModuleList([])
if isinstance(only_cross_attention, bool):
only_cross_attention = [only_cross_attention] * len(down_block_types)
if isinstance(attention_head_dim, int):
attention_head_dim = (attention_head_dim,) * len(down_block_types)
# down
output_channel = block_out_channels[0]
for i, down_block_type in enumerate(down_block_types):
@@ -159,12 +143,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attention_head_dim[i],
attn_num_head_channels=attention_head_dim,
downsample_padding=downsample_padding,
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention[i],
upcast_attention=upcast_attention,
)
self.down_blocks.append(down_block)
@@ -177,11 +157,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
output_scale_factor=mid_block_scale_factor,
resnet_time_scale_shift="default",
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attention_head_dim[-1],
attn_num_head_channels=attention_head_dim,
resnet_groups=norm_num_groups,
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
upcast_attention=upcast_attention,
)
# count how many layers upsample the images
@@ -189,8 +166,6 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
# up
reversed_block_out_channels = list(reversed(block_out_channels))
reversed_attention_head_dim = list(reversed(attention_head_dim))
only_cross_attention = list(reversed(only_cross_attention))
output_channel = reversed_block_out_channels[0]
for i, up_block_type in enumerate(up_block_types):
is_final_block = i == len(block_out_channels) - 1
@@ -218,11 +193,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=reversed_attention_head_dim[i],
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention[i],
upcast_attention=upcast_attention,
attn_num_head_channels=attention_head_dim,
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel
@@ -230,72 +201,40 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
# out
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
self.conv_act = nn.SiLU()
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
def set_attention_slice(self, slice_size):
r"""
Enable sliced attention computation.
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
in several steps. This is useful to save some memory in exchange for a small speed decrease.
Args:
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
`"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
must be a multiple of `slice_size`.
"""
sliceable_head_dims = []
def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module):
if hasattr(module, "set_attention_slice"):
sliceable_head_dims.append(module.sliceable_head_dim)
for child in module.children():
fn_recursive_retrieve_slicable_dims(child)
# retrieve number of attention layers
for module in self.children():
fn_recursive_retrieve_slicable_dims(module)
num_slicable_layers = len(sliceable_head_dims)
if slice_size == "auto":
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = [dim // 2 for dim in sliceable_head_dims]
elif slice_size == "max":
# make smallest slice possible
slice_size = num_slicable_layers * [1]
slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
if len(slice_size) != len(sliceable_head_dims):
if slice_size is not None and self.config.attention_head_dim % slice_size != 0:
raise ValueError(
f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
f"Make sure slice_size {slice_size} is a divisor of "
f"the number of heads used in cross_attention {self.config.attention_head_dim}"
)
if slice_size is not None and slice_size > self.config.attention_head_dim:
raise ValueError(
f"Chunk_size {slice_size} has to be smaller or equal to "
f"the number of heads used in cross_attention {self.config.attention_head_dim}"
)
for i in range(len(slice_size)):
size = slice_size[i]
dim = sliceable_head_dims[i]
if size is not None and size > dim:
raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
for block in self.down_blocks:
if hasattr(block, "attentions") and block.attentions is not None:
block.set_attention_slice(slice_size)
# Recursively walk through all the children.
# Any children which exposes the set_attention_slice method
# gets the message
def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
if hasattr(module, "set_attention_slice"):
module.set_attention_slice(slice_size.pop())
self.mid_block.set_attention_slice(slice_size)
for child in module.children():
fn_recursive_set_attention_slice(child, slice_size)
for block in self.up_blocks:
if hasattr(block, "attentions") and block.attentions is not None:
block.set_attention_slice(slice_size)
reversed_slice_size = list(reversed(slice_size))
for module in self.children():
fn_recursive_set_attention_slice(module, reversed_slice_size)
def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
for block in self.down_blocks:
if hasattr(block, "attentions") and block.attentions is not None:
block.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
self.mid_block.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
for block in self.up_blocks:
if hasattr(block, "attentions") and block.attentions is not None:
block.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D)):
@@ -306,14 +245,14 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
sample: torch.FloatTensor,
timestep: Union[torch.Tensor, float, int],
encoder_hidden_states: torch.Tensor,
class_labels: Optional[torch.Tensor] = None,
return_dict: bool = True,
) -> Union[UNet2DConditionOutput, Tuple]:
r"""
Args:
sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
encoder_hidden_states (`torch.FloatTensor`):
(batch_size, sequence_length, hidden_size) encoder hidden states
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
@@ -344,14 +283,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
timesteps = timestep
if not torch.is_tensor(timesteps):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = sample.device.type == "mps"
if isinstance(timestep, float):
dtype = torch.float32 if is_mps else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0:
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
@@ -365,19 +298,13 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
t_emb = t_emb.to(dtype=self.dtype)
emb = self.time_embedding(t_emb)
if self.config.num_class_embeds is not None:
if class_labels is None:
raise ValueError("class_labels should be provided when num_class_embeds > 0")
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
emb = emb + class_emb
# 2. pre-process
sample = self.conv_in(sample)
# 3. down
down_block_res_samples = (sample,)
for downsample_block in self.down_blocks:
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
if hasattr(downsample_block, "attentions") and downsample_block.attentions is not None:
sample, res_samples = downsample_block(
hidden_states=sample,
temb=emb,
@@ -403,7 +330,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
if not is_final_block and forward_upsample_size:
upsample_size = down_block_res_samples[-1].shape[2:]
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
if hasattr(upsample_block, "attentions") and upsample_block.attentions is not None:
sample = upsample_block(
hidden_states=sample,
temb=emb,

View File

@@ -79,16 +79,12 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
The tuple of output channels for each block.
layers_per_block (`int`, *optional*, defaults to 2):
The number of layers per block.
attention_head_dim (`int` or `Tuple[int]`, *optional*, defaults to 8):
attention_head_dim (`int`, *optional*, defaults to 8):
The dimension of the attention heads.
cross_attention_dim (`int`, *optional*, defaults to 768):
The dimension of the cross attention features.
dropout (`float`, *optional*, defaults to 0):
Dropout probability for down, up and bottleneck blocks.
flip_sin_to_cos (`bool`, *optional*, defaults to `True`):
Whether to flip the sin to cos in the time embedding.
freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
"""
sample_size: int = 32
@@ -101,15 +97,12 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
"DownBlock2D",
)
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")
only_cross_attention: Union[bool, Tuple[bool]] = False
block_out_channels: Tuple[int] = (320, 640, 1280, 1280)
layers_per_block: int = 2
attention_head_dim: Union[int, Tuple[int]] = 8
attention_head_dim: int = 8
cross_attention_dim: int = 1280
dropout: float = 0.0
use_linear_projection: bool = False
dtype: jnp.dtype = jnp.float32
flip_sin_to_cos: bool = True
freq_shift: int = 0
def init_weights(self, rng: jax.random.PRNGKey) -> FrozenDict:
@@ -138,19 +131,9 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
)
# time
self.time_proj = FlaxTimesteps(
block_out_channels[0], flip_sin_to_cos=self.flip_sin_to_cos, freq_shift=self.config.freq_shift
)
self.time_proj = FlaxTimesteps(block_out_channels[0], freq_shift=self.config.freq_shift)
self.time_embedding = FlaxTimestepEmbedding(time_embed_dim, dtype=self.dtype)
only_cross_attention = self.only_cross_attention
if isinstance(only_cross_attention, bool):
only_cross_attention = (only_cross_attention,) * len(self.down_block_types)
attention_head_dim = self.attention_head_dim
if isinstance(attention_head_dim, int):
attention_head_dim = (attention_head_dim,) * len(self.down_block_types)
# down
down_blocks = []
output_channel = block_out_channels[0]
@@ -165,10 +148,8 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
out_channels=output_channel,
dropout=self.dropout,
num_layers=self.layers_per_block,
attn_num_head_channels=attention_head_dim[i],
attn_num_head_channels=self.attention_head_dim,
add_downsample=not is_final_block,
use_linear_projection=self.use_linear_projection,
only_cross_attention=only_cross_attention[i],
dtype=self.dtype,
)
else:
@@ -188,16 +169,13 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
self.mid_block = FlaxUNetMidBlock2DCrossAttn(
in_channels=block_out_channels[-1],
dropout=self.dropout,
attn_num_head_channels=attention_head_dim[-1],
use_linear_projection=self.use_linear_projection,
attn_num_head_channels=self.attention_head_dim,
dtype=self.dtype,
)
# up
up_blocks = []
reversed_block_out_channels = list(reversed(block_out_channels))
reversed_attention_head_dim = list(reversed(attention_head_dim))
only_cross_attention = list(reversed(only_cross_attention))
output_channel = reversed_block_out_channels[0]
for i, up_block_type in enumerate(self.up_block_types):
prev_output_channel = output_channel
@@ -212,11 +190,9 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
out_channels=output_channel,
prev_output_channel=prev_output_channel,
num_layers=self.layers_per_block + 1,
attn_num_head_channels=reversed_attention_head_dim[i],
attn_num_head_channels=self.attention_head_dim,
add_upsample=not is_final_block,
dropout=self.dropout,
use_linear_projection=self.use_linear_projection,
only_cross_attention=only_cross_attention[i],
dtype=self.dtype,
)
else:

View File

@@ -290,10 +290,15 @@ class VectorQuantizer(nn.Module):
# reshape z -> (batch, height, width, channel) and flatten
z = z.permute(0, 2, 3, 1).contiguous()
z_flattened = z.view(-1, self.vq_embed_dim)
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
min_encoding_indices = torch.argmin(torch.cdist(z_flattened, self.embedding.weight), dim=1)
d = (
torch.sum(z_flattened**2, dim=1, keepdim=True)
+ torch.sum(self.embedding.weight**2, dim=1)
- 2 * torch.einsum("bd,dn->bn", z_flattened, self.embedding.weight.t())
)
min_encoding_indices = torch.argmin(d, dim=1)
z_q = self.embedding(min_encoding_indices).view(z.shape)
perplexity = None
min_encodings = None
@@ -560,7 +565,6 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
self.quant_conv = torch.nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
self.post_quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1)
self.use_slicing = False
def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
h = self.encoder(x)
@@ -572,7 +576,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
return AutoencoderKLOutput(latent_dist=posterior)
def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
z = self.post_quant_conv(z)
dec = self.decoder(z)
@@ -581,34 +585,6 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
return DecoderOutput(sample=dec)
def enable_slicing(self):
r"""
Enable sliced VAE decoding.
When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
def disable_slicing(self):
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously invoked, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
if self.use_slicing and z.shape[0] > 1:
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
decoded = torch.cat(decoded_slices)
else:
decoded = self._decode(z).sample
if not return_dict:
return (decoded,)
return DecoderOutput(sample=decoded)
def forward(
self,
sample: torch.FloatTensor,

View File

@@ -29,7 +29,7 @@ from PIL import Image
from tqdm.auto import tqdm
from .configuration_utils import ConfigMixin
from .hub_utils import http_user_agent, send_telemetry
from .hub_utils import http_user_agent
from .modeling_flax_utils import FLAX_WEIGHTS_NAME, FlaxModelMixin
from .schedulers.scheduling_utils_flax import SCHEDULER_CONFIG_NAME, FlaxSchedulerMixin
from .utils import CONFIG_NAME, DIFFUSERS_CACHE, BaseOutput, is_transformers_available, logging
@@ -317,8 +317,8 @@ class FlaxDiffusionPipeline(ConfigMixin):
allow_patterns = [os.path.join(k, "*") for k in folder_names]
allow_patterns += [FLAX_WEIGHTS_NAME, SCHEDULER_CONFIG_NAME, CONFIG_NAME, cls.config_name]
# make sure we don't download PyTorch weights, unless when using from_pt
ignore_patterns = "*.bin" if not from_pt else []
# make sure we don't download PyTorch weights
ignore_patterns = "*.bin"
if cls != FlaxDiffusionPipeline:
requested_pipeline_class = cls.__name__
@@ -346,16 +346,8 @@ class FlaxDiffusionPipeline(ConfigMixin):
ignore_patterns=ignore_patterns,
user_agent=user_agent,
)
send_telemetry(
{"pipeline_class": requested_pipeline_class, "pipeline_path": "hub", "framework": "flax"},
name="diffusers_from_pretrained",
)
else:
cached_folder = pretrained_model_name_or_path
send_telemetry(
{"pipeline_class": cls.__name__, "pipeline_path": "local", "framework": "flax"},
name="diffusers_from_pretrained",
)
config_dict = cls.load_config(cached_folder)
@@ -419,13 +411,13 @@ class FlaxDiffusionPipeline(ConfigMixin):
f" {expected_class_obj}"
)
elif passed_class_obj[name] is None:
logger.warning(
logger.warn(
f"You have passed `None` for {name} to disable its functionality in {pipeline_class}. Note"
f" that this might lead to problems when using {pipeline_class} and is not recommended."
)
sub_model_should_be_defined = False
else:
logger.warning(
logger.warn(
f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it"
" has the correct type"
)

View File

@@ -26,14 +26,14 @@ import torch
import diffusers
import PIL
from huggingface_hub import model_info, snapshot_download
from huggingface_hub import snapshot_download
from packaging import version
from PIL import Image
from tqdm.auto import tqdm
from .configuration_utils import ConfigMixin
from .dynamic_modules_utils import get_class_from_dynamic_module
from .hub_utils import http_user_agent, send_telemetry
from .hub_utils import http_user_agent
from .modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT
from .schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
from .utils import (
@@ -44,7 +44,6 @@ from .utils import (
BaseOutput,
deprecate,
is_accelerate_available,
is_safetensors_available,
is_torch_version,
is_transformers_available,
logging,
@@ -118,23 +117,6 @@ class AudioPipelineOutput(BaseOutput):
audios: np.ndarray
def is_safetensors_compatible(info) -> bool:
filenames = set(sibling.rfilename for sibling in info.siblings)
pt_filenames = set(filename for filename in filenames if filename.endswith(".bin"))
is_safetensors_compatible = any(file.endswith(".safetensors") for file in filenames)
for pt_filename in pt_filenames:
prefix, raw = os.path.split(pt_filename)
if raw == "pytorch_model.bin":
# transformers specific
sf_filename = os.path.join(prefix, "model.safetensors")
else:
sf_filename = pt_filename[: -len(".bin")] + ".safetensors"
if is_safetensors_compatible and sf_filename not in filenames:
logger.warning(f"{sf_filename} not found")
is_safetensors_compatible = False
return is_safetensors_compatible
class DiffusionPipeline(ConfigMixin):
r"""
Base class for all models.
@@ -147,13 +129,10 @@ class DiffusionPipeline(ConfigMixin):
Class attributes:
- **config_name** (`str`) -- name of the config file that will store the class and module names of all
- **config_name** ([`str`]) -- name of the config file that will store the class and module names of all
components of the diffusion pipeline.
- **_optional_components** (List[`str`]) -- list of all components that are optional so they don't have to be
passed for the pipeline to function (should be overridden by subclasses).
"""
config_name = "model_index.json"
_optional_components = []
def register_modules(self, **kwargs):
# import it here to avoid circular import
@@ -188,11 +167,7 @@ class DiffusionPipeline(ConfigMixin):
# set models
setattr(self, name, module)
def save_pretrained(
self,
save_directory: Union[str, os.PathLike],
safe_serialization: bool = False,
):
def save_pretrained(self, save_directory: Union[str, os.PathLike]):
"""
Save all variables of the pipeline that can be saved and loaded as well as the pipelines configuration file to
a directory. A pipeline variable can be saved and loaded if its class implements both a save and loading
@@ -201,8 +176,6 @@ class DiffusionPipeline(ConfigMixin):
Arguments:
save_directory (`str` or `os.PathLike`):
Directory to which to save. Will be created if it doesn't exist.
safe_serialization (`bool`, *optional*, defaults to `False`):
Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
"""
self.save_config(save_directory)
@@ -211,19 +184,12 @@ class DiffusionPipeline(ConfigMixin):
model_index_dict.pop("_diffusers_version")
model_index_dict.pop("_module", None)
expected_modules, optional_kwargs = self._get_signature_keys(self)
def is_saveable_module(name, value):
if name not in expected_modules:
return False
if name in self._optional_components and value[0] is None:
return False
return True
model_index_dict = {k: v for k, v in model_index_dict.items() if is_saveable_module(k, v)}
for pipeline_component_name in model_index_dict.keys():
sub_model = getattr(self, pipeline_component_name)
if sub_model is None:
# edge case for saving a pipeline with safety_checker=None
continue
model_cls = sub_model.__class__
save_method_name = None
@@ -240,16 +206,7 @@ class DiffusionPipeline(ConfigMixin):
break
save_method = getattr(sub_model, save_method_name)
# Call the save method with the argument safe_serialization only if it's supported
save_method_signature = inspect.signature(save_method)
save_method_accept_safe = "safe_serialization" in save_method_signature.parameters
if save_method_accept_safe:
save_method(
os.path.join(save_directory, pipeline_component_name), safe_serialization=safe_serialization
)
else:
save_method(os.path.join(save_directory, pipeline_component_name))
save_method(os.path.join(save_directory, pipeline_component_name))
def to(self, torch_device: Optional[Union[str, torch.device]] = None):
if torch_device is None:
@@ -392,8 +349,7 @@ class DiffusionPipeline(ConfigMixin):
also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the
model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch,
setting this argument to `True` will raise an error.
return_cached_folder (`bool`, *optional*, defaults to `False`):
If set to `True`, path to downloaded cached folder will be returned in addition to loaded pipeline.
kwargs (remaining dictionary of keyword arguments, *optional*):
Can be used to overwrite load - and saveable variables - *i.e.* the pipeline components - of the
specific pipeline class. The overwritten components are then directly passed to the pipelines
@@ -446,7 +402,33 @@ class DiffusionPipeline(ConfigMixin):
sess_options = kwargs.pop("sess_options", None)
device_map = kwargs.pop("device_map", None)
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
return_cached_folder = kwargs.pop("return_cached_folder", False)
if low_cpu_mem_usage and not is_accelerate_available():
low_cpu_mem_usage = False
logger.warn(
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
" install accelerate\n```\n."
)
if device_map is not None and not is_torch_version(">=", "1.9.0"):
raise NotImplementedError(
"Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
" `device_map=None`."
)
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
raise NotImplementedError(
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
" `low_cpu_mem_usage=False`."
)
if low_cpu_mem_usage is False and device_map is not None:
raise ValueError(
f"You cannot set `low_cpu_mem_usage` to False while using device_map={device_map} for loading and"
" dispatching. Please make sure to set `low_cpu_mem_usage=True`."
)
# 1. Download the checkpoints and configs
# use snapshot download here to get it working from from_pretrained
@@ -467,7 +449,7 @@ class DiffusionPipeline(ConfigMixin):
allow_patterns += [WEIGHTS_NAME, SCHEDULER_CONFIG_NAME, CONFIG_NAME, ONNX_WEIGHTS_NAME, cls.config_name]
# make sure we don't download flax weights
ignore_patterns = ["*.msgpack"]
ignore_patterns = "*.msgpack"
if custom_pipeline is not None:
allow_patterns += [CUSTOM_PIPELINE_FILE_NAME]
@@ -477,20 +459,10 @@ class DiffusionPipeline(ConfigMixin):
else:
requested_pipeline_class = config_dict.get("_class_name", cls.__name__)
user_agent = {"pipeline_class": requested_pipeline_class}
if custom_pipeline is not None and not custom_pipeline.endswith(".py"):
if custom_pipeline is not None:
user_agent["custom_pipeline"] = custom_pipeline
user_agent = http_user_agent(user_agent)
if is_safetensors_available():
info = model_info(
pretrained_model_name_or_path,
use_auth_token=use_auth_token,
revision=revision,
)
if is_safetensors_compatible(info):
ignore_patterns.append("*.bin")
# download all allow_patterns
cached_folder = snapshot_download(
pretrained_model_name_or_path,
@@ -504,16 +476,8 @@ class DiffusionPipeline(ConfigMixin):
ignore_patterns=ignore_patterns,
user_agent=user_agent,
)
send_telemetry(
{"pipeline_class": requested_pipeline_class, "pipeline_path": "hub", "framework": "pytorch"},
name="diffusers_from_pretrained",
)
else:
cached_folder = pretrained_model_name_or_path
send_telemetry(
{"pipeline_class": cls.__name__, "pipeline_path": "local", "framework": "pytorch"},
name="diffusers_from_pretrained",
)
config_dict = cls.load_config(cached_folder)
@@ -528,7 +492,9 @@ class DiffusionPipeline(ConfigMixin):
else:
file_name = CUSTOM_PIPELINE_FILE_NAME
pipeline_class = get_class_from_dynamic_module(custom_pipeline, module_file=file_name, cache_dir=cache_dir)
pipeline_class = get_class_from_dynamic_module(
custom_pipeline, module_file=file_name, cache_dir=custom_pipeline
)
elif cls != DiffusionPipeline:
pipeline_class = cls
else:
@@ -557,74 +523,38 @@ class DiffusionPipeline(ConfigMixin):
# some modules can be passed directly to the init
# in this case they are already instantiated in `kwargs`
# extract them here
expected_modules, optional_kwargs = cls._get_signature_keys(pipeline_class)
expected_modules = set(inspect.signature(pipeline_class.__init__).parameters.keys()) - set(["self"])
passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs}
init_dict, unused_kwargs, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
# define init kwargs
init_kwargs = {k: init_dict.pop(k) for k in optional_kwargs if k in init_dict}
init_kwargs = {**init_kwargs, **passed_pipe_kwargs}
# remove `null` components
def load_module(name, value):
if value[0] is None:
return False
if name in passed_class_obj and passed_class_obj[name] is None:
return False
return True
init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)}
if len(unused_kwargs) > 0:
logger.warning(
f"Keyword arguments {unused_kwargs} are not expected by {pipeline_class.__name__} and will be ignored."
)
logger.warning(f"Keyword arguments {unused_kwargs} not recognized.")
if low_cpu_mem_usage and not is_accelerate_available():
low_cpu_mem_usage = False
logger.warning(
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
" install accelerate\n```\n."
)
if device_map is not None and not is_torch_version(">=", "1.9.0"):
raise NotImplementedError(
"Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
" `device_map=None`."
)
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
raise NotImplementedError(
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
" `low_cpu_mem_usage=False`."
)
if low_cpu_mem_usage is False and device_map is not None:
raise ValueError(
f"You cannot set `low_cpu_mem_usage` to False while using device_map={device_map} for loading and"
" dispatching. Please make sure to set `low_cpu_mem_usage=True`."
)
init_kwargs = {}
# import it here to avoid circular import
from diffusers import pipelines
# 3. Load each module in the pipeline
for name, (library_name, class_name) in init_dict.items():
if class_name is None:
# edge case for when the pipeline was saved with safety_checker=None
init_kwargs[name] = None
continue
# 3.1 - now that JAX/Flax is an official framework of the library, we might load from Flax names
if class_name.startswith("Flax"):
class_name = class_name[4:]
is_pipeline_module = hasattr(pipelines, library_name)
loaded_sub_model = None
sub_model_should_be_defined = True
# if the model is in a pipeline module, then we load it from the pipeline
if name in passed_class_obj:
# 1. check that passed_class_obj has correct parent class
if not is_pipeline_module:
if not is_pipeline_module and passed_class_obj[name] is not None:
library = importlib.import_module(library_name)
class_obj = getattr(library, class_name)
importable_classes = LOADABLE_CLASSES[library_name]
@@ -640,8 +570,14 @@ class DiffusionPipeline(ConfigMixin):
f"{passed_class_obj[name]} is of type: {type(passed_class_obj[name])}, but should be"
f" {expected_class_obj}"
)
elif passed_class_obj[name] is None:
logger.warn(
f"You have passed `None` for {name} to disable its functionality in {pipeline_class}. Note"
f" that this might lead to problems when using {pipeline_class} and is not recommended."
)
sub_model_should_be_defined = False
else:
logger.warning(
logger.warn(
f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it"
" has the correct type"
)
@@ -661,7 +597,7 @@ class DiffusionPipeline(ConfigMixin):
importable_classes = LOADABLE_CLASSES[library_name]
class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()}
if loaded_sub_model is None:
if loaded_sub_model is None and sub_model_should_be_defined:
load_method_name = None
for class_name, class_candidate in class_candidates.items():
if class_candidate is not None and issubclass(class_obj, class_candidate):
@@ -715,32 +651,19 @@ class DiffusionPipeline(ConfigMixin):
# 4. Potentially add passed objects if expected
missing_modules = set(expected_modules) - set(init_kwargs.keys())
passed_modules = list(passed_class_obj.keys())
optional_modules = pipeline_class._optional_components
if len(missing_modules) > 0 and missing_modules <= set(passed_modules + optional_modules):
if len(missing_modules) > 0 and missing_modules <= set(passed_class_obj.keys()):
for module in missing_modules:
init_kwargs[module] = passed_class_obj.get(module, None)
init_kwargs[module] = passed_class_obj[module]
elif len(missing_modules) > 0:
passed_modules = set(list(init_kwargs.keys()) + list(passed_class_obj.keys())) - optional_kwargs
passed_modules = set(list(init_kwargs.keys()) + list(passed_class_obj.keys()))
raise ValueError(
f"Pipeline {pipeline_class} expected {expected_modules}, but only {passed_modules} were passed."
)
# 5. Instantiate the pipeline
model = pipeline_class(**init_kwargs)
if return_cached_folder:
return model, cached_folder
return model
@staticmethod
def _get_signature_keys(obj):
parameters = inspect.signature(obj.__init__).parameters
required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty}
optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty})
expected_modules = set(required_parameters.keys()) - set(["self"])
return expected_modules, optional_parameters
@property
def components(self) -> Dict[str, Any]:
r"""
@@ -765,10 +688,8 @@ class DiffusionPipeline(ConfigMixin):
Returns:
A dictionaly containing all the modules needed to initialize the pipeline.
"""
expected_modules, optional_parameters = self._get_signature_keys(self)
components = {
k: getattr(self, k) for k in self.config.keys() if not k.startswith("_") and k not in optional_parameters
}
components = {k: getattr(self, k) for k in self.config.keys() if not k.startswith("_")}
expected_modules = set(inspect.signature(self.__init__).parameters.keys()) - set(["self"])
if set(components.keys()) != expected_modules:
raise ValueError(
@@ -794,7 +715,7 @@ class DiffusionPipeline(ConfigMixin):
return pil_images
def progress_bar(self, iterable=None, total=None):
def progress_bar(self, iterable):
if not hasattr(self, "_progress_bar_config"):
self._progress_bar_config = {}
elif not isinstance(self._progress_bar_config, dict):
@@ -802,78 +723,7 @@ class DiffusionPipeline(ConfigMixin):
f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}."
)
if iterable is not None:
return tqdm(iterable, **self._progress_bar_config)
elif total is not None:
return tqdm(total=total, **self._progress_bar_config)
else:
raise ValueError("Either `total` or `iterable` has to be defined.")
return tqdm(iterable, **self._progress_bar_config)
def set_progress_bar_config(self, **kwargs):
self._progress_bar_config = kwargs
def enable_xformers_memory_efficient_attention(self):
r"""
Enable memory efficient attention as implemented in xformers.
When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
time. Speed up at training time is not guaranteed.
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
is used.
"""
self.set_use_memory_efficient_attention_xformers(True)
def disable_xformers_memory_efficient_attention(self):
r"""
Disable memory efficient attention as implemented in xformers.
"""
self.set_use_memory_efficient_attention_xformers(False)
def set_use_memory_efficient_attention_xformers(self, valid: bool) -> None:
# Recursively walk through all the children.
# Any children which exposes the set_use_memory_efficient_attention_xformers method
# gets the message
def fn_recursive_set_mem_eff(module: torch.nn.Module):
if hasattr(module, "set_use_memory_efficient_attention_xformers"):
module.set_use_memory_efficient_attention_xformers(valid)
for child in module.children():
fn_recursive_set_mem_eff(child)
module_names, _, _ = self.extract_init_dict(dict(self.config))
for module_name in module_names:
module = getattr(self, module_name)
if isinstance(module, torch.nn.Module):
fn_recursive_set_mem_eff(module)
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
r"""
Enable sliced attention computation.
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
in several steps. This is useful to save some memory in exchange for a small speed decrease.
Args:
slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
`"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
must be a multiple of `slice_size`.
"""
self.set_attention_slice(slice_size)
def disable_attention_slicing(self):
r"""
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
back to computing attention in one step.
"""
# set slice_size = `None` to disable `attention slicing`
self.enable_attention_slicing(None)
def set_attention_slice(self, slice_size: Optional[int]):
module_names, _, _ = self.extract_init_dict(dict(self.config))
for module_name in module_names:
module = getattr(self, module_name)
if isinstance(module, torch.nn.Module) and hasattr(module, "set_attention_slice"):
module.set_attention_slice(slice_size)

View File

@@ -126,7 +126,7 @@ init_image = init_image.resize((768, 512))
prompt = "A fantasy landscape, trending on artstation"
images = pipe(prompt=prompt, image=init_image, strength=0.75, guidance_scale=7.5).images
images = pipe(prompt=prompt, init_image=init_image, strength=0.75, guidance_scale=7.5).images
images[0].save("fantasy_landscape.png")
```

View File

@@ -1,20 +1,7 @@
from ..utils import (
OptionalDependencyNotAvailable,
is_flax_available,
is_k_diffusion_available,
is_librosa_available,
is_onnx_available,
is_torch_available,
is_transformers_available,
)
from ..utils import is_flax_available, is_onnx_available, is_torch_available, is_transformers_available
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ..utils.dummy_pt_objects import * # noqa F403
else:
if is_torch_available():
from .dance_diffusion import DanceDiffusionPipeline
from .ddim import DDIMPipeline
from .ddpm import DDPMPipeline
@@ -24,49 +11,22 @@ else:
from .repaint import RePaintPipeline
from .score_sde_ve import ScoreSdeVePipeline
from .stochastic_karras_ve import KarrasVePipeline
try:
if not (is_torch_available() and is_librosa_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ..utils.dummy_torch_and_librosa_objects import * # noqa F403
else:
from .audio_diffusion import AudioDiffusionPipeline, Mel
from ..utils.dummy_pt_objects import * # noqa F403
try:
if not (is_torch_available() and is_transformers_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ..utils.dummy_torch_and_transformers_objects import * # noqa F403
else:
if is_torch_available() and is_transformers_available():
from .alt_diffusion import AltDiffusionImg2ImgPipeline, AltDiffusionPipeline
from .latent_diffusion import LDMTextToImagePipeline
from .paint_by_example import PaintByExamplePipeline
from .stable_diffusion import (
CycleDiffusionPipeline,
StableDiffusionDepth2ImgPipeline,
StableDiffusionImageVariationPipeline,
StableDiffusionImg2ImgPipeline,
StableDiffusionInpaintPipeline,
StableDiffusionInpaintPipelineLegacy,
StableDiffusionPipeline,
StableDiffusionUpscalePipeline,
)
from .stable_diffusion_safe import StableDiffusionPipelineSafe
from .versatile_diffusion import (
VersatileDiffusionDualGuidedPipeline,
VersatileDiffusionImageVariationPipeline,
VersatileDiffusionPipeline,
VersatileDiffusionTextToImagePipeline,
)
from .vq_diffusion import VQDiffusionPipeline
try:
if not (is_torch_available() and is_transformers_available() and is_onnx_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ..utils.dummy_torch_and_transformers_and_onnx_objects import * # noqa F403
else:
if is_transformers_available() and is_onnx_available():
from .stable_diffusion import (
OnnxStableDiffusionImg2ImgPipeline,
OnnxStableDiffusionInpaintPipeline,
@@ -75,19 +35,5 @@ else:
StableDiffusionOnnxPipeline,
)
try:
if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ..utils.dummy_torch_and_transformers_and_k_diffusion_objects import * # noqa F403
else:
from .stable_diffusion import StableDiffusionKDiffusionPipeline
try:
if not (is_flax_available() and is_transformers_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ..utils.dummy_flax_and_transformers_objects import * # noqa F403
else:
if is_transformers_available() and is_flax_available():
from .stable_diffusion import FlaxStableDiffusionPipeline

View File

@@ -18,7 +18,6 @@ from typing import Callable, List, Optional, Union
import torch
from diffusers.utils import is_accelerate_available
from packaging import version
from transformers import CLIPFeatureExtractor, XLMRobertaTokenizer
from ...configuration_utils import FrozenDict
@@ -68,7 +67,6 @@ class AltDiffusionPipeline(DiffusionPipeline):
feature_extractor ([`CLIPFeatureExtractor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
_optional_components = ["safety_checker", "feature_extractor"]
def __init__(
self,
@@ -86,7 +84,6 @@ class AltDiffusionPipeline(DiffusionPipeline):
],
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
requires_safety_checker: bool = True,
):
super().__init__()
@@ -117,8 +114,8 @@ class AltDiffusionPipeline(DiffusionPipeline):
new_config["clip_sample"] = False
scheduler._internal_dict = FrozenDict(new_config)
if safety_checker is None and requires_safety_checker:
logger.warning(
if safety_checker is None:
logger.warn(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Alt Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
@@ -127,33 +124,6 @@ class AltDiffusionPipeline(DiffusionPipeline):
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
)
if safety_checker is not None and feature_extractor is None:
raise ValueError(
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
version.parse(unet.config._diffusers_version).base_version
) < version.parse("0.9.0.dev0")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
" the `unet/config.json` file"
)
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(unet.config)
new_config["sample_size"] = 64
unet._internal_dict = FrozenDict(new_config)
self.register_modules(
vae=vae,
text_encoder=text_encoder,
@@ -163,24 +133,51 @@ class AltDiffusionPipeline(DiffusionPipeline):
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.register_to_config(requires_safety_checker=requires_safety_checker)
def enable_vae_slicing(self):
def enable_xformers_memory_efficient_attention(self):
r"""
Enable sliced VAE decoding.
Enable memory efficient attention as implemented in xformers.
When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
steps. This is useful to save some memory and allow larger batch sizes.
When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
time. Speed up at training time is not guaranteed.
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
is used.
"""
self.vae.enable_slicing()
self.unet.set_use_memory_efficient_attention_xformers(True)
def disable_vae_slicing(self):
def disable_xformers_memory_efficient_attention(self):
r"""
Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
computing decoding in one step.
Disable memory efficient attention as implemented in xformers.
"""
self.vae.disable_slicing()
self.unet.set_use_memory_efficient_attention_xformers(False)
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
r"""
Enable sliced attention computation.
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
in several steps. This is useful to save some memory in exchange for a small speed decrease.
Args:
slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
`attention_head_dim` must be a multiple of `slice_size`.
"""
if slice_size == "auto":
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = self.unet.config.attention_head_dim // 2
self.unet.set_attention_slice(slice_size)
def disable_attention_slicing(self):
r"""
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
back to computing attention in one step.
"""
# set slice_size = `None` to disable `attention slicing`
self.enable_attention_slicing(None)
def enable_sequential_cpu_offload(self, gpu_id=0):
r"""
@@ -195,15 +192,10 @@ class AltDiffusionPipeline(DiffusionPipeline):
device = torch.device(f"cuda:{gpu_id}")
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]:
if cpu_offloaded_model is not None:
cpu_offload(cpu_offloaded_model, device)
if self.safety_checker is not None:
# TODO(Patrick) - there is currently a bug with cpu offload of nn.Parameter in accelerate
# fix by only offloading self.safety_checker for now
cpu_offload(self.safety_checker.vision_model, device)
@property
def _execution_device(self):
r"""
@@ -378,7 +370,7 @@ class AltDiffusionPipeline(DiffusionPipeline):
)
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
shape = (batch_size, num_channels_latents, height // 8, width // 8)
if latents is None:
if device.type == "mps":
# randn does not work reproducibly on mps
@@ -398,8 +390,8 @@ class AltDiffusionPipeline(DiffusionPipeline):
def __call__(
self,
prompt: Union[str, List[str]],
height: Optional[int] = None,
width: Optional[int] = None,
height: int = 512,
width: int = 512,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
@@ -411,6 +403,7 @@ class AltDiffusionPipeline(DiffusionPipeline):
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: Optional[int] = 1,
**kwargs,
):
r"""
Function invoked when calling the pipeline for generation.
@@ -418,9 +411,9 @@ class AltDiffusionPipeline(DiffusionPipeline):
Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
height (`int`, *optional*, defaults to 512):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
width (`int`, *optional*, defaults to 512):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
@@ -466,9 +459,6 @@ class AltDiffusionPipeline(DiffusionPipeline):
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`.
"""
# 0. Default height and width to unet
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor
# 1. Check inputs. Raise error if not correct
self.check_inputs(prompt, height, width, callback_steps)
@@ -507,29 +497,25 @@ class AltDiffusionPipeline(DiffusionPipeline):
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 7. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
for i, t in enumerate(self.progress_bar(timesteps)):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
# call the callback, if provided
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
# 8. Post-processing
image = self.decode_latents(latents)

Some files were not shown because too many files have changed in this diff Show More