Compare commits

..

1 Commits

Author SHA1 Message Date
Patrick von Platen
b444122ab0 [Tests] Refactor integration tests 2022-10-31 09:51:04 +00:00
124 changed files with 753 additions and 9228 deletions

View File

@@ -1,50 +0,0 @@
name: Build Docker images (nightly)
on:
workflow_dispatch:
schedule:
- cron: "0 0 * * *" # every day at midnight
concurrency:
group: docker-image-builds
cancel-in-progress: false
env:
REGISTRY: diffusers
jobs:
build-docker-images:
runs-on: ubuntu-latest
permissions:
contents: read
packages: write
strategy:
fail-fast: false
matrix:
image-name:
- diffusers-pytorch-cpu
- diffusers-pytorch-cuda
- diffusers-flax-cpu
- diffusers-flax-tpu
- diffusers-onnxruntime-cpu
- diffusers-onnxruntime-cuda
steps:
- name: Checkout repository
uses: actions/checkout@v3
- name: Login to Docker Hub
uses: docker/login-action@v2
with:
username: ${{ env.REGISTRY }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Build and push
uses: docker/build-push-action@v3
with:
no-cache: true
context: ./docker/${{ matrix.image-name }}
push: true
tags: ${{ env.REGISTRY }}/${{ matrix.image-name }}:latest

View File

@@ -10,46 +10,19 @@ concurrency:
cancel-in-progress: true
env:
DIFFUSERS_IS_CI: yes
OMP_NUM_THREADS: 4
MKL_NUM_THREADS: 4
OMP_NUM_THREADS: 8
MKL_NUM_THREADS: 8
PYTEST_TIMEOUT: 60
MPS_TORCH_VERSION: 1.13.0
jobs:
run_fast_tests:
strategy:
fail-fast: false
matrix:
config:
- name: Fast PyTorch CPU tests on Ubuntu
framework: pytorch
runner: docker-cpu
image: diffusers/diffusers-pytorch-cpu
report: torch_cpu
- name: Fast Flax CPU tests on Ubuntu
framework: flax
runner: docker-cpu
image: diffusers/diffusers-flax-cpu
report: flax_cpu
- name: Fast ONNXRuntime CPU tests on Ubuntu
framework: onnxruntime
runner: docker-cpu
image: diffusers/diffusers-onnxruntime-cpu
report: onnx_cpu
name: ${{ matrix.config.name }}
runs-on: ${{ matrix.config.runner }}
run_tests_cpu:
name: CPU tests on Ubuntu
runs-on: [ self-hosted, docker-gpu ]
container:
image: ${{ matrix.config.image }}
image: python:3.7
options: --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache/
defaults:
run:
shell: bash
steps:
- name: Checkout diffusers
uses: actions/checkout@v3
@@ -58,6 +31,8 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install torch --extra-index-url https://download.pytorch.org/whl/cpu
python -m pip install -e .[quality,test]
python -m pip install git+https://github.com/huggingface/accelerate
@@ -65,43 +40,23 @@ jobs:
run: |
python utils/print_env.py
- name: Run fast PyTorch CPU tests
if: ${{ matrix.config.framework == 'pytorch' }}
- name: Run all fast tests on CPU
run: |
python -m pytest -n 2 --max-worker-restart=0 --dist=loadfile \
-s -v -k "not Flax and not Onnx" \
--make-reports=tests_${{ matrix.config.report }} \
tests/
- name: Run fast Flax TPU tests
if: ${{ matrix.config.framework == 'flax' }}
run: |
python -m pytest -n 2 --max-worker-restart=0 --dist=loadfile \
-s -v -k "Flax" \
--make-reports=tests_${{ matrix.config.report }} \
tests/
- name: Run fast ONNXRuntime CPU tests
if: ${{ matrix.config.framework == 'onnxruntime' }}
run: |
python -m pytest -n 2 --max-worker-restart=0 --dist=loadfile \
-s -v -k "Onnx" \
--make-reports=tests_${{ matrix.config.report }} \
tests/
python -m pytest -n 2 --max-worker-restart=0 --dist=loadfile -s -v --make-reports=tests_torch_cpu tests/
- name: Failure short reports
if: ${{ failure() }}
run: cat reports/tests_${{ matrix.config.report }}_failures_short.txt
run: cat reports/tests_torch_cpu_failures_short.txt
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v2
with:
name: pr_${{ matrix.config.report }}_test_reports
name: pr_torch_cpu_test_reports
path: reports
run_fast_tests_apple_m1:
name: Fast PyTorch MPS tests on MacOS
run_tests_apple_m1:
name: MPS tests on Apple M1
runs-on: [ self-hosted, apple-m1 ]
steps:
@@ -133,7 +88,7 @@ jobs:
run: |
${CONDA_RUN} python utils/print_env.py
- name: Run fast PyTorch tests on M1 (MPS)
- name: Run all fast tests on MPS
shell: arch -arch arm64 bash {0}
run: |
${CONDA_RUN} python -m pytest -n 1 -s -v --make-reports=tests_torch_mps tests/

View File

@@ -6,7 +6,6 @@ on:
- main
env:
DIFFUSERS_IS_CI: yes
HF_HOME: /mnt/cache
OMP_NUM_THREADS: 8
MKL_NUM_THREADS: 8
@@ -14,38 +13,12 @@ env:
RUN_SLOW: yes
jobs:
run_slow_tests:
strategy:
fail-fast: false
matrix:
config:
- name: Slow PyTorch CUDA tests on Ubuntu
framework: pytorch
runner: docker-gpu
image: diffusers/diffusers-pytorch-cuda
report: torch_cuda
- name: Slow Flax TPU tests on Ubuntu
framework: flax
runner: docker-tpu
image: diffusers/diffusers-flax-tpu
report: flax_tpu
- name: Slow ONNXRuntime CUDA tests on Ubuntu
framework: onnxruntime
runner: docker-gpu
image: diffusers/diffusers-onnxruntime-cuda
report: onnx_cuda
name: ${{ matrix.config.name }}
runs-on: ${{ matrix.config.runner }}
run_tests_single_gpu:
name: Diffusers tests
runs-on: [ self-hosted, docker-gpu, single-gpu ]
container:
image: ${{ matrix.config.image }}
options: --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache/ ${{ matrix.config.runner == 'docker-tpu' && '--privileged' || '--gpus 0'}}
defaults:
run:
shell: bash
image: nvcr.io/nvidia/pytorch:22.07-py3
options: --gpus 0 --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache
steps:
- name: Checkout diffusers
@@ -54,12 +27,14 @@ jobs:
fetch-depth: 2
- name: NVIDIA-SMI
if : ${{ matrix.config.runner == 'docker-gpu' }}
run: |
nvidia-smi
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip uninstall -y torch torchvision torchtext
python -m pip install torch --extra-index-url https://download.pytorch.org/whl/cu117
python -m pip install -e .[quality,test]
python -m pip install git+https://github.com/huggingface/accelerate
@@ -67,55 +42,29 @@ jobs:
run: |
python utils/print_env.py
- name: Run slow PyTorch CUDA tests
if: ${{ matrix.config.framework == 'pytorch' }}
- name: Run all (incl. slow) tests on GPU
env:
HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }}
run: |
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
-s -v -k "not Flax and not Onnx" \
--make-reports=tests_${{ matrix.config.report }} \
tests/
- name: Run slow Flax TPU tests
if: ${{ matrix.config.framework == 'flax' }}
env:
HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }}
run: |
python -m pytest -n 0 \
-s -v -k "Flax" \
--make-reports=tests_${{ matrix.config.report }} \
tests/
- name: Run slow ONNXRuntime CUDA tests
if: ${{ matrix.config.framework == 'onnxruntime' }}
env:
HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }}
run: |
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
-s -v -k "Onnx" \
--make-reports=tests_${{ matrix.config.report }} \
tests/
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v --make-reports=tests_torch_gpu tests/
- name: Failure short reports
if: ${{ failure() }}
run: cat reports/tests_${{ matrix.config.report }}_failures_short.txt
run: cat reports/tests_torch_gpu_failures_short.txt
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v2
with:
name: ${{ matrix.config.report }}_test_reports
name: torch_test_reports
path: reports
run_examples_tests:
name: Examples PyTorch CUDA tests on Ubuntu
runs-on: docker-gpu
run_examples_single_gpu:
name: Examples tests
runs-on: [ self-hosted, docker-gpu, single-gpu ]
container:
image: diffusers/diffusers-pytorch-cuda
options: --gpus 0 --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache/
image: nvcr.io/nvidia/pytorch:22.07-py3
options: --gpus 0 --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache
steps:
- name: Checkout diffusers
@@ -129,6 +78,9 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip uninstall -y torch torchvision torchtext
python -m pip install torch --extra-index-url https://download.pytorch.org/whl/cu117
python -m pip install -e .[quality,test,training]
python -m pip install git+https://github.com/huggingface/accelerate
@@ -140,11 +92,11 @@ jobs:
env:
HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }}
run: |
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v --make-reports=examples_torch_cuda examples/
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v --make-reports=examples_torch_gpu examples/
- name: Failure short reports
if: ${{ failure() }}
run: cat reports/examples_torch_cuda_failures_short.txt
run: cat reports/examples_torch_gpu_failures_short.txt
- name: Test suite reports artifacts
if: ${{ always() }}

View File

@@ -27,12 +27,10 @@ More precisely, 🤗 Diffusers offers:
## Installation
### For PyTorch
**With `pip`**
```bash
pip install --upgrade diffusers[torch]
pip install --upgrade diffusers
```
**With `conda`**
@@ -41,14 +39,6 @@ pip install --upgrade diffusers[torch]
conda install -c conda-forge diffusers
```
### For Flax
**With `pip`**
```bash
pip install --upgrade diffusers[flax]
```
**Apple Silicon (M1/M2) support**
Please, refer to [the documentation](https://huggingface.co/docs/diffusers/optimization/mps).
@@ -152,7 +142,11 @@ it before the pipeline and pass it to `from_pretrained`.
```python
from diffusers import LMSDiscreteScheduler
lms = LMSDiscreteScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
lms = LMSDiscreteScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear"
)
pipe = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
@@ -364,7 +358,7 @@ There are many ways to try running Diffusers! Here we outline code-focused tools
If you want to run the code yourself 💻, you can try out:
- [Text-to-Image Latent Diffusion](https://huggingface.co/CompVis/ldm-text2im-large-256)
```python
# !pip install diffusers["torch"] transformers
# !pip install diffusers transformers
from diffusers import DiffusionPipeline
device = "cuda"
@@ -383,7 +377,7 @@ image.save("squirrel.png")
```
- [Unconditional Diffusion with discrete scheduler](https://huggingface.co/google/ddpm-celebahq-256)
```python
# !pip install diffusers["torch"]
# !pip install diffusers
from diffusers import DDPMPipeline, DDIMPipeline, PNDMPipeline
model_id = "google/ddpm-celebahq-256"

View File

@@ -1,42 +0,0 @@
FROM ubuntu:20.04
LABEL maintainer="Hugging Face"
LABEL repository="diffusers"
ENV DEBIAN_FRONTEND=noninteractive
RUN apt update && \
apt install -y bash \
build-essential \
git \
git-lfs \
curl \
ca-certificates \
python3.8 \
python3-pip \
python3.8-venv && \
rm -rf /var/lib/apt/lists
# make sure to use venv
RUN python3 -m venv /opt/venv
ENV PATH="/opt/venv/bin:$PATH"
# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
# follow the instructions here: https://cloud.google.com/tpu/docs/run-in-container#train_a_jax_model_in_a_docker_container
RUN python3 -m pip install --no-cache-dir --upgrade pip && \
python3 -m pip install --upgrade --no-cache-dir \
clu \
"jax[cpu]>=0.2.16,!=0.3.2" \
"flax>=0.4.1" \
"jaxlib>=0.1.65" && \
python3 -m pip install --no-cache-dir \
accelerate \
datasets \
hf-doc-builder \
huggingface-hub \
modelcards \
numpy \
scipy \
tensorboard \
transformers
CMD ["/bin/bash"]

View File

@@ -1,44 +0,0 @@
FROM ubuntu:20.04
LABEL maintainer="Hugging Face"
LABEL repository="diffusers"
ENV DEBIAN_FRONTEND=noninteractive
RUN apt update && \
apt install -y bash \
build-essential \
git \
git-lfs \
curl \
ca-certificates \
python3.8 \
python3-pip \
python3.8-venv && \
rm -rf /var/lib/apt/lists
# make sure to use venv
RUN python3 -m venv /opt/venv
ENV PATH="/opt/venv/bin:$PATH"
# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
# follow the instructions here: https://cloud.google.com/tpu/docs/run-in-container#train_a_jax_model_in_a_docker_container
RUN python3 -m pip install --no-cache-dir --upgrade pip && \
python3 -m pip install --no-cache-dir \
"jax[tpu]>=0.2.16,!=0.3.2" \
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html && \
python3 -m pip install --upgrade --no-cache-dir \
clu \
"flax>=0.4.1" \
"jaxlib>=0.1.65" && \
python3 -m pip install --no-cache-dir \
accelerate \
datasets \
hf-doc-builder \
huggingface-hub \
modelcards \
numpy \
scipy \
tensorboard \
transformers
CMD ["/bin/bash"]

View File

@@ -1,42 +0,0 @@
FROM ubuntu:20.04
LABEL maintainer="Hugging Face"
LABEL repository="diffusers"
ENV DEBIAN_FRONTEND=noninteractive
RUN apt update && \
apt install -y bash \
build-essential \
git \
git-lfs \
curl \
ca-certificates \
python3.8 \
python3-pip \
python3.8-venv && \
rm -rf /var/lib/apt/lists
# make sure to use venv
RUN python3 -m venv /opt/venv
ENV PATH="/opt/venv/bin:$PATH"
# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
RUN python3 -m pip install --no-cache-dir --upgrade pip && \
python3 -m pip install --no-cache-dir \
torch \
torchvision \
torchaudio \
onnxruntime \
--extra-index-url https://download.pytorch.org/whl/cpu && \
python3 -m pip install --no-cache-dir \
accelerate \
datasets \
hf-doc-builder \
huggingface-hub \
modelcards \
numpy \
scipy \
tensorboard \
transformers
CMD ["/bin/bash"]

View File

@@ -1,42 +0,0 @@
FROM nvidia/cuda:11.6.2-cudnn8-devel-ubuntu20.04
LABEL maintainer="Hugging Face"
LABEL repository="diffusers"
ENV DEBIAN_FRONTEND=noninteractive
RUN apt update && \
apt install -y bash \
build-essential \
git \
git-lfs \
curl \
ca-certificates \
python3.8 \
python3-pip \
python3.8-venv && \
rm -rf /var/lib/apt/lists
# make sure to use venv
RUN python3 -m venv /opt/venv
ENV PATH="/opt/venv/bin:$PATH"
# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
RUN python3 -m pip install --no-cache-dir --upgrade pip && \
python3 -m pip install --no-cache-dir \
torch \
torchvision \
torchaudio \
"onnxruntime-gpu>=1.13.1" \
--extra-index-url https://download.pytorch.org/whl/cu117 && \
python3 -m pip install --no-cache-dir \
accelerate \
datasets \
hf-doc-builder \
huggingface-hub \
modelcards \
numpy \
scipy \
tensorboard \
transformers
CMD ["/bin/bash"]

View File

@@ -1,41 +0,0 @@
FROM ubuntu:20.04
LABEL maintainer="Hugging Face"
LABEL repository="diffusers"
ENV DEBIAN_FRONTEND=noninteractive
RUN apt update && \
apt install -y bash \
build-essential \
git \
git-lfs \
curl \
ca-certificates \
python3.8 \
python3-pip \
python3.8-venv && \
rm -rf /var/lib/apt/lists
# make sure to use venv
RUN python3 -m venv /opt/venv
ENV PATH="/opt/venv/bin:$PATH"
# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
RUN python3 -m pip install --no-cache-dir --upgrade pip && \
python3 -m pip install --no-cache-dir \
torch \
torchvision \
torchaudio \
--extra-index-url https://download.pytorch.org/whl/cpu && \
python3 -m pip install --no-cache-dir \
accelerate \
datasets \
hf-doc-builder \
huggingface-hub \
modelcards \
numpy \
scipy \
tensorboard \
transformers
CMD ["/bin/bash"]

View File

@@ -1,41 +0,0 @@
FROM nvidia/cuda:11.7.1-cudnn8-runtime-ubuntu20.04
LABEL maintainer="Hugging Face"
LABEL repository="diffusers"
ENV DEBIAN_FRONTEND=noninteractive
RUN apt update && \
apt install -y bash \
build-essential \
git \
git-lfs \
curl \
ca-certificates \
python3.8 \
python3-pip \
python3.8-venv && \
rm -rf /var/lib/apt/lists
# make sure to use venv
RUN python3 -m venv /opt/venv
ENV PATH="/opt/venv/bin:$PATH"
# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
RUN python3 -m pip install --no-cache-dir --upgrade pip && \
python3 -m pip install --no-cache-dir \
torch \
torchvision \
torchaudio \
--extra-index-url https://download.pytorch.org/whl/cu117 && \
python3 -m pip install --no-cache-dir \
accelerate \
datasets \
hf-doc-builder \
huggingface-hub \
modelcards \
numpy \
scipy \
tensorboard \
transformers
CMD ["/bin/bash"]

View File

@@ -78,8 +78,6 @@
- sections:
- local: api/pipelines/overview
title: "Overview"
- local: api/pipelines/cycle_diffusion
title: "Cycle Diffusion"
- local: api/pipelines/ddim
title: "DDIM"
- local: api/pipelines/ddpm
@@ -98,9 +96,5 @@
title: "Stochastic Karras VE"
- local: api/pipelines/dance_diffusion
title: "Dance Diffusion"
- local: api/pipelines/vq_diffusion
title: "VQ Diffusion"
- local: api/pipelines/repaint
title: "RePaint"
title: "Pipelines"
title: "API"

View File

@@ -49,12 +49,6 @@ The models are built on the base class ['ModelMixin'] that is a `torch.nn.module
## AutoencoderKL
[[autodoc]] AutoencoderKL
## Transformer2DModel
[[autodoc]] Transformer2DModel
## Transformer2DModelOutput
[[autodoc]] models.attention.Transformer2DModelOutput
## FlaxModelMixin
[[autodoc]] FlaxModelMixin

View File

@@ -1,99 +0,0 @@
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Cycle Diffusion
## Overview
Cycle Diffusion is a Text-Guided Image-to-Image Generation model proposed in [Unifying Diffusion Models' Latent Space, with Applications to CycleDiffusion and Guidance](https://arxiv.org/abs/2210.05559) by Chen Henry Wu, Fernando De la Torre.
The abstract of the paper is the following:
*Diffusion models have achieved unprecedented performance in generative modeling. The commonly-adopted formulation of the latent code of diffusion models is a sequence of gradually denoised samples, as opposed to the simpler (e.g., Gaussian) latent space of GANs, VAEs, and normalizing flows. This paper provides an alternative, Gaussian formulation of the latent space of various diffusion models, as well as an invertible DPM-Encoder that maps images into the latent space. While our formulation is purely based on the definition of diffusion models, we demonstrate several intriguing consequences. (1) Empirically, we observe that a common latent space emerges from two diffusion models trained independently on related domains. In light of this finding, we propose CycleDiffusion, which uses DPM-Encoder for unpaired image-to-image translation. Furthermore, applying CycleDiffusion to text-to-image diffusion models, we show that large-scale text-to-image diffusion models can be used as zero-shot image-to-image editors. (2) One can guide pre-trained diffusion models and GANs by controlling the latent codes in a unified, plug-and-play formulation based on energy-based models. Using the CLIP model and a face recognition model as guidance, we demonstrate that diffusion models have better coverage of low-density sub-populations and individuals than GANs.*
*Tips*:
- The Cycle Diffusion pipeline is fully compatible with any [Stable Diffusion](./stable_diffusion) checkpoints
- Currently Cycle Diffusion only works with the [`DDIMScheduler`].
*Example*:
In the following we should how to best use the [`CycleDiffusionPipeline`]
```python
import requests
import torch
from PIL import Image
from io import BytesIO
from diffusers import CycleDiffusionPipeline, DDIMScheduler
# load the pipeline
# make sure you're logged in with `huggingface-cli login`
model_id_or_path = "CompVis/stable-diffusion-v1-4"
scheduler = DDIMScheduler.from_config(model_id_or_path, subfolder="scheduler")
pipe = CycleDiffusionPipeline.from_pretrained(model_id_or_path, scheduler=scheduler).to("cuda")
# let's download an initial image
url = "https://raw.githubusercontent.com/ChenWu98/cycle-diffusion/main/data/dalle2/An%20astronaut%20riding%20a%20horse.png"
response = requests.get(url)
init_image = Image.open(BytesIO(response.content)).convert("RGB")
init_image = init_image.resize((512, 512))
init_image.save("horse.png")
# let's specify a prompt
source_prompt = "An astronaut riding a horse"
prompt = "An astronaut riding an elephant"
# call the pipeline
image = pipe(
prompt=prompt,
source_prompt=source_prompt,
init_image=init_image,
num_inference_steps=100,
eta=0.1,
strength=0.8,
guidance_scale=2,
source_guidance_scale=1,
).images[0]
image.save("horse_to_elephant.png")
# let's try another example
# See more samples at the original repo: https://github.com/ChenWu98/cycle-diffusion
url = "https://raw.githubusercontent.com/ChenWu98/cycle-diffusion/main/data/dalle2/A%20black%20colored%20car.png"
response = requests.get(url)
init_image = Image.open(BytesIO(response.content)).convert("RGB")
init_image = init_image.resize((512, 512))
init_image.save("black.png")
source_prompt = "A black colored car"
prompt = "A blue colored car"
# call the pipeline
torch.manual_seed(0)
image = pipe(
prompt=prompt,
source_prompt=source_prompt,
init_image=init_image,
num_inference_steps=100,
eta=0.1,
strength=0.85,
guidance_scale=3,
source_guidance_scale=1,
).images[0]
image.save("black_to_blue.png")
```
## CycleDiffusionPipeline
[[autodoc]] CycleDiffusionPipeline
- __call__

View File

@@ -28,7 +28,7 @@ or created independently from each other.
To that end, we strive to offer all open-sourced, state-of-the-art diffusion system under a unified API.
More specifically, we strive to provide pipelines that
- 1. can load the officially published weights and yield 1-to-1 the same outputs as the original implementation according to the corresponding paper (*e.g.* [LDMTextToImagePipeline](https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines/latent_diffusion), uses the officially released weights of [High-Resolution Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752)),
- 1. can load the officially published weights and yield 1-to-1 the same outputs as the original implementation according to the corresponding paper (*e.g.* [LatentDiffusionPipeline](https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines/latent_diffusion), uses the officially released weights of [High-Resolution Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752)),
- 2. have a simple user interface to run the model in inference (see the [Pipelines API](#pipelines-api) section),
- 3. are easy to understand with code that is self-explanatory and can be read along-side the official paper (see [Pipelines summary](#pipelines-summary)),
- 4. can easily be contributed by the community (see the [Contribution](#contribution) section).
@@ -41,24 +41,19 @@ If you are looking for *official* training examples, please have a look at [exam
The following table summarizes all officially supported pipelines, their corresponding paper, and if
available a colab notebook to directly try them out.
| Pipeline | Paper | Tasks | Colab
|---|---|:---:|:---:|
| [cycle_diffusion](./api/pipelines/cycle_diffusion) | [**Cycle Diffusion**](https://arxiv.org/abs/2210.05559) | Image-to-Image Text-Guided Generation |
| [dance_diffusion](./api/pipelines/dance_diffusion) | [**Dance Diffusion**](https://github.com/williamberman/diffusers.git) | Unconditional Audio Generation |
| [ddpm](./api/pipelines/ddpm) | [**Denoising Diffusion Probabilistic Models**](https://arxiv.org/abs/2006.11239) | Unconditional Image Generation |
| [ddim](./api/pipelines/ddim) | [**Denoising Diffusion Implicit Models**](https://arxiv.org/abs/2010.02502) | Unconditional Image Generation |
| [latent_diffusion](./api/pipelines/latent_diffusion) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752)| Text-to-Image Generation |
| [latent_diffusion_uncond](./api/pipelines/latent_diffusion_uncond) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752) | Unconditional Image Generation |
| [pndm](./api/pipelines/pndm) | [**Pseudo Numerical Methods for Diffusion Models on Manifolds**](https://arxiv.org/abs/2202.09778) | Unconditional Image Generation |
| [score_sde_ve](./api/pipelines/score_sde_ve) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation |
| [score_sde_vp](./api/pipelines/score_sde_vp) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation |
| [stable_diffusion](./api/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-to-Image Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb)
| [stable_diffusion](./api/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Image-to-Image Text-Guided Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb)
| [stable_diffusion](./api/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-Guided Image Inpainting | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb)
| [stochastic_karras_ve](./api/pipelines/stochastic_karras_ve) | [**Elucidating the Design Space of Diffusion-Based Generative Models**](https://arxiv.org/abs/2206.00364) | Unconditional Image Generation |
| [vq_diffusion](./api/pipelines/vq_diffusion) | [Vector Quantized Diffusion Model for Text-to-Image Synthesis](https://arxiv.org/abs/2111.14822) | Text-to-Image Generation |
| [ddpm](./ddpm) | [**Denoising Diffusion Probabilistic Models**](https://arxiv.org/abs/2006.11239) | Unconditional Image Generation |
| [ddim](./ddim) | [**Denoising Diffusion Implicit Models**](https://arxiv.org/abs/2010.02502) | Unconditional Image Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb)
| [latent_diffusion](./latent_diffusion) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752)| Text-to-Image Generation |
| [latent_diffusion_uncond](./latent_diffusion_uncond) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752) | Unconditional Image Generation |
| [pndm](./pndm) | [**Pseudo Numerical Methods for Diffusion Models on Manifolds**](https://arxiv.org/abs/2202.09778) | Unconditional Image Generation |
| [score_sde_ve](./score_sde_ve) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation |
| [score_sde_vp](./score_sde_vp) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation |
| [stable_diffusion](./stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-to-Image Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb)
| [stable_diffusion](./stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Image-to-Image Text-Guided Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb)
| [stable_diffusion](./stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-Guided Image Inpainting | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb)
| [stochastic_karras_ve](./stochastic_karras_ve) | [**Elucidating the Design Space of Diffusion-Based Generative Models**](https://arxiv.org/abs/2206.00364) | Unconditional Image Generation |
**Note**: Pipelines are simple examples of how to play around with the diffusion systems as described in the corresponding papers.

View File

@@ -1,77 +0,0 @@
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# RePaint
## Overview
[RePaint: Inpainting using Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2201.09865) (PNDM) by Andreas Lugmayr, Martin Danelljan, Andres Romero, Fisher Yu, Radu Timofte, Luc Van Gool.
The abstract of the paper is the following:
Free-form inpainting is the task of adding new content to an image in the regions specified by an arbitrary binary mask. Most existing approaches train for a certain distribution of masks, which limits their generalization capabilities to unseen mask types. Furthermore, training with pixel-wise and perceptual losses often leads to simple textural extensions towards the missing areas instead of semantically meaningful generation. In this work, we propose RePaint: A Denoising Diffusion Probabilistic Model (DDPM) based inpainting approach that is applicable to even extreme masks. We employ a pretrained unconditional DDPM as the generative prior. To condition the generation process, we only alter the reverse diffusion iterations by sampling the unmasked regions using the given image information. Since this technique does not modify or condition the original DDPM network itself, the model produces high-quality and diverse output images for any inpainting form. We validate our method for both faces and general-purpose image inpainting using standard and extreme masks.
RePaint outperforms state-of-the-art Autoregressive, and GAN approaches for at least five out of six mask distributions.
The original codebase can be found [here](https://github.com/andreas128/RePaint).
## Available Pipelines:
| Pipeline | Tasks | Colab
|-------------------------------------------------------------------------------------------------------------------------------|--------------------|:---:|
| [pipeline_repaint.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/repaint/pipeline_repaint.py) | *Image Inpainting* | - |
## Usage example
```python
from io import BytesIO
import torch
import PIL
import requests
from diffusers import RePaintPipeline, RePaintScheduler
def download_image(url):
response = requests.get(url)
return PIL.Image.open(BytesIO(response.content)).convert("RGB")
img_url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/repaint/celeba_hq_256.png"
mask_url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/repaint/mask_256.png"
# Load the original image and the mask as PIL images
original_image = download_image(img_url).resize((256, 256))
mask_image = download_image(mask_url).resize((256, 256))
# Load the RePaint scheduler and pipeline based on a pretrained DDPM model
scheduler = RePaintScheduler.from_config("google/ddpm-ema-celebahq-256")
pipe = RePaintPipeline.from_pretrained("google/ddpm-ema-celebahq-256", scheduler=scheduler)
pipe = pipe.to("cuda")
generator = torch.Generator(device="cuda").manual_seed(0)
output = pipe(
original_image=original_image,
mask_image=mask_image,
num_inference_steps=250,
eta=0.0,
jump_length=10,
jump_n_sample=10,
generator=generator,
)
inpainted_image = output.images[0]
```
## RePaintPipeline
[[autodoc]] pipelines.repaint.pipeline_repaint.RePaintPipeline
- __call__

View File

@@ -31,21 +31,6 @@ For more details about how Stable Diffusion works and how it differs from the ba
## Tips
### How to load and use different schedulers.
The stable diffusion pipeline uses [`PNDMScheduler`] scheduler by default. But `diffusers` provides many other schedulers that can be used with the stable diffusion pipeline such as [`DDIMScheduler`], [`LMSDiscreteScheduler`], [`EulerDiscreteScheduler`], [`EulerAncestralDiscreteScheduler`] etc.
To use a different scheduler, you can pass the `scheduler` argument to `from_pretrained` method of the pipeline. For example, to use the [`EulerDiscreteScheduler`], you can do the following:
```python
from diffusers import StableDiffusionPipeline, EulerDiscreteScheduler
euler_scheduler = EulerDiscreteScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
pipeline = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", scheduler=euler_scheduler)
```
### How to conver all use cases with multiple or single pipeline
If you want to use all possible use cases in a single `DiffusionPipeline` you can either:
- Make use of the [Stable Diffusion Mega Pipeline](https://github.com/huggingface/diffusers/tree/main/examples/community#stable-diffusion-mega) or
- Make use of the `components` functionality to instantiate all components in the most memory-efficient way:

View File

@@ -1,34 +0,0 @@
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# VQDiffusion
## Overview
[Vector Quantized Diffusion Model for Text-to-Image Synthesis](https://arxiv.org/abs/2111.14822) by Shuyang Gu, Dong Chen, Jianmin Bao, Fang Wen, Bo Zhang, Dongdong Chen, Lu Yuan, Baining Guo
The abstract of the paper is the following:
We present the vector quantized diffusion (VQ-Diffusion) model for text-to-image generation. This method is based on a vector quantized variational autoencoder (VQ-VAE) whose latent space is modeled by a conditional variant of the recently developed Denoising Diffusion Probabilistic Model (DDPM). We find that this latent-space method is well-suited for text-to-image generation tasks because it not only eliminates the unidirectional bias with existing methods but also allows us to incorporate a mask-and-replace diffusion strategy to avoid the accumulation of errors, which is a serious problem with existing methods. Our experiments show that the VQ-Diffusion produces significantly better text-to-image generation results when compared with conventional autoregressive (AR) models with similar numbers of parameters. Compared with previous GAN-based text-to-image methods, our VQ-Diffusion can handle more complex scenes and improve the synthesized image quality by a large margin. Finally, we show that the image generation computation in our method can be made highly efficient by reparameterization. With traditional AR methods, the text-to-image generation time increases linearly with the output image resolution and hence is quite time consuming even for normal size images. The VQ-Diffusion allows us to achieve a better trade-off between quality and speed. Our experiments indicate that the VQ-Diffusion model with the reparameterization is fifteen times faster than traditional AR methods while achieving a better image quality.
The original codebase can be found [here](https://github.com/microsoft/VQ-Diffusion).
## Available Pipelines:
| Pipeline | Tasks | Colab
|---|---|:---:|
| [pipeline_vq_diffusion.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/vq_diffusion/pipeline_vq_diffusion.py) | *Text-to-Image Generation* | - |
## VQDiffusionPipeline
[[autodoc]] pipelines.vq_diffusion.pipeline_vq_diffusion.VQDiffusionPipeline
- __call__

View File

@@ -70,12 +70,6 @@ Original paper can be found [here](https://arxiv.org/abs/2010.02502).
[[autodoc]] DDPMScheduler
#### Multistep DPM-Solver
Original paper can be found [here](https://arxiv.org/abs/2206.00927) and the [improved version](https://arxiv.org/abs/2211.01095). The original implementation can be found [here](https://github.com/LuChengTHU/dpm-solver).
[[autodoc]] DPMSolverMultistepScheduler
#### Variance exploding, stochastic sampling from Karras et. al
Original paper can be found [here](https://arxiv.org/abs/2006.11239).
@@ -118,34 +112,3 @@ Score SDE-VP is under construction.
</Tip>
[[autodoc]] schedulers.scheduling_sde_vp.ScoreSdeVpScheduler
#### Euler scheduler
Euler scheduler (Algorithm 2) from the paper [Elucidating the Design Space of Diffusion-Based Generative Models](https://arxiv.org/abs/2206.00364) by Karras et al. (2022). Based on the original [k-diffusion](https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L51) implementation by Katherine Crowson.
Fast scheduler which often times generates good outputs with 20-30 steps.
[[autodoc]] EulerDiscreteScheduler
#### Euler Ancestral scheduler
Ancestral sampling with Euler method steps. Based on the original (k-diffusion)[https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L72] implementation by Katherine Crowson.
Fast scheduler which often times generates good outputs with 20-30 steps.
[[autodoc]] EulerAncestralDiscreteScheduler
#### VQDiffusionScheduler
Original paper can be found [here](https://arxiv.org/abs/2111.14822)
[[autodoc]] VQDiffusionScheduler
#### RePaint scheduler
DDPM-based inpainting scheduler for unsupervised inpainting with extreme masks.
Intended for use with [`RePaintPipeline`].
Based on the paper [RePaint: Inpainting using Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2201.09865)
and the original implementation by Andreas Lugmayr et al.: https://github.com/andreas128/RePaint
[[autodoc]] RePaintScheduler

View File

@@ -34,8 +34,6 @@ available a colab notebook to directly try them out.
| Pipeline | Paper | Tasks | Colab
|---|---|:---:|:---:|
| [cycle_diffusion](./api/pipelines/cycle_diffusion) | [**Cycle Diffusion**](https://arxiv.org/abs/2210.05559) | Image-to-Image Text-Guided Generation |
| [dance_diffusion](./api/pipelines/dance_diffusion) | [**Dance Diffusion**](https://github.com/williamberman/diffusers.git) | Unconditional Audio Generation |
| [ddpm](./api/pipelines/ddpm) | [**Denoising Diffusion Probabilistic Models**](https://arxiv.org/abs/2006.11239) | Unconditional Image Generation |
| [ddim](./api/pipelines/ddim) | [**Denoising Diffusion Implicit Models**](https://arxiv.org/abs/2010.02502) | Unconditional Image Generation |
| [latent_diffusion](./api/pipelines/latent_diffusion) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752)| Text-to-Image Generation |
@@ -47,6 +45,5 @@ available a colab notebook to directly try them out.
| [stable_diffusion](./api/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Image-to-Image Text-Guided Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb)
| [stable_diffusion](./api/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-Guided Image Inpainting | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb)
| [stochastic_karras_ve](./api/pipelines/stochastic_karras_ve) | [**Elucidating the Design Space of Diffusion-Based Generative Models**](https://arxiv.org/abs/2206.00364) | Unconditional Image Generation |
| [vq_diffusion](./api/pipelines/vq_diffusion) | [Vector Quantized Diffusion Model for Text-to-Image Synthesis](https://arxiv.org/abs/2111.14822) | Text-to-Image Generation |
**Note**: Pipelines are simple examples of how to play around with the diffusion systems as described in the corresponding papers.

View File

@@ -12,12 +12,9 @@ specific language governing permissions and limitations under the License.
# Installation
Install 🤗 Diffusers for whichever deep learning library youre working with.
Install Diffusers for with PyTorch. Support for other libraries will come in the future
🤗 Diffusers is tested on Python 3.7+, PyTorch 1.7.0+ and flax. Follow the installation instructions below for the deep learning library you are using:
- [PyTorch](https://pytorch.org/get-started/locally/) installation instructions.
- [Flax](https://flax.readthedocs.io/en/latest/) installation instructions.
🤗 Diffusers is tested on Python 3.7+, and PyTorch 1.7.0+.
## Install with pip
@@ -39,30 +36,12 @@ source .env/bin/activate
Now you're ready to install 🤗 Diffusers with the following command:
**For PyTorch**
```bash
pip install diffusers["torch"]
```
**For Flax**
```bash
pip install diffusers["flax"]
pip install diffusers
```
## Install from source
Before intsalling `diffusers` from source, make sure you have `torch` and `accelerate` installed.
For `torch` installation refer to the `torch` [docs](https://pytorch.org/get-started/locally/#start-locally).
To install `accelerate`
```bash
pip install accelerate
```
Install 🤗 Diffusers from source with the following command:
```bash
@@ -88,18 +67,7 @@ Clone the repository and install 🤗 Diffusers with the following commands:
```bash
git clone https://github.com/huggingface/diffusers.git
cd diffusers
```
**For PyTorch**
```
pip install -e ".[torch]"
```
**For Flax**
```
pip install -e ".[flax]"
pip install -e .
```
These commands will link the folder you cloned the repository to and your Python library paths.

View File

@@ -22,7 +22,6 @@ We present some techniques and ideas to optimize 🤗 Diffusers _inference_ for
| fp16 | 3.61s | x2.63 |
| channels last | 3.30s | x2.88 |
| traced UNet | 3.21s | x2.96 |
| memory efficient attention | 2.63s | x3.61 |
<em>
obtained on NVIDIA TITAN RTX by generating a single image of size 512x512 from
@@ -291,41 +290,3 @@ pipe.unet = TracedUNet()
with torch.inference_mode():
image = pipe([prompt] * 1, num_inference_steps=50).images[0]
```
## Memory Efficient Attention
Recent work on optimizing the bandwitdh in the attention block have generated huge speed ups and gains in GPU memory usage. The most recent being Flash Attention (from @tridao, [code](https://github.com/HazyResearch/flash-attention), [paper](https://arxiv.org/pdf/2205.14135.pdf)) .
Here are the speedups we obtain on a few Nvidia GPUs when running the inference at 512x512 with a batch size of 1 (one prompt):
| GPU | Base Attention FP16 | Memory Efficient Attention FP16 |
|------------------ |--------------------- |--------------------------------- |
| NVIDIA Tesla T4 | 3.5it/s | 5.5it/s |
| NVIDIA 3060 RTX | 4.6it/s | 7.8it/s |
| NVIDIA A10G | 8.88it/s | 15.6it/s |
| NVIDIA RTX A6000 | 11.7it/s | 21.09it/s |
| NVIDIA TITAN RTX | 12.51it/s | 18.22it/s |
| A100-SXM4-40GB | 18.6it/s | 29.it/s |
| A100-SXM-80GB | 18.7it/s | 29.5it/s |
To leverage it just make sure you have:
- PyTorch > 1.12
- Cuda available
- Installed the [xformers](https://github.com/facebookresearch/xformers) library
```python
from diffusers import StableDiffusionPipeline
import torch
pipe = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
revision="fp16",
torch_dtype=torch.float16,
).to("cuda")
pipe.enable_xformers_memory_efficient_attention()
with torch.inference_mode():
sample = pipe("a small cat")
# optional: You can disable it via
# pipe.disable_xformers_memory_efficient_attention()
```

View File

@@ -19,8 +19,11 @@ specific language governing permissions and limitations under the License.
- Mac computer with Apple silicon (M1/M2) hardware.
- macOS 12.6 or later (13.0 or later recommended).
- arm64 version of Python.
- PyTorch 1.13. You can install it with `pip` or `conda` using the instructions in https://pytorch.org/get-started/locally/.
- PyTorch 1.13.0 RC (Release Candidate). You can install it with `pip` using:
```
pip3 install --pre torch --extra-index-url https://download.pytorch.org/whl/test/cpu
```
## Inference Pipeline
@@ -60,4 +63,4 @@ pipeline.enable_attention_slicing()
## Known Issues
- As mentioned above, we are investigating a strange [first-time inference issue](https://github.com/huggingface/diffusers/issues/372).
- Generating multiple prompts in a batch [crashes or doesn't work reliably](https://github.com/huggingface/diffusers/issues/363). We believe this is related to the [`mps` backend in PyTorch](https://github.com/pytorch/pytorch/issues/84039). This is being resolved, but for now we recommend to iterate instead of batching.
- Generating multiple prompts in a batch [crashes or doesn't work reliably](https://github.com/huggingface/diffusers/issues/363). We believe this is related to the [`mps` backend in PyTorch](https://github.com/pytorch/pytorch/issues/84039). For now, we recommend to iterate instead of batching.

View File

@@ -121,7 +121,7 @@ you could use it as follows:
```python
>>> from diffusers import LMSDiscreteScheduler
>>> scheduler = LMSDiscreteScheduler.from_config("runwayml/stable-diffusion-v1-5", subfolder="scheduler")
>>> scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear")
>>> generator = StableDiffusionPipeline.from_pretrained(
... "runwayml/stable-diffusion-v1-5", scheduler=scheduler, use_auth_token=AUTH_TOKEN

View File

@@ -38,9 +38,9 @@ Training examples show how to pretrain or fine-tune diffusion models for a varie
| Task | 🤗 Accelerate | 🤗 Datasets | Colab
|---|---|:---:|:---:|
| [**Unconditional Image Generation**](./unconditional_image_generation) | ✅ | ✅ | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb)
| [**Text-to-Image fine-tuning**](./text_to_image) | ✅ | ✅ |
| [**Textual Inversion**](./textual_inversion) | ✅ | - | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/sd_textual_inversion_training.ipynb)
| [**Unconditional Image Generation**](./unconditional_training) | ✅ | ✅ | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb)
| [**Text-to-Image fine-tuning**](./text2image) | ✅ | ✅ |
| [**Textual Inversion**](./text_inversion) | ✅ | - | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/sd_textual_inversion_training.ipynb)
| [**Dreambooth**](./dreambooth) | ✅ | - | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/sd_dreambooth_training.ipynb)

View File

@@ -18,7 +18,6 @@ If a community doesn't work as expected, please open an issue and ping the autho
| Composable Stable Diffusion| Stable Diffusion Pipeline that supports prompts that contain "&#124;" in prompts (as an AND condition) and weights (separated by "&#124;" as well) to positively / negatively weight prompts. | [Composable Stable Diffusion](#composable-stable-diffusion) | - | [Mark Rich](https://github.com/MarkRich) |
| Seed Resizing Stable Diffusion| Stable Diffusion Pipeline that supports resizing an image and retaining the concepts of the 512 by 512 generation. | [Seed Resizing](#seed-resizing) | - | [Mark Rich](https://github.com/MarkRich) |
| Imagic Stable Diffusion | Stable Diffusion Pipeline that enables writing a text prompt to edit an existing image| [Imagic Stable Diffusion](#imagic-stable-diffusion) | - | [Mark Rich](https://github.com/MarkRich) |
To load a custom pipeline you just need to pass the `custom_pipeline` argument to `DiffusionPipeline`, as one of the files in `diffusers/examples/community`. Feel free to send a PR with your own pipelines, we will merge them quickly.
@@ -374,49 +373,6 @@ for i in range(4):
for i, img in enumerate(images):
img.save(f"./composable_diffusion/image_{i}.png")
```
### Imagic Stable Diffusion
Allows you to edit an image using stable diffusion.
```python
import requests
from PIL import Image
from io import BytesIO
import torch
from diffusers import DiffusionPipeline, DDIMScheduler
has_cuda = torch.cuda.is_available()
device = torch.device('cpu' if not has_cuda else 'cuda')
pipe = DiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
safety_checker=None,
use_auth_token=True,
custom_pipeline="imagic_stable_diffusion",
scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
).to(device)
generator = th.Generator("cuda").manual_seed(0)
seed = 0
prompt = "A photo of Barack Obama smiling with a big grin"
url = 'https://www.dropbox.com/s/6tlwzr73jd1r9yk/obama.png?dl=1'
response = requests.get(url)
init_image = Image.open(BytesIO(response.content)).convert("RGB")
init_image = init_image.resize((512, 512))
res = pipe.train(
prompt,
init_image,
guidance_scale=7.5,
num_inference_steps=50,
generator=generator)
res = pipe(alpha=1)
image = res.images[0]
image.save('./imagic/imagic_image_alpha_1.png')
res = pipe(alpha=1.5)
image = res.images[0]
image.save('./imagic/imagic_image_alpha_1_5.png')
res = pipe(alpha=2)
image = res.images[0]
image.save('./imagic/imagic_image_alpha_2.png')
```
### Seed Resizing
Test seed resizing. Originally generate an image in 512 by 512, then generate image with same seed at 512 by 592 using seed resizing. Finally, generate 512 by 592 using original stable diffusion pipeline.
@@ -500,4 +456,4 @@ res = pipe_compare(
image = res.images[0]
image.save('./seed_resize/seed_resize_{w}_{h}_image_compare.png'.format(w=width, h=height))
```
```

View File

@@ -1,476 +0,0 @@
"""
modeled after the textual_inversion.py / train_dreambooth.py and the work
of justinpinkney here: https://github.com/justinpinkney/stable-diffusion/blob/main/notebooks/imagic.ipynb
"""
import inspect
import warnings
from typing import List, Optional, Union
import numpy as np
import torch
import torch.nn.functional as F
import PIL
from accelerate import Accelerator
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from diffusers.utils import logging
from tqdm.auto import tqdm
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def preprocess(image):
w, h = image.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
image = image.resize((w, h), resample=PIL.Image.LANCZOS)
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
return 2.0 * image - 1.0
class ImagicStableDiffusionPipeline(DiffusionPipeline):
r"""
Pipeline for imagic image editing.
See paper here: https://arxiv.org/pdf/2210.09276.pdf
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`CLIPTextModel`]):
Frozen text-encoder. Stable Diffusion uses the text portion of
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
tokenizer (`CLIPTokenizer`):
Tokenizer of class
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offsensive or harmful.
Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
feature_extractor ([`CLIPFeatureExtractor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
def __init__(
self,
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
):
super().__init__()
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
r"""
Enable sliced attention computation.
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
in several steps. This is useful to save some memory in exchange for a small speed decrease.
Args:
slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
`attention_head_dim` must be a multiple of `slice_size`.
"""
if slice_size == "auto":
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = self.unet.config.attention_head_dim // 2
self.unet.set_attention_slice(slice_size)
def disable_attention_slicing(self):
r"""
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
back to computing attention in one step.
"""
# set slice_size = `None` to disable `attention slicing`
self.enable_attention_slicing(None)
def train(
self,
prompt: Union[str, List[str]],
init_image: Union[torch.FloatTensor, PIL.Image.Image],
height: Optional[int] = 512,
width: Optional[int] = 512,
generator: Optional[torch.Generator] = None,
embedding_learning_rate: float = 0.001,
diffusion_model_learning_rate: float = 2e-6,
text_embedding_optimization_steps: int = 500,
model_fine_tuning_optimization_steps: int = 1000,
**kwargs,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
height (`int`, *optional*, defaults to 512):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to 512):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 7.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator`, *optional*):
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
Returns:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
When returning a tuple, the first element is a list with the generated images, and the second element is a
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`.
"""
accelerator = Accelerator(
gradient_accumulation_steps=1,
mixed_precision="fp16",
)
if "torch_device" in kwargs:
device = kwargs.pop("torch_device")
warnings.warn(
"`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0."
" Consider using `pipe.to(torch_device)` instead."
)
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
self.to(device)
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
# Freeze vae and unet
self.vae.requires_grad_(False)
self.unet.requires_grad_(False)
self.text_encoder.requires_grad_(False)
self.unet.eval()
self.vae.eval()
self.text_encoder.eval()
if accelerator.is_main_process:
accelerator.init_trackers(
"imagic",
config={
"embedding_learning_rate": embedding_learning_rate,
"text_embedding_optimization_steps": text_embedding_optimization_steps,
},
)
# get text embeddings for prompt
text_input = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncaton=True,
return_tensors="pt",
)
text_embeddings = torch.nn.Parameter(
self.text_encoder(text_input.input_ids.to(self.device))[0], requires_grad=True
)
text_embeddings = text_embeddings.detach()
text_embeddings.requires_grad_()
text_embeddings_orig = text_embeddings.clone()
# Initialize the optimizer
optimizer = torch.optim.Adam(
[text_embeddings], # only optimize the embeddings
lr=embedding_learning_rate,
)
if isinstance(init_image, PIL.Image.Image):
init_image = preprocess(init_image)
latents_dtype = text_embeddings.dtype
init_image = init_image.to(device=self.device, dtype=latents_dtype)
init_latent_image_dist = self.vae.encode(init_image).latent_dist
init_image_latents = init_latent_image_dist.sample(generator=generator)
init_image_latents = 0.18215 * init_image_latents
progress_bar = tqdm(range(text_embedding_optimization_steps), disable=not accelerator.is_local_main_process)
progress_bar.set_description("Steps")
global_step = 0
logger.info("First optimizing the text embedding to better reconstruct the init image")
for _ in range(text_embedding_optimization_steps):
with accelerator.accumulate(text_embeddings):
# Sample noise that we'll add to the latents
noise = torch.randn(init_image_latents.shape).to(init_image_latents.device)
timesteps = torch.randint(1000, (1,), device=init_image_latents.device)
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = self.scheduler.add_noise(init_image_latents, noise, timesteps)
# Predict the noise residual
noise_pred = self.unet(noisy_latents, timesteps, text_embeddings).sample
loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
accelerator.backward(loss)
optimizer.step()
optimizer.zero_grad()
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
logs = {"loss": loss.detach().item()} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
accelerator.log(logs, step=global_step)
accelerator.wait_for_everyone()
text_embeddings.requires_grad_(False)
# Now we fine tune the unet to better reconstruct the image
self.unet.requires_grad_(True)
self.unet.train()
optimizer = torch.optim.Adam(
self.unet.parameters(), # only optimize unet
lr=diffusion_model_learning_rate,
)
progress_bar = tqdm(range(model_fine_tuning_optimization_steps), disable=not accelerator.is_local_main_process)
logger.info("Next fine tuning the entire model to better reconstruct the init image")
for _ in range(model_fine_tuning_optimization_steps):
with accelerator.accumulate(self.unet.parameters()):
# Sample noise that we'll add to the latents
noise = torch.randn(init_image_latents.shape).to(init_image_latents.device)
timesteps = torch.randint(1000, (1,), device=init_image_latents.device)
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = self.scheduler.add_noise(init_image_latents, noise, timesteps)
# Predict the noise residual
noise_pred = self.unet(noisy_latents, timesteps, text_embeddings).sample
loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
accelerator.backward(loss)
optimizer.step()
optimizer.zero_grad()
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
logs = {"loss": loss.detach().item()} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
accelerator.log(logs, step=global_step)
accelerator.wait_for_everyone()
self.text_embeddings_orig = text_embeddings_orig
self.text_embeddings = text_embeddings
@torch.no_grad()
def __call__(
self,
alpha: float = 1.2,
height: Optional[int] = 512,
width: Optional[int] = 512,
num_inference_steps: Optional[int] = 50,
generator: Optional[torch.Generator] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
guidance_scale: float = 7.5,
eta: float = 0.0,
**kwargs,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
height (`int`, *optional*, defaults to 512):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to 512):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 7.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator`, *optional*):
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
Returns:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
When returning a tuple, the first element is a list with the generated images, and the second element is a
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`.
"""
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if self.text_embeddings is None:
raise ValueError("Please run the pipe.train() before trying to generate an image.")
if self.text_embeddings_orig is None:
raise ValueError("Please run the pipe.train() before trying to generate an image.")
text_embeddings = alpha * self.text_embeddings_orig + (1 - alpha) * self.text_embeddings
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance:
uncond_tokens = [""]
max_length = self.tokenizer.model_max_length
uncond_input = self.tokenizer(
uncond_tokens,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
)
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = uncond_embeddings.shape[1]
uncond_embeddings = uncond_embeddings.view(1, seq_len, -1)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
# get the initial random noise unless the user supplied it
# Unlike in other pipelines, latents need to be generated in the target device
# for 1-to-1 results reproducibility with the CompVis implementation.
# However this currently doesn't work in `mps`.
latents_shape = (1, self.unet.in_channels, height // 8, width // 8)
latents_dtype = text_embeddings.dtype
if self.device.type == "mps":
# randn does not exist on mps
latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to(
self.device
)
else:
latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)
# set timesteps
self.scheduler.set_timesteps(num_inference_steps)
# Some schedulers like PNDM have timesteps as arrays
# It's more optimized to move all timesteps to correct device beforehand
timesteps_tensor = self.scheduler.timesteps.to(self.device)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
for i, t in enumerate(self.progress_bar(timesteps_tensor)):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
latents = 1 / 0.18215 * latents
image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
if self.safety_checker is not None:
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(
self.device
)
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
)
else:
has_nsfw_concept = None
if output_type == "pil":
image = self.numpy_to_pil(image)
if not return_dict:
return (image, has_nsfw_concept)
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)

View File

@@ -278,7 +278,7 @@ class StableDiffusionWalkPipeline(DiffusionPipeline):
if do_classifier_free_guidance:
uncond_tokens: List[str]
if negative_prompt is None:
uncond_tokens = [""] * batch_size
uncond_tokens = [""]
elif type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
@@ -307,7 +307,7 @@ class StableDiffusionWalkPipeline(DiffusionPipeline):
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = uncond_embeddings.shape[1]
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
uncond_embeddings = uncond_embeddings.repeat(batch_size, num_images_per_prompt, 1)
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
# For classifier free guidance, we need to do two forward passes.

View File

@@ -12,7 +12,7 @@ from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from diffusers.utils import deprecate, is_accelerate_available, logging
from diffusers.utils import deprecate, logging
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
@@ -340,15 +340,13 @@ def get_weighted_text_embeddings(
# assign weights to the prompts and normalize in the sense of mean
# TODO: should we normalize by chunk or in a whole (current implementation)?
if (not skip_parsing) and (not skip_weighting):
previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
previous_mean = text_embeddings.mean(axis=[-2, -1])
text_embeddings *= prompt_weights.unsqueeze(-1)
current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
text_embeddings *= (previous_mean / text_embeddings.mean(axis=[-2, -1])).unsqueeze(-1).unsqueeze(-1)
if uncond_prompt is not None:
previous_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
previous_mean = uncond_embeddings.mean(axis=[-2, -1])
uncond_embeddings *= uncond_weights.unsqueeze(-1)
current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
uncond_embeddings *= (previous_mean / uncond_embeddings.mean(axis=[-2, -1])).unsqueeze(-1).unsqueeze(-1)
if uncond_prompt is not None:
return text_embeddings, uncond_embeddings
@@ -433,19 +431,6 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
)
deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(scheduler.config)
new_config["clip_sample"] = False
scheduler._internal_dict = FrozenDict(new_config)
if safety_checker is None:
logger.warn(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
@@ -466,24 +451,6 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
feature_extractor=feature_extractor,
)
def enable_xformers_memory_efficient_attention(self):
r"""
Enable memory efficient attention as implemented in xformers.
When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
time. Speed up at training time is not guaranteed.
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
is used.
"""
self.unet.set_use_memory_efficient_attention_xformers(True)
def disable_xformers_memory_efficient_attention(self):
r"""
Disable memory efficient attention as implemented in xformers.
"""
self.unet.set_use_memory_efficient_attention_xformers(False)
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
r"""
Enable sliced attention computation.
@@ -511,23 +478,6 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
# set slice_size = `None` to disable `attention slicing`
self.enable_attention_slicing(None)
def enable_sequential_cpu_offload(self):
r"""
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
`torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
"""
if is_accelerate_available():
from accelerate import cpu_offload
else:
raise ImportError("Please install accelerate via `pip install accelerate`")
device = self.device
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]:
if cpu_offloaded_model is not None:
cpu_offload(cpu_offloaded_model, device)
@torch.no_grad()
def __call__(
self,
@@ -548,7 +498,6 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
is_cancelled_callback: Optional[Callable[[], bool]] = None,
callback_steps: Optional[int] = 1,
**kwargs,
):
@@ -611,15 +560,11 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
is_cancelled_callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. If the function returns
`True`, the inference will be cancelled.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
Returns:
`None` if cancelled by `is_cancelled_callback`,
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
When returning a tuple, the first element is a list with the generated images, and the second element is a
@@ -812,11 +757,8 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
latents = (init_latents_proper * mask) + (latents * (1 - mask))
# call the callback, if provided
if i % callback_steps == 0:
if callback is not None:
callback(i, t, latents)
if is_cancelled_callback is not None and is_cancelled_callback():
return None
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
latents = 1 / 0.18215 * latents
image = self.vae.decode(latents).sample

View File

@@ -435,7 +435,6 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
is_cancelled_callback: Optional[Callable[[], bool]] = None,
callback_steps: Optional[int] = 1,
**kwargs,
):
@@ -497,15 +496,11 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: np.ndarray)`.
is_cancelled_callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. If the function returns
`True`, the inference will be cancelled.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
Returns:
`None` if cancelled by `is_cancelled_callback`,
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
When returning a tuple, the first element is a list with the generated images, and the second element is a
@@ -673,11 +668,8 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
latents = (init_latents_proper * mask) + (latents * (1 - mask))
# call the callback, if provided
if i % callback_steps == 0:
if callback is not None:
callback(i, t, latents)
if is_cancelled_callback is not None and is_cancelled_callback():
return None
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
latents = 1 / 0.18215 * latents
# image = self.vae_decoder(latent_sample=latents)[0]
@@ -701,7 +693,7 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
clip_input=safety_checker_input[i : i + 1], images=image[i : i + 1]
)
images.append(image_i)
has_nsfw_concept.append(has_nsfw_concept_i[0])
has_nsfw_concept.append(has_nsfw_concept_i)
image = np.concatenate(images)
else:
has_nsfw_concept = None

View File

@@ -148,7 +148,7 @@ class SpeechToImagePipeline(DiffusionPipeline):
if do_classifier_free_guidance:
uncond_tokens: List[str]
if negative_prompt is None:
uncond_tokens = [""] * batch_size
uncond_tokens = [""]
elif type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
@@ -177,7 +177,7 @@ class SpeechToImagePipeline(DiffusionPipeline):
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = uncond_embeddings.shape[1]
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
uncond_embeddings = uncond_embeddings.repeat(batch_size, num_images_per_prompt, 1)
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
# For classifier free guidance, we need to do two forward passes.

View File

@@ -295,7 +295,7 @@ class WildcardStableDiffusionPipeline(DiffusionPipeline):
if do_classifier_free_guidance:
uncond_tokens: List[str]
if negative_prompt is None:
uncond_tokens = [""] * batch_size
uncond_tokens = [""]
elif type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
@@ -324,7 +324,7 @@ class WildcardStableDiffusionPipeline(DiffusionPipeline):
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = uncond_embeddings.shape[1]
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
uncond_embeddings = uncond_embeddings.repeat(batch_size, num_images_per_prompt, 1)
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
# For classifier free guidance, we need to do two forward passes.

View File

@@ -185,7 +185,7 @@ accelerate launch train_dreambooth.py \
--class_prompt="a photo of dog" \
--resolution=512 \
--train_batch_size=1 \
--use_8bit_adam \
--use_8bit_adam
--gradient_checkpointing \
--learning_rate=2e-6 \
--lr_scheduler="constant" \
@@ -291,4 +291,4 @@ python train_dreambooth_flax.py \
--learning_rate=2e-6 \
--num_class_images=200 \
--max_train_steps=800
```
```

View File

@@ -469,7 +469,9 @@ def main(args):
eps=args.adam_epsilon,
)
noise_scheduler = DDPMScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
noise_scheduler = DDPMScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000
)
train_dataset = DreamBoothDataset(
instance_data_root=args.instance_data_dir,
@@ -494,12 +496,7 @@ def main(args):
pixel_values = torch.stack(pixel_values)
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
input_ids = tokenizer.pad(
{"input_ids": input_ids},
padding="max_length",
max_length=tokenizer.model_max_length,
return_tensors="pt",
).input_ids
input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids
batch = {
"input_ids": input_ids,

View File

@@ -361,8 +361,7 @@ def main():
logger.info(f"Number of class images to sample: {num_new_images}.")
sample_dataset = PromptDataset(args.class_prompt, num_new_images)
total_sample_batch_size = args.sample_batch_size * jax.local_device_count()
sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=total_sample_batch_size)
sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
for example in tqdm(
sample_dataloader, desc="Generating class images", disable=not jax.process_index() == 0

View File

@@ -372,7 +372,11 @@ def main():
weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,
)
noise_scheduler = DDPMScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
# TODO (patil-suraj): load scheduler using args
noise_scheduler = DDPMScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000
)
# Get the datasets: you can either provide your own training and evaluation files (see below)
# or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
@@ -605,7 +609,9 @@ def main():
vae=vae,
unet=unet,
tokenizer=tokenizer,
scheduler=PNDMScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler"),
scheduler=PNDMScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True
),
safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"),
feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
)

View File

@@ -29,7 +29,7 @@ accelerate config
### Cat toy example
You need to accept the model license before downloading or using the weights. In this example we'll use model version `v1-5`, so you'll need to visit [its card](https://huggingface.co/runwayml/stable-diffusion-v1-5), read the license and tick the checkbox if you agree.
You need to accept the model license before downloading or using the weights. In this example we'll use model version `v1-4`, so you'll need to visit [its card](https://huggingface.co/CompVis/stable-diffusion-v1-4), read the license and tick the checkbox if you agree.
You have to be a registered user in 🤗 Hugging Face Hub, and you'll also need to use an access token for the code to work. For more information on access tokens, please refer to [this section of the documentation](https://huggingface.co/docs/hub/security-tokens).
@@ -111,4 +111,4 @@ python textual_inversion_flax.py \
--learning_rate=5.0e-04 --scale_lr \
--output_dir="textual_inversion_cat"
```
It should be at least 70% faster than the PyTorch script with the same configuration.
It should be at least 70% faster than the PyTorch script with the same configuration.

View File

@@ -419,7 +419,13 @@ def main():
eps=args.adam_epsilon,
)
noise_scheduler = DDPMScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
# TODO (patil-suraj): load scheduler using args
noise_scheduler = DDPMScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
num_train_timesteps=1000,
)
train_dataset = TextualInversionDataset(
data_root=args.train_data_dir,
@@ -552,7 +558,9 @@ def main():
vae=vae,
unet=unet,
tokenizer=tokenizer,
scheduler=PNDMScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler"),
scheduler=PNDMScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True
),
safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"),
feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
)

View File

@@ -29,24 +29,6 @@ from tqdm.auto import tqdm
logger = get_logger(__name__)
def _extract_into_tensor(arr, timesteps, broadcast_shape):
"""
Extract values from a 1-D numpy array for a batch of indices.
:param arr: the 1-D numpy array.
:param timesteps: a tensor of indices into the array to extract.
:param broadcast_shape: a larger shape of K dimensions with the batch
dimension equal to the length of timesteps.
:return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
"""
if not isinstance(arr, torch.Tensor):
arr = torch.from_numpy(arr)
res = arr[timesteps].float().to(timesteps.device)
while len(res.shape) < len(broadcast_shape):
res = res[..., None]
return res.expand(broadcast_shape)
def parse_args():
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument(
@@ -189,16 +171,6 @@ def parse_args():
),
)
parser.add_argument(
"--predict_mode",
type=str,
default="eps",
help="What the model should predict. 'eps' to predict error, 'x0' to directly predict reconstruction",
)
parser.add_argument("--ddpm_num_steps", type=int, default=1000)
parser.add_argument("--ddpm_beta_schedule", type=str, default="linear")
args = parser.parse_args()
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
if env_local_rank != -1 and env_local_rank != args.local_rank:
@@ -252,7 +224,7 @@ def main(args):
"UpBlock2D",
),
)
noise_scheduler = DDPMScheduler(num_train_timesteps=args.ddpm_num_steps, beta_schedule=args.ddpm_beta_schedule)
noise_scheduler = DDPMScheduler(num_train_timesteps=1000)
optimizer = torch.optim.AdamW(
model.parameters(),
lr=args.learning_rate,
@@ -285,8 +257,6 @@ def main(args):
images = [augmentations(image.convert("RGB")) for image in examples["image"]]
return {"input": images}
logger.info(f"Dataset size: {len(dataset)}")
dataset.set_transform(transforms)
train_dataloader = torch.utils.data.DataLoader(
dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers
@@ -349,20 +319,8 @@ def main(args):
with accelerator.accumulate(model):
# Predict the noise residual
model_output = model(noisy_images, timesteps).sample
if args.predict_mode == "eps":
loss = F.mse_loss(model_output, noise) # this could have different weights!
elif args.predict_mode == "x0":
alpha_t = _extract_into_tensor(
noise_scheduler.alphas_cumprod, timesteps, (clean_images.shape[0], 1, 1, 1)
)
snr_weights = alpha_t / (1 - alpha_t)
loss = snr_weights * F.mse_loss(
model_output, clean_images, reduction="none"
) # use SNR weighting from distillation paper
loss = loss.mean()
noise_pred = model(noisy_images, timesteps).sample
loss = F.mse_loss(noise_pred, noise)
accelerator.backward(loss)
if accelerator.sync_gradients:
@@ -397,12 +355,7 @@ def main(args):
generator = torch.manual_seed(0)
# run pipeline in inference (sample random noise and denoise)
images = pipeline(
generator=generator,
batch_size=args.eval_batch_size,
output_type="numpy",
predict_epsilon=args.predict_mode == "eps",
).images
images = pipeline(generator=generator, batch_size=args.eval_batch_size, output_type="numpy").images
# denormalize the images and save to tensorboard
images_processed = (images * 255).round().astype("uint8")

View File

@@ -1,885 +0,0 @@
"""
This script ports models from VQ-diffusion (https://github.com/microsoft/VQ-Diffusion) to diffusers.
It currently only supports porting the ITHQ dataset.
ITHQ dataset:
```sh
# From the root directory of diffusers.
# Download the VQVAE checkpoint
$ wget https://facevcstandard.blob.core.windows.net/v-zhictang/Improved-VQ-Diffusion_model_release/ithq_vqvae.pth?sv=2020-10-02&st=2022-05-30T15%3A17%3A18Z&se=2030-05-31T15%3A17%3A00Z&sr=b&sp=r&sig=1jVavHFPpUjDs%2FTO1V3PTezaNbPp2Nx8MxiWI7y6fEY%3D -O ithq_vqvae.pth
# Download the VQVAE config
# NOTE that in VQ-diffusion the documented file is `configs/ithq.yaml` but the target class
# `image_synthesis.modeling.codecs.image_codec.ema_vqvae.PatchVQVAE`
# loads `OUTPUT/pretrained_model/taming_dvae/config.yaml`
$ wget https://raw.githubusercontent.com/microsoft/VQ-Diffusion/main/OUTPUT/pretrained_model/taming_dvae/config.yaml -O ithq_vqvae.yaml
# Download the main model checkpoint
$ wget https://facevcstandard.blob.core.windows.net/v-zhictang/Improved-VQ-Diffusion_model_release/ithq_learnable.pth?sv=2020-10-02&st=2022-05-30T10%3A22%3A06Z&se=2030-05-31T10%3A22%3A00Z&sr=b&sp=r&sig=GOE%2Bza02%2FPnGxYVOOPtwrTR4RA3%2F5NVgMxdW4kjaEZ8%3D -O ithq_learnable.pth
# Download the main model config
$ wget https://raw.githubusercontent.com/microsoft/VQ-Diffusion/main/configs/ithq.yaml -O ithq.yaml
# run the convert script
$ python ./scripts/convert_vq_diffusion_to_diffusers.py \
--checkpoint_path ./ithq_learnable.pth \
--original_config_file ./ithq.yaml \
--vqvae_checkpoint_path ./ithq_vqvae.pth \
--vqvae_original_config_file ./ithq_vqvae.yaml \
--dump_path <path to save pre-trained `VQDiffusionPipeline`>
```
"""
import argparse
import tempfile
import torch
import yaml
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
from diffusers import VQDiffusionPipeline, VQDiffusionScheduler, VQModel
from diffusers.models.attention import Transformer2DModel
from transformers import CLIPTextModel, CLIPTokenizer
from yaml.loader import FullLoader
try:
from omegaconf import OmegaConf
except ImportError:
raise ImportError(
"OmegaConf is required to convert the VQ Diffusion checkpoints. Please install it with `pip install"
" OmegaConf`."
)
# vqvae model
PORTED_VQVAES = ["image_synthesis.modeling.codecs.image_codec.patch_vqgan.PatchVQGAN"]
def vqvae_model_from_original_config(original_config):
assert original_config.target in PORTED_VQVAES, f"{original_config.target} has not yet been ported to diffusers."
original_config = original_config.params
original_encoder_config = original_config.encoder_config.params
original_decoder_config = original_config.decoder_config.params
in_channels = original_encoder_config.in_channels
out_channels = original_decoder_config.out_ch
down_block_types = get_down_block_types(original_encoder_config)
up_block_types = get_up_block_types(original_decoder_config)
assert original_encoder_config.ch == original_decoder_config.ch
assert original_encoder_config.ch_mult == original_decoder_config.ch_mult
block_out_channels = tuple(
[original_encoder_config.ch * a_ch_mult for a_ch_mult in original_encoder_config.ch_mult]
)
assert original_encoder_config.num_res_blocks == original_decoder_config.num_res_blocks
layers_per_block = original_encoder_config.num_res_blocks
assert original_encoder_config.z_channels == original_decoder_config.z_channels
latent_channels = original_encoder_config.z_channels
num_vq_embeddings = original_config.n_embed
# Hard coded value for ResnetBlock.GoupNorm(num_groups) in VQ-diffusion
norm_num_groups = 32
e_dim = original_config.embed_dim
model = VQModel(
in_channels=in_channels,
out_channels=out_channels,
down_block_types=down_block_types,
up_block_types=up_block_types,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
latent_channels=latent_channels,
num_vq_embeddings=num_vq_embeddings,
norm_num_groups=norm_num_groups,
vq_embed_dim=e_dim,
)
return model
def get_down_block_types(original_encoder_config):
attn_resolutions = coerce_attn_resolutions(original_encoder_config.attn_resolutions)
num_resolutions = len(original_encoder_config.ch_mult)
resolution = coerce_resolution(original_encoder_config.resolution)
curr_res = resolution
down_block_types = []
for _ in range(num_resolutions):
if curr_res in attn_resolutions:
down_block_type = "AttnDownEncoderBlock2D"
else:
down_block_type = "DownEncoderBlock2D"
down_block_types.append(down_block_type)
curr_res = [r // 2 for r in curr_res]
return down_block_types
def get_up_block_types(original_decoder_config):
attn_resolutions = coerce_attn_resolutions(original_decoder_config.attn_resolutions)
num_resolutions = len(original_decoder_config.ch_mult)
resolution = coerce_resolution(original_decoder_config.resolution)
curr_res = [r // 2 ** (num_resolutions - 1) for r in resolution]
up_block_types = []
for _ in reversed(range(num_resolutions)):
if curr_res in attn_resolutions:
up_block_type = "AttnUpDecoderBlock2D"
else:
up_block_type = "UpDecoderBlock2D"
up_block_types.append(up_block_type)
curr_res = [r * 2 for r in curr_res]
return up_block_types
def coerce_attn_resolutions(attn_resolutions):
attn_resolutions = OmegaConf.to_object(attn_resolutions)
attn_resolutions_ = []
for ar in attn_resolutions:
if isinstance(ar, (list, tuple)):
attn_resolutions_.append(list(ar))
else:
attn_resolutions_.append([ar, ar])
return attn_resolutions_
def coerce_resolution(resolution):
resolution = OmegaConf.to_object(resolution)
if isinstance(resolution, int):
resolution = [resolution, resolution] # H, W
elif isinstance(resolution, (tuple, list)):
resolution = list(resolution)
else:
raise ValueError("Unknown type of resolution:", resolution)
return resolution
# done vqvae model
# vqvae checkpoint
def vqvae_original_checkpoint_to_diffusers_checkpoint(model, checkpoint):
diffusers_checkpoint = {}
diffusers_checkpoint.update(vqvae_encoder_to_diffusers_checkpoint(model, checkpoint))
# quant_conv
diffusers_checkpoint.update(
{
"quant_conv.weight": checkpoint["quant_conv.weight"],
"quant_conv.bias": checkpoint["quant_conv.bias"],
}
)
# quantize
diffusers_checkpoint.update({"quantize.embedding.weight": checkpoint["quantize.embedding"]})
# post_quant_conv
diffusers_checkpoint.update(
{
"post_quant_conv.weight": checkpoint["post_quant_conv.weight"],
"post_quant_conv.bias": checkpoint["post_quant_conv.bias"],
}
)
# decoder
diffusers_checkpoint.update(vqvae_decoder_to_diffusers_checkpoint(model, checkpoint))
return diffusers_checkpoint
def vqvae_encoder_to_diffusers_checkpoint(model, checkpoint):
diffusers_checkpoint = {}
# conv_in
diffusers_checkpoint.update(
{
"encoder.conv_in.weight": checkpoint["encoder.conv_in.weight"],
"encoder.conv_in.bias": checkpoint["encoder.conv_in.bias"],
}
)
# down_blocks
for down_block_idx, down_block in enumerate(model.encoder.down_blocks):
diffusers_down_block_prefix = f"encoder.down_blocks.{down_block_idx}"
down_block_prefix = f"encoder.down.{down_block_idx}"
# resnets
for resnet_idx, resnet in enumerate(down_block.resnets):
diffusers_resnet_prefix = f"{diffusers_down_block_prefix}.resnets.{resnet_idx}"
resnet_prefix = f"{down_block_prefix}.block.{resnet_idx}"
diffusers_checkpoint.update(
vqvae_resnet_to_diffusers_checkpoint(
resnet, checkpoint, diffusers_resnet_prefix=diffusers_resnet_prefix, resnet_prefix=resnet_prefix
)
)
# downsample
# do not include the downsample when on the last down block
# There is no downsample on the last down block
if down_block_idx != len(model.encoder.down_blocks) - 1:
# There's a single downsample in the original checkpoint but a list of downsamples
# in the diffusers model.
diffusers_downsample_prefix = f"{diffusers_down_block_prefix}.downsamplers.0.conv"
downsample_prefix = f"{down_block_prefix}.downsample.conv"
diffusers_checkpoint.update(
{
f"{diffusers_downsample_prefix}.weight": checkpoint[f"{downsample_prefix}.weight"],
f"{diffusers_downsample_prefix}.bias": checkpoint[f"{downsample_prefix}.bias"],
}
)
# attentions
if hasattr(down_block, "attentions"):
for attention_idx, _ in enumerate(down_block.attentions):
diffusers_attention_prefix = f"{diffusers_down_block_prefix}.attentions.{attention_idx}"
attention_prefix = f"{down_block_prefix}.attn.{attention_idx}"
diffusers_checkpoint.update(
vqvae_attention_to_diffusers_checkpoint(
checkpoint,
diffusers_attention_prefix=diffusers_attention_prefix,
attention_prefix=attention_prefix,
)
)
# mid block
# mid block attentions
# There is a single hardcoded attention block in the middle of the VQ-diffusion encoder
diffusers_attention_prefix = "encoder.mid_block.attentions.0"
attention_prefix = "encoder.mid.attn_1"
diffusers_checkpoint.update(
vqvae_attention_to_diffusers_checkpoint(
checkpoint, diffusers_attention_prefix=diffusers_attention_prefix, attention_prefix=attention_prefix
)
)
# mid block resnets
for diffusers_resnet_idx, resnet in enumerate(model.encoder.mid_block.resnets):
diffusers_resnet_prefix = f"encoder.mid_block.resnets.{diffusers_resnet_idx}"
# the hardcoded prefixes to `block_` are 1 and 2
orig_resnet_idx = diffusers_resnet_idx + 1
# There are two hardcoded resnets in the middle of the VQ-diffusion encoder
resnet_prefix = f"encoder.mid.block_{orig_resnet_idx}"
diffusers_checkpoint.update(
vqvae_resnet_to_diffusers_checkpoint(
resnet, checkpoint, diffusers_resnet_prefix=diffusers_resnet_prefix, resnet_prefix=resnet_prefix
)
)
diffusers_checkpoint.update(
{
# conv_norm_out
"encoder.conv_norm_out.weight": checkpoint["encoder.norm_out.weight"],
"encoder.conv_norm_out.bias": checkpoint["encoder.norm_out.bias"],
# conv_out
"encoder.conv_out.weight": checkpoint["encoder.conv_out.weight"],
"encoder.conv_out.bias": checkpoint["encoder.conv_out.bias"],
}
)
return diffusers_checkpoint
def vqvae_decoder_to_diffusers_checkpoint(model, checkpoint):
diffusers_checkpoint = {}
# conv in
diffusers_checkpoint.update(
{
"decoder.conv_in.weight": checkpoint["decoder.conv_in.weight"],
"decoder.conv_in.bias": checkpoint["decoder.conv_in.bias"],
}
)
# up_blocks
for diffusers_up_block_idx, up_block in enumerate(model.decoder.up_blocks):
# up_blocks are stored in reverse order in the VQ-diffusion checkpoint
orig_up_block_idx = len(model.decoder.up_blocks) - 1 - diffusers_up_block_idx
diffusers_up_block_prefix = f"decoder.up_blocks.{diffusers_up_block_idx}"
up_block_prefix = f"decoder.up.{orig_up_block_idx}"
# resnets
for resnet_idx, resnet in enumerate(up_block.resnets):
diffusers_resnet_prefix = f"{diffusers_up_block_prefix}.resnets.{resnet_idx}"
resnet_prefix = f"{up_block_prefix}.block.{resnet_idx}"
diffusers_checkpoint.update(
vqvae_resnet_to_diffusers_checkpoint(
resnet, checkpoint, diffusers_resnet_prefix=diffusers_resnet_prefix, resnet_prefix=resnet_prefix
)
)
# upsample
# there is no up sample on the last up block
if diffusers_up_block_idx != len(model.decoder.up_blocks) - 1:
# There's a single upsample in the VQ-diffusion checkpoint but a list of downsamples
# in the diffusers model.
diffusers_downsample_prefix = f"{diffusers_up_block_prefix}.upsamplers.0.conv"
downsample_prefix = f"{up_block_prefix}.upsample.conv"
diffusers_checkpoint.update(
{
f"{diffusers_downsample_prefix}.weight": checkpoint[f"{downsample_prefix}.weight"],
f"{diffusers_downsample_prefix}.bias": checkpoint[f"{downsample_prefix}.bias"],
}
)
# attentions
if hasattr(up_block, "attentions"):
for attention_idx, _ in enumerate(up_block.attentions):
diffusers_attention_prefix = f"{diffusers_up_block_prefix}.attentions.{attention_idx}"
attention_prefix = f"{up_block_prefix}.attn.{attention_idx}"
diffusers_checkpoint.update(
vqvae_attention_to_diffusers_checkpoint(
checkpoint,
diffusers_attention_prefix=diffusers_attention_prefix,
attention_prefix=attention_prefix,
)
)
# mid block
# mid block attentions
# There is a single hardcoded attention block in the middle of the VQ-diffusion decoder
diffusers_attention_prefix = "decoder.mid_block.attentions.0"
attention_prefix = "decoder.mid.attn_1"
diffusers_checkpoint.update(
vqvae_attention_to_diffusers_checkpoint(
checkpoint, diffusers_attention_prefix=diffusers_attention_prefix, attention_prefix=attention_prefix
)
)
# mid block resnets
for diffusers_resnet_idx, resnet in enumerate(model.encoder.mid_block.resnets):
diffusers_resnet_prefix = f"decoder.mid_block.resnets.{diffusers_resnet_idx}"
# the hardcoded prefixes to `block_` are 1 and 2
orig_resnet_idx = diffusers_resnet_idx + 1
# There are two hardcoded resnets in the middle of the VQ-diffusion decoder
resnet_prefix = f"decoder.mid.block_{orig_resnet_idx}"
diffusers_checkpoint.update(
vqvae_resnet_to_diffusers_checkpoint(
resnet, checkpoint, diffusers_resnet_prefix=diffusers_resnet_prefix, resnet_prefix=resnet_prefix
)
)
diffusers_checkpoint.update(
{
# conv_norm_out
"decoder.conv_norm_out.weight": checkpoint["decoder.norm_out.weight"],
"decoder.conv_norm_out.bias": checkpoint["decoder.norm_out.bias"],
# conv_out
"decoder.conv_out.weight": checkpoint["decoder.conv_out.weight"],
"decoder.conv_out.bias": checkpoint["decoder.conv_out.bias"],
}
)
return diffusers_checkpoint
def vqvae_resnet_to_diffusers_checkpoint(resnet, checkpoint, *, diffusers_resnet_prefix, resnet_prefix):
rv = {
# norm1
f"{diffusers_resnet_prefix}.norm1.weight": checkpoint[f"{resnet_prefix}.norm1.weight"],
f"{diffusers_resnet_prefix}.norm1.bias": checkpoint[f"{resnet_prefix}.norm1.bias"],
# conv1
f"{diffusers_resnet_prefix}.conv1.weight": checkpoint[f"{resnet_prefix}.conv1.weight"],
f"{diffusers_resnet_prefix}.conv1.bias": checkpoint[f"{resnet_prefix}.conv1.bias"],
# norm2
f"{diffusers_resnet_prefix}.norm2.weight": checkpoint[f"{resnet_prefix}.norm2.weight"],
f"{diffusers_resnet_prefix}.norm2.bias": checkpoint[f"{resnet_prefix}.norm2.bias"],
# conv2
f"{diffusers_resnet_prefix}.conv2.weight": checkpoint[f"{resnet_prefix}.conv2.weight"],
f"{diffusers_resnet_prefix}.conv2.bias": checkpoint[f"{resnet_prefix}.conv2.bias"],
}
if resnet.conv_shortcut is not None:
rv.update(
{
f"{diffusers_resnet_prefix}.conv_shortcut.weight": checkpoint[f"{resnet_prefix}.nin_shortcut.weight"],
f"{diffusers_resnet_prefix}.conv_shortcut.bias": checkpoint[f"{resnet_prefix}.nin_shortcut.bias"],
}
)
return rv
def vqvae_attention_to_diffusers_checkpoint(checkpoint, *, diffusers_attention_prefix, attention_prefix):
return {
# group_norm
f"{diffusers_attention_prefix}.group_norm.weight": checkpoint[f"{attention_prefix}.norm.weight"],
f"{diffusers_attention_prefix}.group_norm.bias": checkpoint[f"{attention_prefix}.norm.bias"],
# query
f"{diffusers_attention_prefix}.query.weight": checkpoint[f"{attention_prefix}.q.weight"][:, :, 0, 0],
f"{diffusers_attention_prefix}.query.bias": checkpoint[f"{attention_prefix}.q.bias"],
# key
f"{diffusers_attention_prefix}.key.weight": checkpoint[f"{attention_prefix}.k.weight"][:, :, 0, 0],
f"{diffusers_attention_prefix}.key.bias": checkpoint[f"{attention_prefix}.k.bias"],
# value
f"{diffusers_attention_prefix}.value.weight": checkpoint[f"{attention_prefix}.v.weight"][:, :, 0, 0],
f"{diffusers_attention_prefix}.value.bias": checkpoint[f"{attention_prefix}.v.bias"],
# proj_attn
f"{diffusers_attention_prefix}.proj_attn.weight": checkpoint[f"{attention_prefix}.proj_out.weight"][
:, :, 0, 0
],
f"{diffusers_attention_prefix}.proj_attn.bias": checkpoint[f"{attention_prefix}.proj_out.bias"],
}
# done vqvae checkpoint
# transformer model
PORTED_DIFFUSIONS = ["image_synthesis.modeling.transformers.diffusion_transformer.DiffusionTransformer"]
PORTED_TRANSFORMERS = ["image_synthesis.modeling.transformers.transformer_utils.Text2ImageTransformer"]
PORTED_CONTENT_EMBEDDINGS = ["image_synthesis.modeling.embeddings.dalle_mask_image_embedding.DalleMaskImageEmbedding"]
def transformer_model_from_original_config(
original_diffusion_config, original_transformer_config, original_content_embedding_config
):
assert (
original_diffusion_config.target in PORTED_DIFFUSIONS
), f"{original_diffusion_config.target} has not yet been ported to diffusers."
assert (
original_transformer_config.target in PORTED_TRANSFORMERS
), f"{original_transformer_config.target} has not yet been ported to diffusers."
assert (
original_content_embedding_config.target in PORTED_CONTENT_EMBEDDINGS
), f"{original_content_embedding_config.target} has not yet been ported to diffusers."
original_diffusion_config = original_diffusion_config.params
original_transformer_config = original_transformer_config.params
original_content_embedding_config = original_content_embedding_config.params
inner_dim = original_transformer_config["n_embd"]
n_heads = original_transformer_config["n_head"]
# VQ-Diffusion gives dimension of the multi-headed attention layers as the
# number of attention heads times the sequence length (the dimension) of a
# single head. We want to specify our attention blocks with those values
# specified separately
assert inner_dim % n_heads == 0
d_head = inner_dim // n_heads
depth = original_transformer_config["n_layer"]
context_dim = original_transformer_config["condition_dim"]
num_embed = original_content_embedding_config["num_embed"]
# the number of embeddings in the transformer includes the mask embedding.
# the content embedding (the vqvae) does not include the mask embedding.
num_embed = num_embed + 1
height = original_transformer_config["content_spatial_size"][0]
width = original_transformer_config["content_spatial_size"][1]
assert width == height, "width has to be equal to height"
dropout = original_transformer_config["resid_pdrop"]
num_embeds_ada_norm = original_diffusion_config["diffusion_step"]
model_kwargs = {
"attention_bias": True,
"cross_attention_dim": context_dim,
"attention_head_dim": d_head,
"num_layers": depth,
"dropout": dropout,
"num_attention_heads": n_heads,
"num_vector_embeds": num_embed,
"num_embeds_ada_norm": num_embeds_ada_norm,
"norm_num_groups": 32,
"sample_size": width,
"activation_fn": "geglu-approximate",
}
model = Transformer2DModel(**model_kwargs)
return model
# done transformer model
# transformer checkpoint
def transformer_original_checkpoint_to_diffusers_checkpoint(model, checkpoint):
diffusers_checkpoint = {}
transformer_prefix = "transformer.transformer"
diffusers_latent_image_embedding_prefix = "latent_image_embedding"
latent_image_embedding_prefix = f"{transformer_prefix}.content_emb"
# DalleMaskImageEmbedding
diffusers_checkpoint.update(
{
f"{diffusers_latent_image_embedding_prefix}.emb.weight": checkpoint[
f"{latent_image_embedding_prefix}.emb.weight"
],
f"{diffusers_latent_image_embedding_prefix}.height_emb.weight": checkpoint[
f"{latent_image_embedding_prefix}.height_emb.weight"
],
f"{diffusers_latent_image_embedding_prefix}.width_emb.weight": checkpoint[
f"{latent_image_embedding_prefix}.width_emb.weight"
],
}
)
# transformer blocks
for transformer_block_idx, transformer_block in enumerate(model.transformer_blocks):
diffusers_transformer_block_prefix = f"transformer_blocks.{transformer_block_idx}"
transformer_block_prefix = f"{transformer_prefix}.blocks.{transformer_block_idx}"
# ada norm block
diffusers_ada_norm_prefix = f"{diffusers_transformer_block_prefix}.norm1"
ada_norm_prefix = f"{transformer_block_prefix}.ln1"
diffusers_checkpoint.update(
transformer_ada_norm_to_diffusers_checkpoint(
checkpoint, diffusers_ada_norm_prefix=diffusers_ada_norm_prefix, ada_norm_prefix=ada_norm_prefix
)
)
# attention block
diffusers_attention_prefix = f"{diffusers_transformer_block_prefix}.attn1"
attention_prefix = f"{transformer_block_prefix}.attn1"
diffusers_checkpoint.update(
transformer_attention_to_diffusers_checkpoint(
checkpoint, diffusers_attention_prefix=diffusers_attention_prefix, attention_prefix=attention_prefix
)
)
# ada norm block
diffusers_ada_norm_prefix = f"{diffusers_transformer_block_prefix}.norm2"
ada_norm_prefix = f"{transformer_block_prefix}.ln1_1"
diffusers_checkpoint.update(
transformer_ada_norm_to_diffusers_checkpoint(
checkpoint, diffusers_ada_norm_prefix=diffusers_ada_norm_prefix, ada_norm_prefix=ada_norm_prefix
)
)
# attention block
diffusers_attention_prefix = f"{diffusers_transformer_block_prefix}.attn2"
attention_prefix = f"{transformer_block_prefix}.attn2"
diffusers_checkpoint.update(
transformer_attention_to_diffusers_checkpoint(
checkpoint, diffusers_attention_prefix=diffusers_attention_prefix, attention_prefix=attention_prefix
)
)
# norm block
diffusers_norm_block_prefix = f"{diffusers_transformer_block_prefix}.norm3"
norm_block_prefix = f"{transformer_block_prefix}.ln2"
diffusers_checkpoint.update(
{
f"{diffusers_norm_block_prefix}.weight": checkpoint[f"{norm_block_prefix}.weight"],
f"{diffusers_norm_block_prefix}.bias": checkpoint[f"{norm_block_prefix}.bias"],
}
)
# feedforward block
diffusers_feedforward_prefix = f"{diffusers_transformer_block_prefix}.ff"
feedforward_prefix = f"{transformer_block_prefix}.mlp"
diffusers_checkpoint.update(
transformer_feedforward_to_diffusers_checkpoint(
checkpoint,
diffusers_feedforward_prefix=diffusers_feedforward_prefix,
feedforward_prefix=feedforward_prefix,
)
)
# to logits
diffusers_norm_out_prefix = "norm_out"
norm_out_prefix = f"{transformer_prefix}.to_logits.0"
diffusers_checkpoint.update(
{
f"{diffusers_norm_out_prefix}.weight": checkpoint[f"{norm_out_prefix}.weight"],
f"{diffusers_norm_out_prefix}.bias": checkpoint[f"{norm_out_prefix}.bias"],
}
)
diffusers_out_prefix = "out"
out_prefix = f"{transformer_prefix}.to_logits.1"
diffusers_checkpoint.update(
{
f"{diffusers_out_prefix}.weight": checkpoint[f"{out_prefix}.weight"],
f"{diffusers_out_prefix}.bias": checkpoint[f"{out_prefix}.bias"],
}
)
return diffusers_checkpoint
def transformer_ada_norm_to_diffusers_checkpoint(checkpoint, *, diffusers_ada_norm_prefix, ada_norm_prefix):
return {
f"{diffusers_ada_norm_prefix}.emb.weight": checkpoint[f"{ada_norm_prefix}.emb.weight"],
f"{diffusers_ada_norm_prefix}.linear.weight": checkpoint[f"{ada_norm_prefix}.linear.weight"],
f"{diffusers_ada_norm_prefix}.linear.bias": checkpoint[f"{ada_norm_prefix}.linear.bias"],
}
def transformer_attention_to_diffusers_checkpoint(checkpoint, *, diffusers_attention_prefix, attention_prefix):
return {
# key
f"{diffusers_attention_prefix}.to_k.weight": checkpoint[f"{attention_prefix}.key.weight"],
f"{diffusers_attention_prefix}.to_k.bias": checkpoint[f"{attention_prefix}.key.bias"],
# query
f"{diffusers_attention_prefix}.to_q.weight": checkpoint[f"{attention_prefix}.query.weight"],
f"{diffusers_attention_prefix}.to_q.bias": checkpoint[f"{attention_prefix}.query.bias"],
# value
f"{diffusers_attention_prefix}.to_v.weight": checkpoint[f"{attention_prefix}.value.weight"],
f"{diffusers_attention_prefix}.to_v.bias": checkpoint[f"{attention_prefix}.value.bias"],
# linear out
f"{diffusers_attention_prefix}.to_out.0.weight": checkpoint[f"{attention_prefix}.proj.weight"],
f"{diffusers_attention_prefix}.to_out.0.bias": checkpoint[f"{attention_prefix}.proj.bias"],
}
def transformer_feedforward_to_diffusers_checkpoint(checkpoint, *, diffusers_feedforward_prefix, feedforward_prefix):
return {
f"{diffusers_feedforward_prefix}.net.0.proj.weight": checkpoint[f"{feedforward_prefix}.0.weight"],
f"{diffusers_feedforward_prefix}.net.0.proj.bias": checkpoint[f"{feedforward_prefix}.0.bias"],
f"{diffusers_feedforward_prefix}.net.2.weight": checkpoint[f"{feedforward_prefix}.2.weight"],
f"{diffusers_feedforward_prefix}.net.2.bias": checkpoint[f"{feedforward_prefix}.2.bias"],
}
# done transformer checkpoint
def read_config_file(filename):
# The yaml file contains annotations that certain values should
# loaded as tuples. By default, OmegaConf will panic when reading
# these. Instead, we can manually read the yaml with the FullLoader and then
# construct the OmegaConf object.
with open(filename) as f:
original_config = yaml.load(f, FullLoader)
return OmegaConf.create(original_config)
# We take separate arguments for the vqvae because the ITHQ vqvae config file
# is separate from the config file for the rest of the model.
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--vqvae_checkpoint_path",
default=None,
type=str,
required=True,
help="Path to the vqvae checkpoint to convert.",
)
parser.add_argument(
"--vqvae_original_config_file",
default=None,
type=str,
required=True,
help="The YAML config file corresponding to the original architecture for the vqvae.",
)
parser.add_argument(
"--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert."
)
parser.add_argument(
"--original_config_file",
default=None,
type=str,
required=True,
help="The YAML config file corresponding to the original architecture.",
)
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
parser.add_argument(
"--checkpoint_load_device",
default="cpu",
type=str,
required=False,
help="The device passed to `map_location` when loading checkpoints.",
)
# See link for how ema weights are always selected
# https://github.com/microsoft/VQ-Diffusion/blob/3c98e77f721db7c787b76304fa2c96a36c7b00af/inference_VQ_Diffusion.py#L65
parser.add_argument(
"--no_use_ema",
action="store_true",
required=False,
help=(
"Set to not use the ema weights from the original VQ-Diffusion checkpoint. You probably do not want to set"
" it as the original VQ-Diffusion always uses the ema weights when loading models."
),
)
args = parser.parse_args()
use_ema = not args.no_use_ema
print(f"loading checkpoints to {args.checkpoint_load_device}")
checkpoint_map_location = torch.device(args.checkpoint_load_device)
# vqvae_model
print(f"loading vqvae, config: {args.vqvae_original_config_file}, checkpoint: {args.vqvae_checkpoint_path}")
vqvae_original_config = read_config_file(args.vqvae_original_config_file).model
vqvae_checkpoint = torch.load(args.vqvae_checkpoint_path, map_location=checkpoint_map_location)["model"]
with init_empty_weights():
vqvae_model = vqvae_model_from_original_config(vqvae_original_config)
vqvae_diffusers_checkpoint = vqvae_original_checkpoint_to_diffusers_checkpoint(vqvae_model, vqvae_checkpoint)
with tempfile.NamedTemporaryFile() as vqvae_diffusers_checkpoint_file:
torch.save(vqvae_diffusers_checkpoint, vqvae_diffusers_checkpoint_file.name)
del vqvae_diffusers_checkpoint
del vqvae_checkpoint
load_checkpoint_and_dispatch(vqvae_model, vqvae_diffusers_checkpoint_file.name, device_map="auto")
print("done loading vqvae")
# done vqvae_model
# transformer_model
print(
f"loading transformer, config: {args.original_config_file}, checkpoint: {args.checkpoint_path}, use ema:"
f" {use_ema}"
)
original_config = read_config_file(args.original_config_file).model
diffusion_config = original_config.params.diffusion_config
transformer_config = original_config.params.diffusion_config.params.transformer_config
content_embedding_config = original_config.params.diffusion_config.params.content_emb_config
pre_checkpoint = torch.load(args.checkpoint_path, map_location=checkpoint_map_location)
if use_ema:
if "ema" in pre_checkpoint:
checkpoint = {}
for k, v in pre_checkpoint["model"].items():
checkpoint[k] = v
for k, v in pre_checkpoint["ema"].items():
# The ema weights are only used on the transformer. To mimic their key as if they came
# from the state_dict for the top level model, we prefix with an additional "transformer."
# See the source linked in the args.use_ema config for more information.
checkpoint[f"transformer.{k}"] = v
else:
print("attempted to load ema weights but no ema weights are specified in the loaded checkpoint.")
checkpoint = pre_checkpoint["model"]
else:
checkpoint = pre_checkpoint["model"]
del pre_checkpoint
with init_empty_weights():
transformer_model = transformer_model_from_original_config(
diffusion_config, transformer_config, content_embedding_config
)
diffusers_transformer_checkpoint = transformer_original_checkpoint_to_diffusers_checkpoint(
transformer_model, checkpoint
)
with tempfile.NamedTemporaryFile() as diffusers_transformer_checkpoint_file:
torch.save(diffusers_transformer_checkpoint, diffusers_transformer_checkpoint_file.name)
del diffusers_transformer_checkpoint
del checkpoint
load_checkpoint_and_dispatch(transformer_model, diffusers_transformer_checkpoint_file.name, device_map="auto")
print("done loading transformer")
# done transformer_model
# text encoder
print("loading CLIP text encoder")
clip_name = "openai/clip-vit-base-patch32"
# The original VQ-Diffusion specifies the pad value by the int used in the
# returned tokens. Each model uses `0` as the pad value. The transformers clip api
# specifies the pad value via the token before it has been tokenized. The `!` pad
# token is the same as padding with the `0` pad value.
pad_token = "!"
tokenizer_model = CLIPTokenizer.from_pretrained(clip_name, pad_token=pad_token, device_map="auto")
assert tokenizer_model.convert_tokens_to_ids(pad_token) == 0
text_encoder_model = CLIPTextModel.from_pretrained(
clip_name,
# `CLIPTextModel` does not support device_map="auto"
# device_map="auto"
)
print("done loading CLIP text encoder")
# done text encoder
# scheduler
scheduler_model = VQDiffusionScheduler(
# the scheduler has the same number of embeddings as the transformer
num_vec_classes=transformer_model.num_vector_embeds
)
# done scheduler
print(f"saving VQ diffusion model, path: {args.dump_path}")
pipe = VQDiffusionPipeline(
vqvae=vqvae_model,
transformer=transformer_model,
tokenizer=tokenizer_model,
text_encoder=text_encoder_model,
scheduler=scheduler_model,
)
pipe.save_pretrained(args.dump_path)
print("done writing VQ diffusion model")

View File

@@ -89,10 +89,11 @@ _deps = [
"huggingface-hub>=0.10.0",
"importlib_metadata",
"isort>=5.5.4",
"jax>=0.2.8,!=0.3.2",
"jaxlib>=0.1.65",
"jax>=0.2.8,!=0.3.2,<=0.3.6",
"jaxlib>=0.1.65,<=0.3.6",
"modelcards>=0.1.4",
"numpy",
"onnxruntime",
"parameterized",
"pytest",
"pytest-timeout",
@@ -178,7 +179,9 @@ extras["quality"] = deps_list("black", "isort", "flake8", "hf-doc-builder")
extras["docs"] = deps_list("hf-doc-builder")
extras["training"] = deps_list("accelerate", "datasets", "tensorboard", "modelcards")
extras["test"] = deps_list(
"accelerate",
"datasets",
"onnxruntime",
"parameterized",
"pytest",
"pytest-timeout",
@@ -187,7 +190,7 @@ extras["test"] = deps_list(
"torchvision",
"transformers"
)
extras["torch"] = deps_list("torch", "accelerate")
extras["torch"] = deps_list("torch")
if os.name == "nt": # windows
extras["flax"] = [] # jax is not supported on windows
@@ -210,7 +213,7 @@ install_requires = [
setup(
name="diffusers",
version="0.8.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
version="0.7.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
description="Diffusers",
long_description=open("README.md", "r", encoding="utf-8").read(),
long_description_content_type="text/markdown",

View File

@@ -9,7 +9,7 @@ from .utils import (
)
__version__ = "0.8.0.dev0"
__version__ = "0.7.0.dev0"
from .configuration_utils import ConfigMixin
from .onnx_utils import OnnxRuntimeModel
@@ -18,7 +18,7 @@ from .utils import logging
if is_torch_available():
from .modeling_utils import ModelMixin
from .models import AutoencoderKL, Transformer2DModel, UNet1DModel, UNet2DConditionModel, UNet2DModel, VQModel
from .models import AutoencoderKL, UNet1DModel, UNet2DConditionModel, UNet2DModel, VQModel
from .optimization import (
get_constant_schedule,
get_constant_schedule_with_warmup,
@@ -36,22 +36,16 @@ if is_torch_available():
KarrasVePipeline,
LDMPipeline,
PNDMPipeline,
RePaintPipeline,
ScoreSdeVePipeline,
)
from .schedulers import (
DDIMScheduler,
DDPMScheduler,
DPMSolverMultistepScheduler,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
IPNDMScheduler,
KarrasVeScheduler,
PNDMScheduler,
RePaintScheduler,
SchedulerMixin,
ScoreSdeVeScheduler,
VQDiffusionScheduler,
)
from .training_utils import EMAModel
else:
@@ -64,13 +58,11 @@ else:
if is_torch_available() and is_transformers_available():
from .pipelines import (
CycleDiffusionPipeline,
LDMTextToImagePipeline,
StableDiffusionImg2ImgPipeline,
StableDiffusionInpaintPipeline,
StableDiffusionInpaintPipelineLegacy,
StableDiffusionPipeline,
VQDiffusionPipeline,
)
else:
from .utils.dummy_torch_and_transformers_objects import * # noqa F403
@@ -93,7 +85,6 @@ if is_flax_available():
from .schedulers import (
FlaxDDIMScheduler,
FlaxDDPMScheduler,
FlaxDPMSolverMultistepScheduler,
FlaxKarrasVeScheduler,
FlaxLMSDiscreteScheduler,
FlaxPNDMScheduler,

View File

@@ -16,7 +16,6 @@
""" ConfigMixin base class and utilities."""
import dataclasses
import functools
import importlib
import inspect
import json
import os
@@ -49,13 +48,9 @@ class ConfigMixin:
[`~ConfigMixin.save_config`] (should be overridden by parent class).
- **ignore_for_config** (`List[str]`) -- A list of attributes that should not be saved in the config (should be
overridden by parent class).
- **_compatible_classes** (`List[str]`) -- A list of classes that are compatible with the parent class, so that
`from_config` can be used from a class different than the one used to save the config (should be overridden
by parent class).
"""
config_name = None
ignore_for_config = []
_compatible_classes = []
def register_to_config(self, **kwargs):
if self.config_name is None:
@@ -285,14 +280,9 @@ class ConfigMixin:
return config_dict
@staticmethod
def _get_init_keys(cls):
return set(dict(inspect.signature(cls.__init__).parameters).keys())
@classmethod
def extract_init_dict(cls, config_dict, **kwargs):
# 1. Retrieve expected config attributes from __init__ signature
expected_keys = cls._get_init_keys(cls)
expected_keys = set(dict(inspect.signature(cls.__init__).parameters).keys())
expected_keys.remove("self")
# remove general kwargs if present in dict
if "kwargs" in expected_keys:
@@ -302,36 +292,9 @@ class ConfigMixin:
for arg in cls._flax_internal_args:
expected_keys.remove(arg)
# 2. Remove attributes that cannot be expected from expected config attributes
# remove keys to be ignored
if len(cls.ignore_for_config) > 0:
expected_keys = expected_keys - set(cls.ignore_for_config)
# load diffusers library to import compatible and original scheduler
diffusers_library = importlib.import_module(__name__.split(".")[0])
# remove attributes from compatible classes that orig cannot expect
compatible_classes = [getattr(diffusers_library, c, None) for c in cls._compatible_classes]
# filter out None potentially undefined dummy classes
compatible_classes = [c for c in compatible_classes if c is not None]
expected_keys_comp_cls = set()
for c in compatible_classes:
expected_keys_c = cls._get_init_keys(c)
expected_keys_comp_cls = expected_keys_comp_cls.union(expected_keys_c)
expected_keys_comp_cls = expected_keys_comp_cls - cls._get_init_keys(cls)
config_dict = {k: v for k, v in config_dict.items() if k not in expected_keys_comp_cls}
# remove attributes from orig class that cannot be expected
orig_cls_name = config_dict.pop("_class_name", cls.__name__)
if orig_cls_name != cls.__name__ and hasattr(diffusers_library, orig_cls_name):
orig_cls = getattr(diffusers_library, orig_cls_name)
unexpected_keys_from_orig = cls._get_init_keys(orig_cls) - expected_keys
config_dict = {k: v for k, v in config_dict.items() if k not in unexpected_keys_from_orig}
# remove private attributes
config_dict = {k: v for k, v in config_dict.items() if not k.startswith("_")}
# 3. Create keyword arguments that will be passed to __init__ from expected keyword arguments
init_dict = {}
for key in expected_keys:
if key in kwargs:
@@ -341,7 +304,8 @@ class ConfigMixin:
# use value from config dict
init_dict[key] = config_dict.pop(key)
# 4. Give nice warning if unexpected values have been passed
config_dict = {k: v for k, v in config_dict.items() if not k.startswith("_")}
if len(config_dict) > 0:
logger.warning(
f"The config attributes {config_dict} were passed to {cls.__name__}, "
@@ -349,16 +313,14 @@ class ConfigMixin:
f"{cls.config_name} configuration file."
)
# 5. Give nice info if config attributes are initiliazed to default because they have not been passed
unused_kwargs = {**config_dict, **kwargs}
passed_keys = set(init_dict.keys())
if len(expected_keys - passed_keys) > 0:
logger.info(
f"{expected_keys - passed_keys} was not found in config. Values will be initialized to default values."
)
# 6. Define unused keyword arguments
unused_kwargs = {**config_dict, **kwargs}
return init_dict, unused_kwargs
@classmethod

View File

@@ -13,10 +13,11 @@ deps = {
"huggingface-hub": "huggingface-hub>=0.10.0",
"importlib_metadata": "importlib_metadata",
"isort": "isort>=5.5.4",
"jax": "jax>=0.2.8,!=0.3.2",
"jaxlib": "jaxlib>=0.1.65",
"jax": "jax>=0.2.8,!=0.3.2,<=0.3.6",
"jaxlib": "jaxlib>=0.1.65,<=0.3.6",
"modelcards": "modelcards>=0.1.4",
"numpy": "numpy",
"onnxruntime": "onnxruntime",
"parameterized": "parameterized",
"pytest": "pytest",
"pytest-timeout": "pytest-timeout",

View File

@@ -16,25 +16,13 @@
import os
import shutil
import sys
from pathlib import Path
from typing import Dict, Optional, Union
from uuid import uuid4
from typing import Optional
from huggingface_hub import HfFolder, Repository, whoami
from . import __version__
from .utils import ENV_VARS_TRUE_VALUES, deprecate, logging
from .utils.import_utils import (
_flax_version,
_jax_version,
_onnxruntime_version,
_torch_version,
is_flax_available,
is_modelcards_available,
is_onnx_available,
is_torch_available,
)
from .pipeline_utils import DiffusionPipeline
from .utils import deprecate, is_modelcards_available, logging
if is_modelcards_available():
@@ -45,32 +33,6 @@ logger = logging.get_logger(__name__)
MODEL_CARD_TEMPLATE_PATH = Path(__file__).parent / "utils" / "model_card_template.md"
SESSION_ID = uuid4().hex
DISABLE_TELEMETRY = os.getenv("DISABLE_TELEMETRY", "").upper() in ENV_VARS_TRUE_VALUES
def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str:
"""
Formats a user-agent string with basic info about a request.
"""
ua = f"diffusers/{__version__}; python/{sys.version.split()[0]}; session_id/{SESSION_ID}"
if DISABLE_TELEMETRY:
return ua + "; telemetry/off"
if is_torch_available():
ua += f"; torch/{_torch_version}"
if is_flax_available():
ua += f"; jax/{_jax_version}"
ua += f"; flax/{_flax_version}"
if is_onnx_available():
ua += f"; onnxruntime/{_onnxruntime_version}"
# CI will set this value to True
if os.environ.get("DIFFUSERS_IS_CI", "").upper() in ENV_VARS_TRUE_VALUES:
ua += "; is_ci/true"
if isinstance(user_agent, dict):
ua += "; " + "; ".join(f"{k}/{v}" for k, v in user_agent.items())
elif isinstance(user_agent, str):
ua += "; " + user_agent
return ua
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
@@ -139,7 +101,7 @@ def init_git_repo(args, at_init: bool = False):
def push_to_hub(
args,
pipeline,
pipeline: DiffusionPipeline,
repo: Repository,
commit_message: Optional[str] = "End of training",
blocking: bool = True,

View File

@@ -21,37 +21,18 @@ from typing import Callable, List, Optional, Tuple, Union
import torch
from torch import Tensor, device
from diffusers.utils import is_accelerate_available
from huggingface_hub import hf_hub_download
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
from requests import HTTPError
from . import __version__
from .utils import (
CONFIG_NAME,
DIFFUSERS_CACHE,
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
WEIGHTS_NAME,
is_accelerate_available,
is_torch_version,
logging,
)
from .utils import CONFIG_NAME, DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, WEIGHTS_NAME, logging
logger = logging.get_logger(__name__)
if is_torch_version(">=", "1.9.0"):
_LOW_CPU_MEM_USAGE_DEFAULT = True
else:
_LOW_CPU_MEM_USAGE_DEFAULT = False
if is_accelerate_available():
import accelerate
from accelerate.utils import set_module_tensor_to_device
from accelerate.utils.versions import is_torch_version
def get_parameter_device(parameter: torch.nn.Module):
try:
return next(parameter.parameters()).device
@@ -287,19 +268,6 @@ class ModelMixin(torch.nn.Module):
Mirror source to accelerate downloads in China. If you are from China and have an accessibility
problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
Please refer to the mirror site for more information.
device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
A map that specifies where each submodule should go. It doesn't need to be refined to each
parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the
same device.
To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For
more information about each option see [designing a device
map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
Speed up model loading by not initializing the weights and only loading the pre-trained weights. This
also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the
model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch,
setting this argument to `True` will raise an error.
<Tip>
@@ -328,41 +296,6 @@ class ModelMixin(torch.nn.Module):
torch_dtype = kwargs.pop("torch_dtype", None)
subfolder = kwargs.pop("subfolder", None)
device_map = kwargs.pop("device_map", None)
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
if low_cpu_mem_usage and not is_accelerate_available():
low_cpu_mem_usage = False
logger.warn(
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
" install accelerate\n```\n."
)
if device_map is not None and not is_accelerate_available():
raise NotImplementedError(
"Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set"
" `device_map=None`. You can install accelerate with `pip install accelerate`."
)
# Check if we can handle device_map and dispatching the weights
if device_map is not None and not is_torch_version(">=", "1.9.0"):
raise NotImplementedError(
"Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
" `device_map=None`."
)
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
raise NotImplementedError(
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
" `low_cpu_mem_usage=False`."
)
if low_cpu_mem_usage is False and device_map is not None:
raise ValueError(
f"You cannot set `low_cpu_mem_usage` to `False` while using device_map={device_map} for loading and"
" dispatching. Please make sure to set `low_cpu_mem_usage=True`."
)
user_agent = {
"diffusers": __version__,
@@ -445,8 +378,12 @@ class ModelMixin(torch.nn.Module):
# restore default dtype
if low_cpu_mem_usage:
# Instantiate model with empty weights
if device_map == "auto":
if is_accelerate_available():
import accelerate
else:
raise ImportError("Please install accelerate via `pip install accelerate`")
with accelerate.init_empty_weights():
model, unused_kwargs = cls.from_config(
config_path,
@@ -463,17 +400,7 @@ class ModelMixin(torch.nn.Module):
**kwargs,
)
# if device_map is Non,e load the state dict on move the params from meta device to the cpu
if device_map is None:
param_device = "cpu"
state_dict = load_state_dict(model_file)
# move the parms from meta device to cpu
for param_name, param in state_dict.items():
set_module_tensor_to_device(model, param_name, param_device, value=param)
else: # else let accelerate handle loading and dispatching.
# Load weights and dispatch according to the device_map
# by deafult the device_map is None and the weights are loaded on the CPU
accelerate.load_checkpoint_and_dispatch(model, model_file, device_map)
accelerate.load_checkpoint_and_dispatch(model, model_file, device_map)
loading_info = {
"missing_keys": [],

View File

@@ -16,7 +16,6 @@ from ..utils import is_flax_available, is_torch_available
if is_torch_available():
from .attention import Transformer2DModel
from .unet_1d import UNet1DModel
from .unet_2d import UNet2DModel
from .unet_2d_condition import UNet2DConditionModel

View File

@@ -12,218 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from dataclasses import dataclass
from typing import Optional
import torch
import torch.nn.functional as F
from torch import nn
from ..configuration_utils import ConfigMixin, register_to_config
from ..modeling_utils import ModelMixin
from ..models.embeddings import ImagePositionalEmbeddings
from ..utils import BaseOutput
from ..utils.import_utils import is_xformers_available
@dataclass
class Transformer2DModelOutput(BaseOutput):
"""
Args:
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
Hidden states conditioned on `encoder_hidden_states` input. If discrete, returns probability distributions
for the unnoised latent pixels.
"""
sample: torch.FloatTensor
if is_xformers_available():
import xformers
import xformers.ops
else:
xformers = None
class Transformer2DModel(ModelMixin, ConfigMixin):
"""
Transformer model for image-like data. Takes either discrete (classes of vector embeddings) or continuous (actual
embeddings) inputs.
When input is continuous: First, project the input (aka embedding) and reshape to b, t, d. Then apply standard
transformer action. Finally, reshape to image.
When input is discrete: First, input (classes of latent pixels) is converted to embeddings and has positional
embeddings applied, see `ImagePositionalEmbeddings`. Then apply standard transformer action. Finally, predict
classes of unnoised image.
Note that it is assumed one of the input classes is the masked latent pixel. The predicted classes of the unnoised
image do not contain a prediction for the masked pixel as the unnoised image cannot be masked.
Parameters:
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
in_channels (`int`, *optional*):
Pass if the input is continuous. The number of channels in the input and output.
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use.
cross_attention_dim (`int`, *optional*): The number of context dimensions to use.
sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
Note that this is fixed at training time as it is used for learning a number of position embeddings. See
`ImagePositionalEmbeddings`.
num_vector_embeds (`int`, *optional*):
Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels.
Includes the class for the masked latent pixel.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`.
The number of diffusion steps used during training. Note that this is fixed at training time as it is used
to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for
up to but not more than steps than `num_embeds_ada_norm`.
attention_bias (`bool`, *optional*):
Configure if the TransformerBlocks' attention should contain a bias parameter.
"""
@register_to_config
def __init__(
self,
num_attention_heads: int = 16,
attention_head_dim: int = 88,
in_channels: Optional[int] = None,
num_layers: int = 1,
dropout: float = 0.0,
norm_num_groups: int = 32,
cross_attention_dim: Optional[int] = None,
attention_bias: bool = False,
sample_size: Optional[int] = None,
num_vector_embeds: Optional[int] = None,
activation_fn: str = "geglu",
num_embeds_ada_norm: Optional[int] = None,
):
super().__init__()
self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim
inner_dim = num_attention_heads * attention_head_dim
# 1. Transformer2DModel can process both standard continous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
# Define whether input is continuous or discrete depending on configuration
self.is_input_continuous = in_channels is not None
self.is_input_vectorized = num_vector_embeds is not None
if self.is_input_continuous and self.is_input_vectorized:
raise ValueError(
f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
" sure that either `in_channels` or `num_vector_embeds` is None."
)
elif not self.is_input_continuous and not self.is_input_vectorized:
raise ValueError(
f"Has to define either `in_channels`: {in_channels} or `num_vector_embeds`: {num_vector_embeds}. Make"
" sure that either `in_channels` or `num_vector_embeds` is not None."
)
# 2. Define input layers
if self.is_input_continuous:
self.in_channels = in_channels
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
elif self.is_input_vectorized:
assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
self.height = sample_size
self.width = sample_size
self.num_vector_embeds = num_vector_embeds
self.num_latent_pixels = self.height * self.width
self.latent_image_embedding = ImagePositionalEmbeddings(
num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
)
# 3. Define transformers blocks
self.transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
inner_dim,
num_attention_heads,
attention_head_dim,
dropout=dropout,
cross_attention_dim=cross_attention_dim,
activation_fn=activation_fn,
num_embeds_ada_norm=num_embeds_ada_norm,
attention_bias=attention_bias,
)
for d in range(num_layers)
]
)
# 4. Define output layers
if self.is_input_continuous:
self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
elif self.is_input_vectorized:
self.norm_out = nn.LayerNorm(inner_dim)
self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
def _set_attention_slice(self, slice_size):
for block in self.transformer_blocks:
block._set_attention_slice(slice_size)
def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True):
"""
Args:
hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
When continous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
hidden_states
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, context dim)`, *optional*):
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
self-attention.
timestep ( `torch.long`, *optional*):
Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
Returns:
[`~models.attention.Transformer2DModelOutput`] or `tuple`: [`~models.attention.Transformer2DModelOutput`]
if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample
tensor.
"""
# 1. Input
if self.is_input_continuous:
batch, channel, height, weight = hidden_states.shape
residual = hidden_states
hidden_states = self.norm(hidden_states)
hidden_states = self.proj_in(hidden_states)
inner_dim = hidden_states.shape[1]
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
elif self.is_input_vectorized:
hidden_states = self.latent_image_embedding(hidden_states)
# 2. Blocks
for block in self.transformer_blocks:
hidden_states = block(hidden_states, context=encoder_hidden_states, timestep=timestep)
# 3. Output
if self.is_input_continuous:
hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2)
hidden_states = self.proj_out(hidden_states)
output = hidden_states + residual
elif self.is_input_vectorized:
hidden_states = self.norm_out(hidden_states)
logits = self.out(hidden_states)
# (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
logits = logits.permute(0, 2, 1)
# log(p(x_0))
output = F.log_softmax(logits.double(), dim=1).float()
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)
def _set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
for block in self.transformer_blocks:
block._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
class AttentionBlock(nn.Module):
"""
@@ -233,19 +27,19 @@ class AttentionBlock(nn.Module):
Uses three q, k, v linear layers to compute attention.
Parameters:
channels (`int`): The number of channels in the input and output.
num_head_channels (`int`, *optional*):
channels (:obj:`int`): The number of channels in the input and output.
num_head_channels (:obj:`int`, *optional*):
The number of channels in each head. If None, then `num_heads` = 1.
norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for group norm.
rescale_output_factor (`float`, *optional*, defaults to 1.0): The factor to rescale the output by.
eps (`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm.
num_groups (:obj:`int`, *optional*, defaults to 32): The number of groups to use for group norm.
rescale_output_factor (:obj:`float`, *optional*, defaults to 1.0): The factor to rescale the output by.
eps (:obj:`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm.
"""
def __init__(
self,
channels: int,
num_head_channels: Optional[int] = None,
norm_num_groups: int = 32,
num_groups: int = 32,
rescale_output_factor: float = 1.0,
eps: float = 1e-5,
):
@@ -254,7 +48,7 @@ class AttentionBlock(nn.Module):
self.num_heads = channels // num_head_channels if num_head_channels is not None else 1
self.num_head_size = num_head_channels
self.group_norm = nn.GroupNorm(num_channels=channels, num_groups=norm_num_groups, eps=eps, affine=True)
self.group_norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=eps, affine=True)
# define q,k,v as linear layers
self.query = nn.Linear(channels, channels)
@@ -310,108 +104,112 @@ class AttentionBlock(nn.Module):
return hidden_states
class SpatialTransformer(nn.Module):
"""
Transformer block for image-like data. First, project the input (aka embedding) and reshape to b, t, d. Then apply
standard transformer action. Finally, reshape to image.
Parameters:
in_channels (:obj:`int`): The number of channels in the input and output.
n_heads (:obj:`int`): The number of heads to use for multi-head attention.
d_head (:obj:`int`): The number of channels in each head.
depth (:obj:`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
dropout (:obj:`float`, *optional*, defaults to 0.1): The dropout probability to use.
context_dim (:obj:`int`, *optional*): The number of context dimensions to use.
"""
def __init__(
self,
in_channels: int,
n_heads: int,
d_head: int,
depth: int = 1,
dropout: float = 0.0,
num_groups: int = 32,
context_dim: Optional[int] = None,
):
super().__init__()
self.n_heads = n_heads
self.d_head = d_head
self.in_channels = in_channels
inner_dim = n_heads * d_head
self.norm = torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
self.transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
for d in range(depth)
]
)
self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
def _set_attention_slice(self, slice_size):
for block in self.transformer_blocks:
block._set_attention_slice(slice_size)
def forward(self, hidden_states, context=None):
# note: if no context is given, cross-attention defaults to self-attention
batch, channel, height, weight = hidden_states.shape
residual = hidden_states
hidden_states = self.norm(hidden_states)
hidden_states = self.proj_in(hidden_states)
inner_dim = hidden_states.shape[1]
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
for block in self.transformer_blocks:
hidden_states = block(hidden_states, context=context)
hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2)
hidden_states = self.proj_out(hidden_states)
return hidden_states + residual
class BasicTransformerBlock(nn.Module):
r"""
A basic Transformer block.
Parameters:
dim (`int`): The number of channels in the input and output.
num_attention_heads (`int`): The number of heads to use for multi-head attention.
attention_head_dim (`int`): The number of channels in each head.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
cross_attention_dim (`int`, *optional*): The size of the context vector for cross attention.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
num_embeds_ada_norm (:
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
attention_bias (:
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
dim (:obj:`int`): The number of channels in the input and output.
n_heads (:obj:`int`): The number of heads to use for multi-head attention.
d_head (:obj:`int`): The number of channels in each head.
dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use.
context_dim (:obj:`int`, *optional*): The size of the context vector for cross attention.
gated_ff (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use a gated feed-forward network.
checkpoint (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use checkpointing.
"""
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
n_heads: int,
d_head: int,
dropout=0.0,
cross_attention_dim: Optional[int] = None,
activation_fn: str = "geglu",
num_embeds_ada_norm: Optional[int] = None,
attention_bias: bool = False,
context_dim: Optional[int] = None,
gated_ff: bool = True,
checkpoint: bool = True,
):
super().__init__()
self.attn1 = CrossAttention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout
) # is a self-attention
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
self.attn2 = CrossAttention(
query_dim=dim,
cross_attention_dim=cross_attention_dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, dropout=dropout
) # is self-attn if context is none
# layer norms
self.use_ada_layer_norm = num_embeds_ada_norm is not None
if self.use_ada_layer_norm:
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm)
else:
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.norm3 = nn.LayerNorm(dim)
self.checkpoint = checkpoint
def _set_attention_slice(self, slice_size):
self.attn1._slice_size = slice_size
self.attn2._slice_size = slice_size
def _set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
if not is_xformers_available():
print("Here is how to install it")
raise ModuleNotFoundError(
"Refer to https://github.com/facebookresearch/xformers for more information on how to install"
" xformers",
name="xformers",
)
elif not torch.cuda.is_available():
raise ValueError(
"torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
" available for GPU "
)
else:
try:
# Make sure we can run the memory efficient attention
_ = xformers.ops.memory_efficient_attention(
torch.randn((1, 2, 40), device="cuda"),
torch.randn((1, 2, 40), device="cuda"),
torch.randn((1, 2, 40), device="cuda"),
)
except Exception as e:
raise e
self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
def forward(self, hidden_states, context=None, timestep=None):
# 1. Self-Attention
norm_hidden_states = (
self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states)
)
hidden_states = self.attn1(norm_hidden_states) + hidden_states
# 2. Cross-Attention
norm_hidden_states = (
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
)
hidden_states = self.attn2(norm_hidden_states, context=context) + hidden_states
# 3. Feed-forward
def forward(self, hidden_states, context=None):
hidden_states = self.attn1(self.norm1(hidden_states)) + hidden_states
hidden_states = self.attn2(self.norm2(hidden_states), context=context) + hidden_states
hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
return hidden_states
@@ -420,28 +218,20 @@ class CrossAttention(nn.Module):
A cross attention layer.
Parameters:
query_dim (`int`): The number of channels in the query.
cross_attention_dim (`int`, *optional*):
query_dim (:obj:`int`): The number of channels in the query.
context_dim (:obj:`int`, *optional*):
The number of channels in the context. If not given, defaults to `query_dim`.
heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
bias (`bool`, *optional*, defaults to False):
Set to `True` for the query, key, and value linear layers to contain a bias parameter.
heads (:obj:`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
dim_head (:obj:`int`, *optional*, defaults to 64): The number of channels in each head.
dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use.
"""
def __init__(
self,
query_dim: int,
cross_attention_dim: Optional[int] = None,
heads: int = 8,
dim_head: int = 64,
dropout: float = 0.0,
bias=False,
self, query_dim: int, context_dim: Optional[int] = None, heads: int = 8, dim_head: int = 64, dropout: int = 0.0
):
super().__init__()
inner_dim = dim_head * heads
cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
context_dim = context_dim if context_dim is not None else query_dim
self.scale = dim_head**-0.5
self.heads = heads
@@ -449,15 +239,12 @@ class CrossAttention(nn.Module):
# is split across the batch axis to save memory
# You can set slice_size with `set_attention_slice`
self._slice_size = None
self._use_memory_efficient_attention_xformers = False
self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
self.to_out = nn.ModuleList([])
self.to_out.append(nn.Linear(inner_dim, query_dim))
self.to_out.append(nn.Dropout(dropout))
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
def reshape_heads_to_batch_dim(self, tensor):
batch_size, seq_len, dim = tensor.shape
@@ -490,19 +277,13 @@ class CrossAttention(nn.Module):
# TODO(PVP) - mask is currently never used. Remember to re-implement when used
# attention, what we cannot get enough of
if self._use_memory_efficient_attention_xformers:
hidden_states = self._memory_efficient_attention_xformers(query, key, value)
else:
if self._slice_size is None or query.shape[0] // self._slice_size == 1:
hidden_states = self._attention(query, key, value)
else:
hidden_states = self._sliced_attention(query, key, value, sequence_length, dim)
# linear proj
hidden_states = self.to_out[0](hidden_states)
# dropout
hidden_states = self.to_out[1](hidden_states)
return hidden_states
if self._slice_size is None or query.shape[0] // self._slice_size == 1:
hidden_states = self._attention(query, key, value)
else:
hidden_states = self._sliced_attention(query, key, value, sequence_length, dim)
return self.to_out(hidden_states)
def _attention(self, query, key, value):
# TODO: use baddbmm for better performance
@@ -554,53 +335,31 @@ class CrossAttention(nn.Module):
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
return hidden_states
def _memory_efficient_attention_xformers(self, query, key, value):
hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=None)
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
return hidden_states
class FeedForward(nn.Module):
r"""
A feed-forward layer.
Parameters:
dim (`int`): The number of channels in the input.
dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
dim (:obj:`int`): The number of channels in the input.
dim_out (:obj:`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
mult (:obj:`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
glu (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use GLU activation.
dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use.
"""
def __init__(
self,
dim: int,
dim_out: Optional[int] = None,
mult: int = 4,
dropout: float = 0.0,
activation_fn: str = "geglu",
self, dim: int, dim_out: Optional[int] = None, mult: int = 4, glu: bool = False, dropout: float = 0.0
):
super().__init__()
inner_dim = int(dim * mult)
dim_out = dim_out if dim_out is not None else dim
project_in = GEGLU(dim, inner_dim)
if activation_fn == "geglu":
geglu = GEGLU(dim, inner_dim)
elif activation_fn == "geglu-approximate":
geglu = ApproximateGELU(dim, inner_dim)
self.net = nn.ModuleList([])
# project in
self.net.append(geglu)
# project dropout
self.net.append(nn.Dropout(dropout))
# project out
self.net.append(nn.Linear(inner_dim, dim_out))
self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out))
def forward(self, hidden_states):
for module in self.net:
hidden_states = module(hidden_states)
return hidden_states
return self.net(hidden_states)
# feedforward
@@ -609,8 +368,8 @@ class GEGLU(nn.Module):
A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
Parameters:
dim_in (`int`): The number of channels in the input.
dim_out (`int`): The number of channels in the output.
dim_in (:obj:`int`): The number of channels in the input.
dim_out (:obj:`int`): The number of channels in the output.
"""
def __init__(self, dim_in: int, dim_out: int):
@@ -626,38 +385,3 @@ class GEGLU(nn.Module):
def forward(self, hidden_states):
hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
return hidden_states * self.gelu(gate)
class ApproximateGELU(nn.Module):
"""
The approximate form of Gaussian Error Linear Unit (GELU)
For more details, see section 2: https://arxiv.org/abs/1606.08415
"""
def __init__(self, dim_in: int, dim_out: int):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out)
def forward(self, x):
x = self.proj(x)
return x * torch.sigmoid(1.702 * x)
class AdaLayerNorm(nn.Module):
"""
Norm layer modified to incorporate timestep embeddings.
"""
def __init__(self, embedding_dim, num_embeddings):
super().__init__()
self.emb = nn.Embedding(num_embeddings, embedding_dim)
self.silu = nn.SiLU()
self.linear = nn.Linear(embedding_dim, embedding_dim * 2)
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False)
def forward(self, x, timestep):
emb = self.linear(self.silu(self.emb(timestep)))
scale, shift = torch.chunk(emb, 2)
x = self.norm(x) * (1 + scale) + shift
return x

View File

@@ -142,7 +142,7 @@ class FlaxBasicTransformerBlock(nn.Module):
return hidden_states
class FlaxTransformer2DModel(nn.Module):
class FlaxSpatialTransformer(nn.Module):
r"""
A Spatial Transformer layer with Gated Linear Unit (GLU) activation function as described in:
https://arxiv.org/pdf/1506.02025.pdf

View File

@@ -126,68 +126,3 @@ class GaussianFourierProjection(nn.Module):
else:
out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
return out
class ImagePositionalEmbeddings(nn.Module):
"""
Converts latent image classes into vector embeddings. Sums the vector embeddings with positional embeddings for the
height and width of the latent space.
For more details, see figure 10 of the dall-e paper: https://arxiv.org/abs/2102.12092
For VQ-diffusion:
Output vector embeddings are used as input for the transformer.
Note that the vector embeddings for the transformer are different than the vector embeddings from the VQVAE.
Args:
num_embed (`int`):
Number of embeddings for the latent pixels embeddings.
height (`int`):
Height of the latent image i.e. the number of height embeddings.
width (`int`):
Width of the latent image i.e. the number of width embeddings.
embed_dim (`int`):
Dimension of the produced vector embeddings. Used for the latent pixel, height, and width embeddings.
"""
def __init__(
self,
num_embed: int,
height: int,
width: int,
embed_dim: int,
):
super().__init__()
self.height = height
self.width = width
self.num_embed = num_embed
self.embed_dim = embed_dim
self.emb = nn.Embedding(self.num_embed, embed_dim)
self.height_emb = nn.Embedding(self.height, embed_dim)
self.width_emb = nn.Embedding(self.width, embed_dim)
def forward(self, index):
emb = self.emb(index)
height_emb = self.height_emb(torch.arange(self.height, device=index.device).view(1, self.height))
# 1 x H x D -> 1 x H x 1 x D
height_emb = height_emb.unsqueeze(2)
width_emb = self.width_emb(torch.arange(self.width, device=index.device).view(1, self.width))
# 1 x W x D -> 1 x 1 x W x D
width_emb = width_emb.unsqueeze(1)
pos_emb = height_emb + width_emb
# 1 x H x W x D -> 1 x L xD
pos_emb = pos_emb.view(1, self.height * self.width, -1)
emb = emb + pos_emb[:, : emb.shape[1], :]
return emb

View File

@@ -17,41 +17,23 @@ import flax.linen as nn
import jax.numpy as jnp
def get_sinusoidal_embeddings(
timesteps: jnp.ndarray,
embedding_dim: int,
freq_shift: float = 1,
min_timescale: float = 1,
max_timescale: float = 1.0e4,
flip_sin_to_cos: bool = False,
scale: float = 1.0,
) -> jnp.ndarray:
"""Returns the positional encoding (same as Tensor2Tensor).
Args:
timesteps: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
embedding_dim: The number of output channels.
min_timescale: The smallest time unit (should probably be 0.0).
max_timescale: The largest time unit.
Returns:
a Tensor of timing signals [N, num_channels]
# This is like models.embeddings.get_timestep_embedding (PyTorch) but
# less general (only handles the case we currently need).
def get_sinusoidal_embeddings(timesteps, embedding_dim, freq_shift: float = 1):
"""
assert timesteps.ndim == 1, "Timesteps should be a 1d-array"
assert embedding_dim % 2 == 0, f"Embedding dimension {embedding_dim} should be even"
num_timescales = float(embedding_dim // 2)
log_timescale_increment = math.log(max_timescale / min_timescale) / (num_timescales - freq_shift)
inv_timescales = min_timescale * jnp.exp(jnp.arange(num_timescales, dtype=jnp.float32) * -log_timescale_increment)
emb = jnp.expand_dims(timesteps, 1) * jnp.expand_dims(inv_timescales, 0)
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
# scale embeddings
scaled_time = scale * emb
if flip_sin_to_cos:
signal = jnp.concatenate([jnp.cos(scaled_time), jnp.sin(scaled_time)], axis=1)
else:
signal = jnp.concatenate([jnp.sin(scaled_time), jnp.cos(scaled_time)], axis=1)
signal = jnp.reshape(signal, [jnp.shape(timesteps)[0], embedding_dim])
return signal
:param timesteps: a 1-D tensor of N indices, one per batch element.
These may be fractional.
:param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
embeddings. :return: an [N x dim] tensor of positional embeddings.
"""
half_dim = embedding_dim // 2
emb = math.log(10000) / (half_dim - freq_shift)
emb = jnp.exp(jnp.arange(half_dim) * -emb)
emb = timesteps[:, None] * emb[None, :]
emb = jnp.concatenate([jnp.cos(emb), jnp.sin(emb)], -1)
return emb
class FlaxTimestepEmbedding(nn.Module):
@@ -88,6 +70,4 @@ class FlaxTimesteps(nn.Module):
@nn.compact
def __call__(self, timesteps):
return get_sinusoidal_embeddings(
timesteps, embedding_dim=self.dim, freq_shift=self.freq_shift, flip_sin_to_cos=True
)
return get_sinusoidal_embeddings(timesteps, self.dim, freq_shift=self.freq_shift)

View File

@@ -15,7 +15,7 @@ import numpy as np
import torch
from torch import nn
from .attention import AttentionBlock, Transformer2DModel
from .attention import AttentionBlock, SpatialTransformer
from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, ResnetBlock2D, Upsample2D
@@ -109,19 +109,6 @@ def get_down_block(
resnet_groups=resnet_groups,
downsample_padding=downsample_padding,
)
elif down_block_type == "AttnDownEncoderBlock2D":
return AttnDownEncoderBlock2D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
add_downsample=add_downsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
downsample_padding=downsample_padding,
attn_num_head_channels=attn_num_head_channels,
)
raise ValueError(f"{down_block_type} does not exist.")
def get_up_block(
@@ -213,17 +200,6 @@ def get_up_block(
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
)
elif up_block_type == "AttnUpDecoderBlock2D":
return AttnUpDecoderBlock2D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
add_upsample=add_upsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
attn_num_head_channels=attn_num_head_channels,
)
raise ValueError(f"{up_block_type} does not exist.")
@@ -273,7 +249,7 @@ class UNetMidBlock2D(nn.Module):
num_head_channels=attn_num_head_channels,
rescale_output_factor=output_scale_factor,
eps=resnet_eps,
norm_num_groups=resnet_groups,
num_groups=resnet_groups,
)
)
resnets.append(
@@ -349,13 +325,13 @@ class UNetMidBlock2DCrossAttn(nn.Module):
for _ in range(num_layers):
attentions.append(
Transformer2DModel(
SpatialTransformer(
in_channels,
attn_num_head_channels,
in_channels // attn_num_head_channels,
in_channels=in_channels,
num_layers=1,
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
depth=1,
context_dim=cross_attention_dim,
num_groups=resnet_groups,
)
)
resnets.append(
@@ -391,14 +367,10 @@ class UNetMidBlock2DCrossAttn(nn.Module):
for attn in self.attentions:
attn._set_attention_slice(slice_size)
def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
for attn in self.attentions:
attn._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
hidden_states = attn(hidden_states, encoder_hidden_states).sample
hidden_states = attn(hidden_states, encoder_hidden_states)
hidden_states = resnet(hidden_states, temb)
return hidden_states
@@ -451,7 +423,7 @@ class AttnDownBlock2D(nn.Module):
num_head_channels=attn_num_head_channels,
rescale_output_factor=output_scale_factor,
eps=resnet_eps,
norm_num_groups=resnet_groups,
num_groups=resnet_groups,
)
)
@@ -462,7 +434,7 @@ class AttnDownBlock2D(nn.Module):
self.downsamplers = nn.ModuleList(
[
Downsample2D(
out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
)
]
)
@@ -530,13 +502,13 @@ class CrossAttnDownBlock2D(nn.Module):
)
)
attentions.append(
Transformer2DModel(
SpatialTransformer(
out_channels,
attn_num_head_channels,
out_channels // attn_num_head_channels,
in_channels=out_channels,
num_layers=1,
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
depth=1,
context_dim=cross_attention_dim,
num_groups=resnet_groups,
)
)
self.attentions = nn.ModuleList(attentions)
@@ -546,7 +518,7 @@ class CrossAttnDownBlock2D(nn.Module):
self.downsamplers = nn.ModuleList(
[
Downsample2D(
out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
)
]
)
@@ -570,32 +542,25 @@ class CrossAttnDownBlock2D(nn.Module):
for attn in self.attentions:
attn._set_attention_slice(slice_size)
def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
for attn in self.attentions:
attn._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
output_states = ()
for resnet, attn in zip(self.resnets, self.attentions):
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def create_custom_forward(module):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return module(*inputs)
return custom_forward
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states
)[0]
create_custom_forward(attn), hidden_states, encoder_hidden_states
)
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
hidden_states = attn(hidden_states, context=encoder_hidden_states)
output_states += (hidden_states,)
@@ -651,7 +616,7 @@ class DownBlock2D(nn.Module):
self.downsamplers = nn.ModuleList(
[
Downsample2D(
out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
)
]
)
@@ -729,7 +694,7 @@ class DownEncoderBlock2D(nn.Module):
self.downsamplers = nn.ModuleList(
[
Downsample2D(
out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
)
]
)
@@ -790,7 +755,7 @@ class AttnDownEncoderBlock2D(nn.Module):
num_head_channels=attn_num_head_channels,
rescale_output_factor=output_scale_factor,
eps=resnet_eps,
norm_num_groups=resnet_groups,
num_groups=resnet_groups,
)
)
@@ -801,7 +766,7 @@ class AttnDownEncoderBlock2D(nn.Module):
self.downsamplers = nn.ModuleList(
[
Downsample2D(
out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
)
]
)
@@ -886,7 +851,7 @@ class AttnSkipDownBlock2D(nn.Module):
down=True,
kernel="fir",
)
self.downsamplers = nn.ModuleList([FirDownsample2D(out_channels, out_channels=out_channels)])
self.downsamplers = nn.ModuleList([FirDownsample2D(in_channels, out_channels=out_channels)])
self.skip_conv = nn.Conv2d(3, out_channels, kernel_size=(1, 1), stride=(1, 1))
else:
self.resnet_down = None
@@ -966,7 +931,7 @@ class SkipDownBlock2D(nn.Module):
down=True,
kernel="fir",
)
self.downsamplers = nn.ModuleList([FirDownsample2D(out_channels, out_channels=out_channels)])
self.downsamplers = nn.ModuleList([FirDownsample2D(in_channels, out_channels=out_channels)])
self.skip_conv = nn.Conv2d(3, out_channels, kernel_size=(1, 1), stride=(1, 1))
else:
self.resnet_down = None
@@ -1041,7 +1006,7 @@ class AttnUpBlock2D(nn.Module):
num_head_channels=attn_num_head_channels,
rescale_output_factor=output_scale_factor,
eps=resnet_eps,
norm_num_groups=resnet_groups,
num_groups=resnet_groups,
)
)
@@ -1088,6 +1053,7 @@ class CrossAttnUpBlock2D(nn.Module):
cross_attention_dim=1280,
attention_type="default",
output_scale_factor=1.0,
downsample_padding=1,
add_upsample=True,
):
super().__init__()
@@ -1116,13 +1082,13 @@ class CrossAttnUpBlock2D(nn.Module):
)
)
attentions.append(
Transformer2DModel(
SpatialTransformer(
out_channels,
attn_num_head_channels,
out_channels // attn_num_head_channels,
in_channels=out_channels,
num_layers=1,
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
depth=1,
context_dim=cross_attention_dim,
num_groups=resnet_groups,
)
)
self.attentions = nn.ModuleList(attentions)
@@ -1152,10 +1118,6 @@ class CrossAttnUpBlock2D(nn.Module):
self.gradient_checkpointing = False
def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
for attn in self.attentions:
attn._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
def forward(
self,
hidden_states,
@@ -1172,22 +1134,19 @@ class CrossAttnUpBlock2D(nn.Module):
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def create_custom_forward(module):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return module(*inputs)
return custom_forward
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states
)[0]
create_custom_forward(attn), hidden_states, encoder_hidden_states
)
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
hidden_states = attn(hidden_states, context=encoder_hidden_states)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
@@ -1367,7 +1326,7 @@ class AttnUpDecoderBlock2D(nn.Module):
num_head_channels=attn_num_head_channels,
rescale_output_factor=output_scale_factor,
eps=resnet_eps,
norm_num_groups=resnet_groups,
num_groups=resnet_groups,
)
)

View File

@@ -15,7 +15,7 @@
import flax.linen as nn
import jax.numpy as jnp
from .attention_flax import FlaxTransformer2DModel
from .attention_flax import FlaxSpatialTransformer
from .resnet_flax import FlaxDownsample2D, FlaxResnetBlock2D, FlaxUpsample2D
@@ -63,7 +63,7 @@ class FlaxCrossAttnDownBlock2D(nn.Module):
)
resnets.append(res_block)
attn_block = FlaxTransformer2DModel(
attn_block = FlaxSpatialTransformer(
in_channels=self.out_channels,
n_heads=self.attn_num_head_channels,
d_head=self.out_channels // self.attn_num_head_channels,
@@ -196,7 +196,7 @@ class FlaxCrossAttnUpBlock2D(nn.Module):
)
resnets.append(res_block)
attn_block = FlaxTransformer2DModel(
attn_block = FlaxSpatialTransformer(
in_channels=self.out_channels,
n_heads=self.attn_num_head_channels,
d_head=self.out_channels // self.attn_num_head_channels,
@@ -326,7 +326,7 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module):
attentions = []
for _ in range(self.num_layers):
attn_block = FlaxTransformer2DModel(
attn_block = FlaxSpatialTransformer(
in_channels=self.in_channels,
n_heads=self.attn_num_head_channels,
d_head=self.in_channels // self.attn_num_head_channels,

View File

@@ -225,17 +225,6 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
if hasattr(block, "attentions") and block.attentions is not None:
block.set_attention_slice(slice_size)
def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
for block in self.down_blocks:
if hasattr(block, "attentions") and block.attentions is not None:
block.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
self.mid_block.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
for block in self.up_blocks:
if hasattr(block, "attentions") and block.attentions is not None:
block.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D)):
module.gradient_checkpointing = value

View File

@@ -233,16 +233,14 @@ class VectorQuantizer(nn.Module):
# NOTE: due to a bug the beta term was applied to the wrong term. for
# backwards compatibility we use the buggy version by default, but you can
# specify legacy=False to fix it.
def __init__(
self, n_e, vq_embed_dim, beta, remap=None, unknown_index="random", sane_index_shape=False, legacy=True
):
def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random", sane_index_shape=False, legacy=True):
super().__init__()
self.n_e = n_e
self.vq_embed_dim = vq_embed_dim
self.e_dim = e_dim
self.beta = beta
self.legacy = legacy
self.embedding = nn.Embedding(self.n_e, self.vq_embed_dim)
self.embedding = nn.Embedding(self.n_e, self.e_dim)
self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
self.remap = remap
@@ -289,7 +287,7 @@ class VectorQuantizer(nn.Module):
def forward(self, z):
# reshape z -> (batch, height, width, channel) and flatten
z = z.permute(0, 2, 3, 1).contiguous()
z_flattened = z.view(-1, self.vq_embed_dim)
z_flattened = z.view(-1, self.e_dim)
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
d = (
@@ -411,7 +409,6 @@ class VQModel(ModelMixin, ConfigMixin):
latent_channels (`int`, *optional*, defaults to `3`): Number of channels in the latent space.
sample_size (`int`, *optional*, defaults to `32`): TODO
num_vq_embeddings (`int`, *optional*, defaults to `256`): Number of codebook vectors in the VQ-VAE.
vq_embed_dim (`int`, *optional*): Hidden dim of codebook vectors in the VQ-VAE.
"""
@register_to_config
@@ -428,7 +425,6 @@ class VQModel(ModelMixin, ConfigMixin):
sample_size: int = 32,
num_vq_embeddings: int = 256,
norm_num_groups: int = 32,
vq_embed_dim: Optional[int] = None,
):
super().__init__()
@@ -444,11 +440,11 @@ class VQModel(ModelMixin, ConfigMixin):
double_z=False,
)
vq_embed_dim = vq_embed_dim if vq_embed_dim is not None else latent_channels
self.quant_conv = torch.nn.Conv2d(latent_channels, vq_embed_dim, 1)
self.quantize = VectorQuantizer(num_vq_embeddings, vq_embed_dim, beta=0.25, remap=None, sane_index_shape=False)
self.post_quant_conv = torch.nn.Conv2d(vq_embed_dim, latent_channels, 1)
self.quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1)
self.quantize = VectorQuantizer(
num_vq_embeddings, latent_channels, beta=0.25, remap=None, sane_index_shape=False
)
self.post_quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1)
# pass init params to Decoder
self.decoder = Decoder(

View File

@@ -29,7 +29,6 @@ from PIL import Image
from tqdm.auto import tqdm
from .configuration_utils import ConfigMixin
from .hub_utils import http_user_agent
from .modeling_flax_utils import FLAX_WEIGHTS_NAME, FlaxModelMixin
from .schedulers.scheduling_utils_flax import SCHEDULER_CONFIG_NAME, FlaxSchedulerMixin
from .utils import CONFIG_NAME, DIFFUSERS_CACHE, BaseOutput, is_transformers_available, logging
@@ -161,10 +160,6 @@ class FlaxDiffusionPipeline(ConfigMixin):
for pipeline_component_name in model_index_dict.keys():
sub_model = getattr(self, pipeline_component_name)
if sub_model is None:
# edge case for saving a pipeline with safety_checker=None
continue
model_cls = sub_model.__class__
save_method_name = None
@@ -276,7 +271,7 @@ class FlaxDiffusionPipeline(ConfigMixin):
>>> # Download pipeline, but overwrite scheduler
>>> from diffusers import LMSDiscreteScheduler
>>> scheduler = LMSDiscreteScheduler.from_config("runwayml/stable-diffusion-v1-5", subfolder="scheduler")
>>> scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear")
>>> pipeline = FlaxDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", scheduler=scheduler)
```
"""
@@ -306,22 +301,6 @@ class FlaxDiffusionPipeline(ConfigMixin):
allow_patterns = [os.path.join(k, "*") for k in folder_names]
allow_patterns += [FLAX_WEIGHTS_NAME, SCHEDULER_CONFIG_NAME, CONFIG_NAME, cls.config_name]
# make sure we don't download PyTorch weights
ignore_patterns = "*.bin"
if cls != FlaxDiffusionPipeline:
requested_pipeline_class = cls.__name__
else:
requested_pipeline_class = config_dict.get("_class_name", cls.__name__)
requested_pipeline_class = (
requested_pipeline_class
if requested_pipeline_class.startswith("Flax")
else "Flax" + requested_pipeline_class
)
user_agent = {"pipeline_class": requested_pipeline_class}
user_agent = http_user_agent(user_agent)
# download all allow_patterns
cached_folder = snapshot_download(
pretrained_model_name_or_path,
@@ -332,8 +311,6 @@ class FlaxDiffusionPipeline(ConfigMixin):
use_auth_token=use_auth_token,
revision=revision,
allow_patterns=allow_patterns,
ignore_patterns=ignore_patterns,
user_agent=user_agent,
)
else:
cached_folder = pretrained_model_name_or_path
@@ -351,7 +328,7 @@ class FlaxDiffusionPipeline(ConfigMixin):
if config_dict["_class_name"].startswith("Flax")
else "Flax" + config_dict["_class_name"]
)
pipeline_class = getattr(diffusers_module, class_name)
pipeline_class = getattr(diffusers_module, config_dict["_class_name"])
# some modules can be passed directly to the init
# in this case they are already instantiated in `kwargs`
@@ -371,11 +348,6 @@ class FlaxDiffusionPipeline(ConfigMixin):
# 3. Load each module in the pipeline
for name, (library_name, class_name) in init_dict.items():
if class_name is None:
# edge case for when the pipeline was saved with safety_checker=None
init_kwargs[name] = None
continue
is_pipeline_module = hasattr(pipelines, library_name)
loaded_sub_model = None
sub_model_should_be_defined = True

View File

@@ -30,10 +30,9 @@ from packaging import version
from PIL import Image
from tqdm.auto import tqdm
from . import __version__
from .configuration_utils import ConfigMixin
from .dynamic_modules_utils import get_class_from_dynamic_module
from .hub_utils import http_user_agent
from .modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT
from .schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
from .utils import (
CONFIG_NAME,
@@ -42,8 +41,6 @@ from .utils import (
WEIGHTS_NAME,
BaseOutput,
deprecate,
is_accelerate_available,
is_torch_version,
is_transformers_available,
logging,
)
@@ -179,10 +176,6 @@ class DiffusionPipeline(ConfigMixin):
for pipeline_component_name in model_index_dict.keys():
sub_model = getattr(self, pipeline_component_name)
if sub_model is None:
# edge case for saving a pipeline with safety_checker=None
continue
model_cls = sub_model.__class__
save_method_name = None
@@ -209,13 +202,13 @@ class DiffusionPipeline(ConfigMixin):
for name in module_names.keys():
module = getattr(self, name)
if isinstance(module, torch.nn.Module):
if module.dtype == torch.float16 and str(torch_device) in ["cpu"]:
if module.dtype == torch.float16 and str(torch_device) in ["cpu", "mps"]:
logger.warning(
"Pipelines loaded with `torch_dtype=torch.float16` cannot run with `cpu` device. It"
" is not recommended to move them to `cpu` as running them will fail. Please make"
" sure to use an accelerator to run the pipeline in inference, due to the lack of"
" support for`float16` operations on this device in PyTorch. Please, remove the"
" `torch_dtype=torch.float16` argument, or use another device for inference."
"Pipelines loaded with `torch_dtype=torch.float16` cannot run with `cpu` or `mps` device. It"
" is not recommended to move them to `cpu` or `mps` as running them will fail. Please make"
" sure to use a `cuda` device to run the pipeline in inference. due to the lack of support for"
" `float16` operations on those devices in PyTorch. Please remove the"
" `torch_dtype=torch.float16` argument, or use a `cuda` device to run inference."
)
module.to(torch_device)
return self
@@ -331,19 +324,6 @@ class DiffusionPipeline(ConfigMixin):
Mirror source to accelerate downloads in China. If you are from China and have an accessibility
problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
Please refer to the mirror site for more information. specify the folder name here.
device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
A map that specifies where each submodule should go. It doesn't need to be refined to each
parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the
same device.
To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For
more information about each option see [designing a device
map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
Speed up model loading by not initializing the weights and only loading the pre-trained weights. This
also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the
model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch,
setting this argument to `True` will raise an error.
kwargs (remaining dictionary of keyword arguments, *optional*):
Can be used to overwrite load - and saveable variables - *i.e.* the pipeline components - of the
@@ -380,7 +360,7 @@ class DiffusionPipeline(ConfigMixin):
>>> # Download pipeline, but overwrite scheduler
>>> from diffusers import LMSDiscreteScheduler
>>> scheduler = LMSDiscreteScheduler.from_config("runwayml/stable-diffusion-v1-5", subfolder="scheduler")
>>> scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear")
>>> pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", scheduler=scheduler)
```
"""
@@ -396,34 +376,6 @@ class DiffusionPipeline(ConfigMixin):
provider = kwargs.pop("provider", None)
sess_options = kwargs.pop("sess_options", None)
device_map = kwargs.pop("device_map", None)
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
if low_cpu_mem_usage and not is_accelerate_available():
low_cpu_mem_usage = False
logger.warn(
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
" install accelerate\n```\n."
)
if device_map is not None and not is_torch_version(">=", "1.9.0"):
raise NotImplementedError(
"Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
" `device_map=None`."
)
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
raise NotImplementedError(
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
" `low_cpu_mem_usage=False`."
)
if low_cpu_mem_usage is False and device_map is not None:
raise ValueError(
f"You cannot set `low_cpu_mem_usage` to False while using device_map={device_map} for loading and"
" dispatching. Please make sure to set `low_cpu_mem_usage=True`."
)
# 1. Download the checkpoints and configs
# use snapshot download here to get it working from from_pretrained
@@ -443,20 +395,13 @@ class DiffusionPipeline(ConfigMixin):
allow_patterns = [os.path.join(k, "*") for k in folder_names]
allow_patterns += [WEIGHTS_NAME, SCHEDULER_CONFIG_NAME, CONFIG_NAME, ONNX_WEIGHTS_NAME, cls.config_name]
# make sure we don't download flax weights
ignore_patterns = "*.msgpack"
if custom_pipeline is not None:
allow_patterns += [CUSTOM_PIPELINE_FILE_NAME]
if cls != DiffusionPipeline:
requested_pipeline_class = cls.__name__
else:
requested_pipeline_class = config_dict.get("_class_name", cls.__name__)
user_agent = {"pipeline_class": requested_pipeline_class}
requested_pipeline_class = config_dict.get("_class_name", cls.__name__)
user_agent = {"diffusers": __version__, "pipeline_class": requested_pipeline_class}
if custom_pipeline is not None:
user_agent["custom_pipeline"] = custom_pipeline
user_agent = http_user_agent(user_agent)
# download all allow_patterns
cached_folder = snapshot_download(
@@ -468,7 +413,6 @@ class DiffusionPipeline(ConfigMixin):
use_auth_token=use_auth_token,
revision=revision,
allow_patterns=allow_patterns,
ignore_patterns=ignore_patterns,
user_agent=user_agent,
)
else:
@@ -525,11 +469,6 @@ class DiffusionPipeline(ConfigMixin):
# 3. Load each module in the pipeline
for name, (library_name, class_name) in init_dict.items():
if class_name is None:
# edge case for when the pipeline was saved with safety_checker=None
init_kwargs[name] = None
continue
# 3.1 - now that JAX/Flax is an official framework of the library, we might load from Flax names
if class_name.startswith("Flax"):
class_name = class_name[4:]
@@ -616,12 +555,8 @@ class DiffusionPipeline(ConfigMixin):
and version.parse(version.parse(transformers.__version__).base_version) >= version.parse("4.20.0")
)
# When loading a transformers model, if the device_map is None, the weights will be initialized as opposed to diffusers.
# To make default loading faster we set the `low_cpu_mem_usage=low_cpu_mem_usage` flag which is `True` by default.
# This makes sure that the weights won't be initialized which significantly speeds up loading.
if is_diffusers_model or is_transformers_model:
loading_kwargs["device_map"] = device_map
loading_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
# check if the module is in a subdirectory
if os.path.isdir(os.path.join(cached_folder, name)):
@@ -663,7 +598,7 @@ class DiffusionPipeline(ConfigMixin):
... StableDiffusionInpaintPipeline,
... )
>>> img2text = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
>>> img2text = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
>>> img2img = StableDiffusionImg2ImgPipeline(**img2text.components)
>>> inpaint = StableDiffusionInpaintPipeline(**img2text.components)
```

View File

@@ -16,7 +16,7 @@ or created independently from each other.
To that end, we strive to offer all open-sourced, state-of-the-art diffusion system under a unified API.
More specifically, we strive to provide pipelines that
- 1. can load the officially published weights and yield 1-to-1 the same outputs as the original implementation according to the corresponding paper (*e.g.* [LDMTextToImagePipeline](https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines/latent_diffusion), uses the officially released weights of [High-Resolution Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752)),
- 1. can load the officially published weights and yield 1-to-1 the same outputs as the original implementation according to the corresponding paper (*e.g.* [LatentDiffusionPipeline](https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines/latent_diffusion), uses the officially released weights of [High-Resolution Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752)),
- 2. have a simple user interface to run the model in inference (see the [Pipelines API](#pipelines-api) section),
- 3. are easy to understand with code that is self-explanatory and can be read along-side the official paper (see [Pipelines summary](#pipelines-summary)),
- 4. can easily be contributed by the community (see the [Contribution](#contribution) section).

View File

@@ -7,7 +7,6 @@ if is_torch_available():
from .ddpm import DDPMPipeline
from .latent_diffusion_uncond import LDMPipeline
from .pndm import PNDMPipeline
from .repaint import RePaintPipeline
from .score_sde_ve import ScoreSdeVePipeline
from .stochastic_karras_ve import KarrasVePipeline
else:
@@ -16,13 +15,11 @@ else:
if is_torch_available() and is_transformers_available():
from .latent_diffusion import LDMTextToImagePipeline
from .stable_diffusion import (
CycleDiffusionPipeline,
StableDiffusionImg2ImgPipeline,
StableDiffusionInpaintPipeline,
StableDiffusionInpaintPipelineLegacy,
StableDiffusionPipeline,
)
from .vq_diffusion import VQDiffusionPipeline
if is_transformers_available() and is_onnx_available():
from .stable_diffusion import (

View File

@@ -13,7 +13,7 @@
# limitations under the License.
import inspect
from typing import Optional, Tuple, Union
import torch
@@ -44,7 +44,6 @@ class DDIMPipeline(DiffusionPipeline):
generator: Optional[torch.Generator] = None,
eta: float = 0.0,
num_inference_steps: int = 50,
use_clipped_model_output: Optional[bool] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
**kwargs,
@@ -61,9 +60,6 @@ class DDIMPipeline(DiffusionPipeline):
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
use_clipped_model_output (`bool`, *optional*, defaults to `None`):
if `True` or `False`, see documentation for `DDIMScheduler.step`. If `None`, nothing is passed
downstream to the scheduler. So use `None` for schedulers which don't support this argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
@@ -86,14 +82,6 @@ class DDIMPipeline(DiffusionPipeline):
# set step values
self.scheduler.set_timesteps(num_inference_steps)
# Ignore use_clipped_model_output if the scheduler doesn't accept this argument
accepts_use_clipped_model_output = "use_clipped_model_output" in set(
inspect.signature(self.scheduler.step).parameters.keys()
)
extra_kwargs = {}
if accepts_use_clipped_model_output:
extra_kwargs["use_clipped_model_output"] = use_clipped_model_output
for t in self.progress_bar(self.scheduler.timesteps):
# 1. predict noise model_output
model_output = self.unet(image, t).sample
@@ -101,7 +89,7 @@ class DDIMPipeline(DiffusionPipeline):
# 2. predict previous mean of image x_t-1 and add variance depending on eta
# eta corresponds to η in paper and should be between [0, 1]
# do x_t -> x_t-1
image = self.scheduler.step(model_output, t, image, eta, **extra_kwargs).prev_sample
image = self.scheduler.step(model_output, t, image, eta).prev_sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()

View File

@@ -45,7 +45,6 @@ class DDPMPipeline(DiffusionPipeline):
num_inference_steps: int = 1000,
output_type: Optional[str] = "pil",
return_dict: bool = True,
predict_epsilon: bool = True,
**kwargs,
) -> Union[ImagePipelineOutput, Tuple]:
r"""
@@ -85,9 +84,7 @@ class DDPMPipeline(DiffusionPipeline):
model_output = self.unet(image, t).sample
# 2. compute previous image: x_t -> x_t-1
image = self.scheduler.step(
model_output, t, image, generator=generator, predict_epsilon=predict_epsilon
).prev_sample
image = self.scheduler.step(model_output, t, image, generator=generator).prev_sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()

View File

@@ -1 +0,0 @@
from .pipeline_repaint import RePaintPipeline

View File

@@ -1,140 +0,0 @@
# Copyright 2022 ETH Zurich Computer Vision Lab and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple, Union
import numpy as np
import torch
import PIL
from tqdm.auto import tqdm
from ...models import UNet2DModel
from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from ...schedulers import RePaintScheduler
def _preprocess_image(image: PIL.Image.Image):
image = np.array(image.convert("RGB"))
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
return image
def _preprocess_mask(mask: PIL.Image.Image):
mask = np.array(mask.convert("L"))
mask = mask.astype(np.float32) / 255.0
mask = mask[None, None]
mask[mask < 0.5] = 0
mask[mask >= 0.5] = 1
mask = torch.from_numpy(mask)
return mask
class RePaintPipeline(DiffusionPipeline):
unet: UNet2DModel
scheduler: RePaintScheduler
def __init__(self, unet, scheduler):
super().__init__()
self.register_modules(unet=unet, scheduler=scheduler)
@torch.no_grad()
def __call__(
self,
original_image: Union[torch.FloatTensor, PIL.Image.Image],
mask_image: Union[torch.FloatTensor, PIL.Image.Image],
num_inference_steps: int = 250,
eta: float = 0.0,
jump_length: int = 10,
jump_n_sample: int = 10,
generator: Optional[torch.Generator] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
) -> Union[ImagePipelineOutput, Tuple]:
r"""
Args:
original_image (`torch.FloatTensor` or `PIL.Image.Image`):
The original image to inpaint on.
mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
The mask_image where 0.0 values define which part of the original image to inpaint (change).
num_inference_steps (`int`, *optional*, defaults to 1000):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
eta (`float`):
The weight of noise for added noise in a diffusion step. Its value is between 0.0 and 1.0 - 0.0 is DDIM
and 1.0 is DDPM scheduler respectively.
jump_length (`int`, *optional*, defaults to 10):
The number of steps taken forward in time before going backward in time for a single jump ("j" in
RePaint paper). Take a look at Figure 9 and 10 in https://arxiv.org/pdf/2201.09865.pdf.
jump_n_sample (`int`, *optional*, defaults to 10):
The number of times we will make forward time jump for a given chosen time sample. Take a look at
Figure 9 and 10 in https://arxiv.org/pdf/2201.09865.pdf.
generator (`torch.Generator`, *optional*):
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
deterministic.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple.
Returns:
[`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if
`return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
generated images.
"""
if not isinstance(original_image, torch.FloatTensor):
original_image = _preprocess_image(original_image)
original_image = original_image.to(self.device)
if not isinstance(mask_image, torch.FloatTensor):
mask_image = _preprocess_mask(mask_image)
mask_image = mask_image.to(self.device)
# sample gaussian noise to begin the loop
image = torch.randn(
original_image.shape,
generator=generator,
device=self.device,
)
image = image.to(self.device)
# set step values
self.scheduler.set_timesteps(num_inference_steps, jump_length, jump_n_sample, self.device)
self.scheduler.eta = eta
t_last = self.scheduler.timesteps[0] + 1
for i, t in enumerate(tqdm(self.scheduler.timesteps)):
if t < t_last:
# predict the noise residual
model_output = self.unet(image, t).sample
# compute previous image: x_t -> x_t-1
image = self.scheduler.step(model_output, t, image, original_image, mask_image, generator).prev_sample
else:
# compute the reverse: x_t-1 -> x_t
image = self.scheduler.undo_step(image, t_last, generator)
t_last = t
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
if output_type == "pil":
image = self.numpy_to_pil(image)
if not return_dict:
return (image,)
return ImagePipelineOutput(images=image)

View File

@@ -72,7 +72,7 @@ image.save("astronaut_rides_horse.png")
# make sure you're logged in with `huggingface-cli login`
from diffusers import StableDiffusionPipeline, DDIMScheduler
scheduler = DDIMScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
pipe = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
@@ -91,7 +91,11 @@ image.save("astronaut_rides_horse.png")
# make sure you're logged in with `huggingface-cli login`
from diffusers import StableDiffusionPipeline, LMSDiscreteScheduler
lms = LMSDiscreteScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
lms = LMSDiscreteScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear"
)
pipe = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
@@ -103,74 +107,3 @@ image = pipe(prompt).sample[0]
image.save("astronaut_rides_horse.png")
```
### CycleDiffusion using Stable Diffusion and DDIM scheduler
```python
import requests
import torch
from PIL import Image
from io import BytesIO
from diffusers import CycleDiffusionPipeline, DDIMScheduler
# load the scheduler. CycleDiffusion only supports stochastic schedulers.
# load the pipeline
# make sure you're logged in with `huggingface-cli login`
model_id_or_path = "CompVis/stable-diffusion-v1-4"
scheduler = DDIMScheduler.from_config(model_id_or_path, subfolder="scheduler")
pipe = CycleDiffusionPipeline.from_pretrained(model_id_or_path, scheduler=scheduler).to("cuda")
# let's download an initial image
url = "https://raw.githubusercontent.com/ChenWu98/cycle-diffusion/main/data/dalle2/An%20astronaut%20riding%20a%20horse.png"
response = requests.get(url)
init_image = Image.open(BytesIO(response.content)).convert("RGB")
init_image = init_image.resize((512, 512))
init_image.save("horse.png")
# let's specify a prompt
source_prompt = "An astronaut riding a horse"
prompt = "An astronaut riding an elephant"
# call the pipeline
image = pipe(
prompt=prompt,
source_prompt=source_prompt,
init_image=init_image,
num_inference_steps=100,
eta=0.1,
strength=0.8,
guidance_scale=2,
source_guidance_scale=1,
).images[0]
image.save("horse_to_elephant.png")
# let's try another example
# See more samples at the original repo: https://github.com/ChenWu98/cycle-diffusion
url = "https://raw.githubusercontent.com/ChenWu98/cycle-diffusion/main/data/dalle2/A%20black%20colored%20car.png"
response = requests.get(url)
init_image = Image.open(BytesIO(response.content)).convert("RGB")
init_image = init_image.resize((512, 512))
init_image.save("black.png")
source_prompt = "A black colored car"
prompt = "A blue colored car"
# call the pipeline
torch.manual_seed(0)
image = pipe(
prompt=prompt,
source_prompt=source_prompt,
init_image=init_image,
num_inference_steps=100,
eta=0.1,
strength=0.85,
guidance_scale=3,
source_guidance_scale=1,
).images[0]
image.save("black_to_blue.png")
```

View File

@@ -28,7 +28,6 @@ class StableDiffusionPipelineOutput(BaseOutput):
if is_transformers_available() and is_torch_available():
from .pipeline_cycle_diffusion import CycleDiffusionPipeline
from .pipeline_stable_diffusion import StableDiffusionPipeline
from .pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline
from .pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipeline

View File

@@ -1,527 +0,0 @@
import inspect
from typing import Callable, List, Optional, Union
import numpy as np
import torch
import PIL
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from ...configuration_utils import FrozenDict
from ...models import AutoencoderKL, UNet2DConditionModel
from ...pipeline_utils import DiffusionPipeline
from ...schedulers import DDIMScheduler
from ...utils import deprecate, logging
from . import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def preprocess(image):
w, h = image.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
image = image.resize((w, h), resample=PIL.Image.LANCZOS)
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
return 2.0 * image - 1.0
def posterior_sample(scheduler, latents, timestep, clean_latents, eta):
# 1. get previous step value (=t-1)
prev_timestep = timestep - scheduler.config.num_train_timesteps // scheduler.num_inference_steps
if prev_timestep <= 0:
return clean_latents
# 2. compute alphas, betas
alpha_prod_t = scheduler.alphas_cumprod[timestep]
alpha_prod_t_prev = (
scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else scheduler.final_alpha_cumprod
)
variance = scheduler._get_variance(timestep, prev_timestep)
std_dev_t = eta * variance ** (0.5)
# direction pointing to x_t
e_t = (latents - alpha_prod_t ** (0.5) * clean_latents) / (1 - alpha_prod_t) ** (0.5)
dir_xt = (1.0 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * e_t
noise = std_dev_t * torch.randn(clean_latents.shape, dtype=clean_latents.dtype, device=clean_latents.device)
prev_latents = alpha_prod_t_prev ** (0.5) * clean_latents + dir_xt + noise
return prev_latents
def compute_noise(scheduler, prev_latents, latents, timestep, noise_pred, eta):
# 1. get previous step value (=t-1)
prev_timestep = timestep - scheduler.config.num_train_timesteps // scheduler.num_inference_steps
# 2. compute alphas, betas
alpha_prod_t = scheduler.alphas_cumprod[timestep]
alpha_prod_t_prev = (
scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else scheduler.final_alpha_cumprod
)
beta_prod_t = 1 - alpha_prod_t
# 3. compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_original_sample = (latents - beta_prod_t ** (0.5) * noise_pred) / alpha_prod_t ** (0.5)
# 4. Clip "predicted x_0"
if scheduler.config.clip_sample:
pred_original_sample = torch.clamp(pred_original_sample, -1, 1)
# 5. compute variance: "sigma_t(η)" -> see formula (16)
# σ_t = sqrt((1 α_t1)/(1 α_t)) * sqrt(1 α_t/α_t1)
variance = scheduler._get_variance(timestep, prev_timestep)
std_dev_t = eta * variance ** (0.5)
# 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * noise_pred
noise = (prev_latents - (alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction)) / (
variance ** (0.5) * eta
)
return noise
class CycleDiffusionPipeline(DiffusionPipeline):
r"""
Pipeline for text-guided image to image generation using Stable Diffusion.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`CLIPTextModel`]):
Frozen text-encoder. Stable Diffusion uses the text portion of
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
tokenizer (`CLIPTokenizer`):
Tokenizer of class
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
feature_extractor ([`CLIPFeatureExtractor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
def __init__(
self,
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
scheduler: DDIMScheduler,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
):
super().__init__()
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
" file"
)
deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(scheduler.config)
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
if safety_checker is None:
logger.warn(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
)
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
r"""
Enable sliced attention computation.
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
in several steps. This is useful to save some memory in exchange for a small speed decrease.
Args:
slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
`attention_head_dim` must be a multiple of `slice_size`.
"""
if slice_size == "auto":
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = self.unet.config.attention_head_dim // 2
self.unet.set_attention_slice(slice_size)
def disable_attention_slicing(self):
r"""
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
back to computing attention in one step.
"""
# set slice_size = `None` to disable `set_attention_slice`
self.enable_attention_slicing(None)
@torch.no_grad()
def __call__(
self,
prompt: Union[str, List[str]],
source_prompt: Union[str, List[str]],
init_image: Union[torch.FloatTensor, PIL.Image.Image],
strength: float = 0.8,
num_inference_steps: Optional[int] = 50,
guidance_scale: Optional[float] = 7.5,
source_guidance_scale: Optional[float] = 1,
num_images_per_prompt: Optional[int] = 1,
eta: Optional[float] = 0.1,
generator: Optional[torch.Generator] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: Optional[int] = 1,
**kwargs,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
init_image (`torch.FloatTensor` or `PIL.Image.Image`):
`Image`, or tensor representing an image batch, that will be used as the starting point for the
process.
strength (`float`, *optional*, defaults to 0.8):
Conceptually, indicates how much to transform the reference `init_image`. Must be between 0 and 1.
`init_image` will be used as a starting point, adding more noise to it the larger the `strength`. The
number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added
noise will be maximum and the denoising process will run for the full number of iterations specified in
`num_inference_steps`. A value of 1, therefore, essentially ignores `init_image`.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference. This parameter will be modulated by `strength`.
guidance_scale (`float`, *optional*, defaults to 7.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
source_guidance_scale (`float`, *optional*, defaults to 1):
Guidance scale for the source prompt. This is useful to control the amount of influence the source
prompt for encoding.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.1):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator`, *optional*):
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
deterministic.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
Returns:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
When returning a tuple, the first element is a list with the generated images, and the second element is a
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`.
"""
if isinstance(prompt, str):
batch_size = 1
elif isinstance(prompt, list):
batch_size = len(prompt)
else:
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if batch_size != 1:
raise ValueError(
"At the moment only `batch_size=1` is supported for prompts, but you seem to have passed multiple"
f" prompts: {prompt}. Please make sure to pass only a single prompt."
)
if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
# set timesteps
self.scheduler.set_timesteps(num_inference_steps)
if isinstance(init_image, PIL.Image.Image):
init_image = preprocess(init_image)
# get prompt text embeddings
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
return_tensors="pt",
)
source_text_inputs = self.tokenizer(
source_prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
source_text_input_ids = source_text_inputs.input_ids
if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
if source_text_input_ids.shape[-1] > self.tokenizer.model_max_length:
removed_text = self.tokenizer.batch_decode(source_text_input_ids[:, self.tokenizer.model_max_length :])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
source_text_input_ids = source_text_input_ids[:, : self.tokenizer.model_max_length]
text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
source_text_embeddings = self.text_encoder(source_text_input_ids.to(self.device))[0]
# duplicate text embeddings for each generation per prompt
text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
source_text_embeddings = source_text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
# get unconditional embeddings for classifier free guidance
uncond_tokens = [""]
max_length = text_input_ids.shape[-1]
uncond_input = self.tokenizer(
uncond_tokens,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
)
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
# duplicate unconditional embeddings for each generation per prompt
uncond_embeddings = uncond_embeddings.repeat_interleave(batch_size * num_images_per_prompt, dim=0)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
source_uncond_tokens = [""]
max_length = source_text_input_ids.shape[-1]
source_uncond_input = self.tokenizer(
source_uncond_tokens,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
)
source_uncond_embeddings = self.text_encoder(source_uncond_input.input_ids.to(self.device))[0]
# duplicate unconditional embeddings for each generation per prompt
source_uncond_embeddings = source_uncond_embeddings.repeat_interleave(
batch_size * num_images_per_prompt, dim=0
)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
source_text_embeddings = torch.cat([source_uncond_embeddings, source_text_embeddings])
# encode the init image into latents and scale the latents
latents_dtype = text_embeddings.dtype
init_image = init_image.to(device=self.device, dtype=latents_dtype)
init_latent_dist = self.vae.encode(init_image).latent_dist
init_latents = init_latent_dist.sample(generator=generator)
init_latents = 0.18215 * init_latents
if isinstance(prompt, str):
prompt = [prompt]
if len(prompt) > init_latents.shape[0] and len(prompt) % init_latents.shape[0] == 0:
# expand init_latents for batch_size
deprecation_message = (
f"You have passed {len(prompt)} text prompts (`prompt`), but only {init_latents.shape[0]} initial"
" images (`init_image`). Initial images are now duplicating to match the number of text prompts. Note"
" that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
" your script to pass as many init images as text prompts to suppress this warning."
)
deprecate("len(prompt) != len(init_image)", "1.0.0", deprecation_message, standard_warn=False)
additional_image_per_prompt = len(prompt) // init_latents.shape[0]
init_latents = torch.cat([init_latents] * additional_image_per_prompt * num_images_per_prompt, dim=0)
elif len(prompt) > init_latents.shape[0] and len(prompt) % init_latents.shape[0] != 0:
raise ValueError(
f"Cannot duplicate `init_image` of batch size {init_latents.shape[0]} to {len(prompt)} text prompts."
)
else:
init_latents = torch.cat([init_latents] * num_images_per_prompt, dim=0)
# get the original timestep using init_timestep
offset = self.scheduler.config.get("steps_offset", 0)
init_timestep = int(num_inference_steps * strength) + offset
init_timestep = min(init_timestep, num_inference_steps)
timesteps = self.scheduler.timesteps[-init_timestep]
timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt, device=self.device)
# add noise to latents using the timesteps
noise = torch.randn(init_latents.shape, generator=generator, device=self.device, dtype=latents_dtype)
clean_latents = init_latents
init_latents = self.scheduler.add_noise(init_latents, noise, timesteps)
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if not (accepts_eta and (0 < eta <= 1)):
raise ValueError(
"Currently, only the DDIM scheduler is supported. Please make sure that `pipeline.scheduler` is of"
f" type {DDIMScheduler.__class__} and not {self.scheduler.__class__}."
)
extra_step_kwargs["eta"] = eta
latents = init_latents
source_latents = init_latents
t_start = max(num_inference_steps - init_timestep + offset, 0)
# Some schedulers like PNDM have timesteps as arrays
# It's more optimized to move all timesteps to correct device beforehand
timesteps = self.scheduler.timesteps[t_start:].to(self.device)
for i, t in enumerate(self.progress_bar(timesteps)):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2)
source_latent_model_input = torch.cat([source_latents] * 2)
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
source_latent_model_input = self.scheduler.scale_model_input(source_latent_model_input, t)
# predict the noise residual
concat_latent_model_input = torch.stack(
[
source_latent_model_input[0],
latent_model_input[0],
source_latent_model_input[1],
latent_model_input[1],
],
dim=0,
)
concat_text_embeddings = torch.stack(
[
source_text_embeddings[0],
text_embeddings[0],
source_text_embeddings[1],
text_embeddings[1],
],
dim=0,
)
concat_noise_pred = self.unet(
concat_latent_model_input, t, encoder_hidden_states=concat_text_embeddings
).sample
# perform guidance
(
source_noise_pred_uncond,
noise_pred_uncond,
source_noise_pred_text,
noise_pred_text,
) = concat_noise_pred.chunk(4, dim=0)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
source_noise_pred = source_noise_pred_uncond + source_guidance_scale * (
source_noise_pred_text - source_noise_pred_uncond
)
# Sample source_latents from the posterior distribution.
prev_source_latents = posterior_sample(
self.scheduler, source_latents, t, clean_latents, **extra_step_kwargs
)
# Compute noise.
noise = compute_noise(
self.scheduler, prev_source_latents, source_latents, t, source_noise_pred, **extra_step_kwargs
)
source_latents = prev_source_latents
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(
noise_pred, t, latents, variance_noise=noise, **extra_step_kwargs
).prev_sample
# call the callback, if provided
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
latents = 1 / 0.18215 * latents
image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
if self.safety_checker is not None:
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(
self.device
)
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
)
else:
has_nsfw_concept = None
if output_type == "pil":
image = self.numpy_to_pil(image)
if not return_dict:
return (image, has_nsfw_concept)
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)

View File

@@ -14,12 +14,7 @@ from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel
from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel
from ...pipeline_flax_utils import FlaxDiffusionPipeline
from ...schedulers import (
FlaxDDIMScheduler,
FlaxDPMSolverMultistepScheduler,
FlaxLMSDiscreteScheduler,
FlaxPNDMScheduler,
)
from ...schedulers import FlaxDDIMScheduler, FlaxLMSDiscreteScheduler, FlaxPNDMScheduler
from ...utils import logging
from . import FlaxStableDiffusionPipelineOutput
from .safety_checker_flax import FlaxStableDiffusionSafetyChecker
@@ -48,8 +43,7 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
unet ([`FlaxUNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
[`FlaxDDIMScheduler`], [`FlaxLMSDiscreteScheduler`], [`FlaxPNDMScheduler`], or
[`FlaxDPMSolverMultistepScheduler`].
[`FlaxDDIMScheduler`], [`FlaxLMSDiscreteScheduler`], or [`FlaxPNDMScheduler`].
safety_checker ([`FlaxStableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
@@ -63,9 +57,7 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
text_encoder: FlaxCLIPTextModel,
tokenizer: CLIPTokenizer,
unet: FlaxUNet2DConditionModel,
scheduler: Union[
FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler, FlaxDPMSolverMultistepScheduler
],
scheduler: Union[FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler],
safety_checker: FlaxStableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
dtype: jnp.dtype = jnp.float32,

View File

@@ -5,7 +5,6 @@ import numpy as np
from transformers import CLIPFeatureExtractor, CLIPTokenizer
from ...configuration_utils import FrozenDict
from ...onnx_utils import OnnxRuntimeModel
from ...pipeline_utils import DiffusionPipeline
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
@@ -37,34 +36,6 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
feature_extractor: CLIPFeatureExtractor,
):
super().__init__()
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
" file"
)
deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(scheduler.config)
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
)
deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(scheduler.config)
new_config["clip_sample"] = False
scheduler._internal_dict = FrozenDict(new_config)
self.register_modules(
vae_encoder=vae_encoder,
vae_decoder=vae_decoder,

View File

@@ -90,19 +90,6 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
)
deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(scheduler.config)
new_config["clip_sample"] = False
scheduler._internal_dict = FrozenDict(new_config)
if safety_checker is None:
logger.warning(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"

View File

@@ -104,19 +104,6 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
)
deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(scheduler.config)
new_config["clip_sample"] = False
scheduler._internal_dict = FrozenDict(new_config)
if safety_checker is None:
logger.warning(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"

View File

@@ -9,14 +9,7 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from ...configuration_utils import FrozenDict
from ...models import AutoencoderKL, UNet2DConditionModel
from ...pipeline_utils import DiffusionPipeline
from ...schedulers import (
DDIMScheduler,
DPMSolverMultistepScheduler,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
LMSDiscreteScheduler,
PNDMScheduler,
)
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from ...utils import deprecate, logging
from . import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker
@@ -59,14 +52,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
scheduler: Union[
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
],
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
):
@@ -86,19 +72,6 @@ class StableDiffusionPipeline(DiffusionPipeline):
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
)
deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(scheduler.config)
new_config["clip_sample"] = False
scheduler._internal_dict = FrozenDict(new_config)
if safety_checker is None:
logger.warn(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
@@ -119,24 +92,6 @@ class StableDiffusionPipeline(DiffusionPipeline):
feature_extractor=feature_extractor,
)
def enable_xformers_memory_efficient_attention(self):
r"""
Enable memory efficient attention as implemented in xformers.
When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
time. Speed up at training time is not guaranteed.
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
is used.
"""
self.unet.set_use_memory_efficient_attention_xformers(True)
def disable_xformers_memory_efficient_attention(self):
r"""
Disable memory efficient attention as implemented in xformers.
"""
self.unet.set_use_memory_efficient_attention_xformers(False)
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
r"""
Enable sliced attention computation.
@@ -178,8 +133,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
device = torch.device("cuda")
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]:
if cpu_offloaded_model is not None:
cpu_offload(cpu_offloaded_model, device)
cpu_offload(cpu_offloaded_model, device)
@torch.no_grad()
def __call__(
@@ -303,7 +257,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
if do_classifier_free_guidance:
uncond_tokens: List[str]
if negative_prompt is None:
uncond_tokens = [""] * batch_size
uncond_tokens = [""]
elif type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
@@ -332,7 +286,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = uncond_embeddings.shape[1]
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
uncond_embeddings = uncond_embeddings.repeat(batch_size, num_images_per_prompt, 1)
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
# For classifier free guidance, we need to do two forward passes.
@@ -379,11 +333,6 @@ class StableDiffusionPipeline(DiffusionPipeline):
if accepts_eta:
extra_step_kwargs["eta"] = eta
# check if the scheduler accepts generator
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
if accepts_generator:
extra_step_kwargs["generator"] = generator
for i, t in enumerate(self.progress_bar(timesteps_tensor)):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents

View File

@@ -5,19 +5,12 @@ import numpy as np
import torch
import PIL
from diffusers.utils import is_accelerate_available
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from ...configuration_utils import FrozenDict
from ...models import AutoencoderKL, UNet2DConditionModel
from ...pipeline_utils import DiffusionPipeline
from ...schedulers import (
DDIMScheduler,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
LMSDiscreteScheduler,
PNDMScheduler,
)
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from ...utils import deprecate, logging
from . import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker
@@ -70,9 +63,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
scheduler: Union[
DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler
],
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
):
@@ -92,19 +83,6 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
)
deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(scheduler.config)
new_config["clip_sample"] = False
scheduler._internal_dict = FrozenDict(new_config)
if safety_checker is None:
logger.warn(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
@@ -152,41 +130,6 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
# set slice_size = `None` to disable `set_attention_slice`
self.enable_attention_slicing(None)
def enable_sequential_cpu_offload(self):
r"""
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
`torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
"""
if is_accelerate_available():
from accelerate import cpu_offload
else:
raise ImportError("Please install accelerate via `pip install accelerate`")
device = torch.device("cuda")
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]:
if cpu_offloaded_model is not None:
cpu_offload(cpu_offloaded_model, device)
def enable_xformers_memory_efficient_attention(self):
r"""
Enable memory efficient attention as implemented in xformers.
When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
time. Speed up at training time is not guaranteed.
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
is used.
"""
self.unet.set_use_memory_efficient_attention_xformers(True)
def disable_xformers_memory_efficient_attention(self):
r"""
Disable memory efficient attention as implemented in xformers.
"""
self.unet.set_use_memory_efficient_attention_xformers(False)
@torch.no_grad()
def __call__(
self,
@@ -313,7 +256,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
if do_classifier_free_guidance:
uncond_tokens: List[str]
if negative_prompt is None:
uncond_tokens = [""] * batch_size
uncond_tokens = [""]
elif type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
@@ -337,9 +280,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
# duplicate unconditional embeddings for each generation per prompt
seq_len = uncond_embeddings.shape[1]
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
uncond_embeddings = uncond_embeddings.repeat_interleave(batch_size * num_images_per_prompt, dim=0)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
@@ -394,11 +335,6 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
if accepts_eta:
extra_step_kwargs["eta"] = eta
# check if the scheduler accepts generator
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
if accepts_generator:
extra_step_kwargs["generator"] = generator
latents = init_latents
t_start = max(num_inference_steps - init_timestep + offset, 0)

View File

@@ -5,7 +5,6 @@ import numpy as np
import torch
import PIL
from diffusers.utils import is_accelerate_available
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from ...configuration_utils import FrozenDict
@@ -91,20 +90,6 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
if hasattr(scheduler.config, "skip_prk_steps") and scheduler.config.skip_prk_steps is False:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration"
" `skip_prk_steps`. `skip_prk_steps` should be set to True in the configuration file. Please make"
" sure to update the config accordingly as not setting `skip_prk_steps` in the config might lead to"
" incorrect results in future versions. If you have downloaded this checkpoint from the Hugging Face"
" Hub, it would be very nice if you could open a Pull request for the"
" `scheduler/scheduler_config.json` file"
)
deprecate("skip_prk_steps not set", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(scheduler.config)
new_config["skip_prk_steps"] = True
scheduler._internal_dict = FrozenDict(new_config)
if safety_checker is None:
logger.warn(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
@@ -152,41 +137,6 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
# set slice_size = `None` to disable `attention slicing`
self.enable_attention_slicing(None)
def enable_sequential_cpu_offload(self):
r"""
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
`torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
"""
if is_accelerate_available():
from accelerate import cpu_offload
else:
raise ImportError("Please install accelerate via `pip install accelerate`")
device = torch.device("cuda")
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]:
if cpu_offloaded_model is not None:
cpu_offload(cpu_offloaded_model, device)
def enable_xformers_memory_efficient_attention(self):
r"""
Enable memory efficient attention as implemented in xformers.
When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
time. Speed up at training time is not guaranteed.
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
is used.
"""
self.unet.set_use_memory_efficient_attention_xformers(True)
def disable_xformers_memory_efficient_attention(self):
r"""
Disable memory efficient attention as implemented in xformers.
"""
self.unet.set_use_memory_efficient_attention_xformers(False)
@torch.no_grad()
def __call__(
self,
@@ -320,7 +270,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
if do_classifier_free_guidance:
uncond_tokens: List[str]
if negative_prompt is None:
uncond_tokens = [""] * batch_size
uncond_tokens = [""]
elif type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
@@ -349,7 +299,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = uncond_embeddings.shape[1]
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
uncond_embeddings = uncond_embeddings.repeat(batch_size, num_images_per_prompt, 1)
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
# For classifier free guidance, we need to do two forward passes.
@@ -379,14 +329,11 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
# prepare mask and masked_image
mask, masked_image = prepare_mask_and_masked_image(image, mask_image)
mask = mask.to(device=self.device, dtype=text_embeddings.dtype)
masked_image = masked_image.to(device=self.device, dtype=text_embeddings.dtype)
# resize the mask to latents shape as we concatenate the mask to the latents
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload
# and half precision
mask = torch.nn.functional.interpolate(mask, size=(height // 8, width // 8))
mask = mask.to(device=self.device, dtype=text_embeddings.dtype)
masked_image = masked_image.to(device=self.device, dtype=text_embeddings.dtype)
# encode the mask image into latents space so we can concatenate it to the latents
masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(generator=generator)
@@ -401,9 +348,6 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents
)
# aligning device to prevent device errors when concating it with the latent model input
masked_image_latents = masked_image_latents.to(device=self.device, dtype=text_embeddings.dtype)
num_channels_mask = mask.shape[1]
num_channels_masked_image = masked_image_latents.shape[1]
@@ -435,11 +379,6 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
if accepts_eta:
extra_step_kwargs["eta"] = eta
# check if the scheduler accepts generator
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
if accepts_generator:
extra_step_kwargs["generator"] = generator
for i, t in enumerate(self.progress_bar(timesteps_tensor)):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents

View File

@@ -96,19 +96,6 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
)
deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(scheduler.config)
new_config["clip_sample"] = False
scheduler._internal_dict = FrozenDict(new_config)
if safety_checker is None:
logger.warn(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
@@ -284,7 +271,7 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
if do_classifier_free_guidance:
uncond_tokens: List[str]
if negative_prompt is None:
uncond_tokens = [""] * batch_size
uncond_tokens = [""]
elif type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
@@ -312,9 +299,7 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
# duplicate unconditional embeddings for each generation per prompt
seq_len = uncond_embeddings.shape[1]
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
uncond_embeddings = uncond_embeddings.repeat_interleave(batch_size * num_images_per_prompt, dim=0)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
@@ -367,11 +352,6 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
if accepts_eta:
extra_step_kwargs["eta"] = eta
# check if the scheduler accepts generator
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
if accepts_generator:
extra_step_kwargs["generator"] = generator
latents = init_latents
t_start = max(num_inference_steps - init_timestep + offset, 0)

View File

@@ -1 +0,0 @@
from .pipeline_vq_diffusion import VQDiffusionPipeline

View File

@@ -1,253 +0,0 @@
# Copyright 2022 Microsoft and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, List, Optional, Tuple, Union
import torch
from diffusers import Transformer2DModel, VQModel
from diffusers.schedulers.scheduling_vq_diffusion import VQDiffusionScheduler
from transformers import CLIPTextModel, CLIPTokenizer
from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from ...utils import logging
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class VQDiffusionPipeline(DiffusionPipeline):
r"""
Pipeline for text-to-image generation using VQ Diffusion
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Args:
vqvae ([`VQModel`]):
Vector Quantized Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent
representations.
text_encoder ([`CLIPTextModel`]):
Frozen text-encoder. VQ Diffusion uses the text portion of
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
the [clip-vit-base-patch32](https://huggingface.co/openai/clip-vit-base-patch32) variant.
tokenizer (`CLIPTokenizer`):
Tokenizer of class
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
transformer ([`Transformer2DModel`]):
Conditional transformer to denoise the encoded image latents.
scheduler ([`VQDiffusionScheduler`]):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
"""
vqvae: VQModel
text_encoder: CLIPTextModel
tokenizer: CLIPTokenizer
transformer: Transformer2DModel
scheduler: VQDiffusionScheduler
def __init__(
self,
vqvae: VQModel,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
transformer: Transformer2DModel,
scheduler: VQDiffusionScheduler,
):
super().__init__()
self.register_modules(
vqvae=vqvae,
transformer=transformer,
text_encoder=text_encoder,
tokenizer=tokenizer,
scheduler=scheduler,
)
@torch.no_grad()
def __call__(
self,
prompt: Union[str, List[str]],
num_inference_steps: int = 100,
truncation_rate: float = 1.0,
num_images_per_prompt: int = 1,
generator: Optional[torch.Generator] = None,
latents: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: Optional[int] = 1,
) -> Union[ImagePipelineOutput, Tuple]:
"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
num_inference_steps (`int`, *optional*, defaults to 100):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
truncation_rate (`float`, *optional*, defaults to 1.0 (equivalent to no truncation)):
Used to "truncate" the predicted classes for x_0 such that the cumulative probability for a pixel is at
most `truncation_rate`. The lowest probabilities that would increase the cumulative probability above
`truncation_rate` are set to zero.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator`, *optional*):
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
deterministic.
latents (`torch.FloatTensor` of shape (batch), *optional*):
Pre-generated noisy latents to be used as inputs for image generation. Must be valid embedding indices.
Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will
be generated of completely masked latent pixels.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple.
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
Returns:
[`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~ pipeline_utils.ImagePipelineOutput `] if
`return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
generated images.
"""
if isinstance(prompt, str):
batch_size = 1
elif isinstance(prompt, list):
batch_size = len(prompt)
else:
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
batch_size = batch_size * num_images_per_prompt
if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
# get prompt text embeddings
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
# NOTE: This additional step of normalizing the text embeddings is from VQ-Diffusion.
# While CLIP does normalize the pooled output of the text transformer when combining
# the image and text embeddings, CLIP does not directly normalize the last hidden state.
#
# CLIP normalizing the pooled output.
# https://github.com/huggingface/transformers/blob/d92e22d1f28324f513f3080e5c47c071a3916721/src/transformers/models/clip/modeling_clip.py#L1052-L1053
text_embeddings = text_embeddings / text_embeddings.norm(dim=-1, keepdim=True)
# duplicate text embeddings for each generation per prompt
text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
# get the initial completely masked latents unless the user supplied it
latents_shape = (batch_size, self.transformer.num_latent_pixels)
if latents is None:
mask_class = self.transformer.num_vector_embeds - 1
latents = torch.full(latents_shape, mask_class).to(self.device)
else:
if latents.shape != latents_shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
if (latents < 0).any() or (latents >= self.transformer.num_vector_embeds).any():
raise ValueError(
"Unexpected latents value(s). All latents be valid embedding indices i.e. in the range 0,"
f" {self.transformer.num_vector_embeds - 1} (inclusive)."
)
latents = latents.to(self.device)
# set timesteps
self.scheduler.set_timesteps(num_inference_steps, device=self.device)
timesteps_tensor = self.scheduler.timesteps.to(self.device)
sample = latents
for i, t in enumerate(self.progress_bar(timesteps_tensor)):
# predict the un-noised image
# model_output == `log_p_x_0`
model_output = self.transformer(sample, encoder_hidden_states=text_embeddings, timestep=t).sample
model_output = self.truncate(model_output, truncation_rate)
# remove `log(0)`'s (`-inf`s)
model_output = model_output.clamp(-70)
# compute the previous noisy sample x_t -> x_t-1
sample = self.scheduler.step(model_output, timestep=t, sample=sample, generator=generator).prev_sample
# call the callback, if provided
if callback is not None and i % callback_steps == 0:
callback(i, t, sample)
embedding_channels = self.vqvae.config.vq_embed_dim
embeddings_shape = (batch_size, self.transformer.height, self.transformer.width, embedding_channels)
embeddings = self.vqvae.quantize.get_codebook_entry(sample, shape=embeddings_shape)
image = self.vqvae.decode(embeddings, force_not_quantize=True).sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
if output_type == "pil":
image = self.numpy_to_pil(image)
if not return_dict:
return (image,)
return ImagePipelineOutput(images=image)
def truncate(self, log_p_x_0: torch.FloatTensor, truncation_rate: float) -> torch.FloatTensor:
"""
Truncates log_p_x_0 such that for each column vector, the total cumulative probability is `truncation_rate` The
lowest probabilities that would increase the cumulative probability above `truncation_rate` are set to zero.
"""
sorted_log_p_x_0, indices = torch.sort(log_p_x_0, 1, descending=True)
sorted_p_x_0 = torch.exp(sorted_log_p_x_0)
keep_mask = sorted_p_x_0.cumsum(dim=1) < truncation_rate
# Ensure that at least the largest probability is not zeroed out
all_true = torch.full_like(keep_mask[:, 0:1, :], True)
keep_mask = torch.cat((all_true, keep_mask), dim=1)
keep_mask = keep_mask[:, :-1, :]
keep_mask = keep_mask.gather(1, indices.argsort(1))
rv = log_p_x_0.clone()
rv[~keep_mask] = -torch.inf # -inf = log(0)
return rv

View File

@@ -19,24 +19,18 @@ from ..utils import is_flax_available, is_scipy_available, is_torch_available
if is_torch_available():
from .scheduling_ddim import DDIMScheduler
from .scheduling_ddpm import DDPMScheduler
from .scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler
from .scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler
from .scheduling_euler_discrete import EulerDiscreteScheduler
from .scheduling_ipndm import IPNDMScheduler
from .scheduling_karras_ve import KarrasVeScheduler
from .scheduling_pndm import PNDMScheduler
from .scheduling_repaint import RePaintScheduler
from .scheduling_sde_ve import ScoreSdeVeScheduler
from .scheduling_sde_vp import ScoreSdeVpScheduler
from .scheduling_utils import SchedulerMixin
from .scheduling_vq_diffusion import VQDiffusionScheduler
else:
from ..utils.dummy_pt_objects import * # noqa F403
if is_flax_available():
from .scheduling_ddim_flax import FlaxDDIMScheduler
from .scheduling_ddpm_flax import FlaxDDPMScheduler
from .scheduling_dpmsolver_multistep_flax import FlaxDPMSolverMultistepScheduler
from .scheduling_karras_ve_flax import FlaxKarrasVeScheduler
from .scheduling_lms_discrete_flax import FlaxLMSDiscreteScheduler
from .scheduling_pndm_flax import FlaxPNDMScheduler

View File

@@ -109,15 +109,6 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
"""
_compatible_classes = [
"PNDMScheduler",
"DDPMScheduler",
"LMSDiscreteScheduler",
"EulerDiscreteScheduler",
"EulerAncestralDiscreteScheduler",
"DPMSolverMultistepScheduler",
]
@register_to_config
def __init__(
self,
@@ -209,7 +200,6 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
eta: float = 0.0,
use_clipped_model_output: bool = False,
generator=None,
variance_noise: Optional[torch.FloatTensor] = None,
return_dict: bool = True,
) -> Union[DDIMSchedulerOutput, Tuple]:
"""
@@ -222,14 +212,8 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
sample (`torch.FloatTensor`):
current instance of sample being created by diffusion process.
eta (`float`): weight of noise for added noise in diffusion step.
use_clipped_model_output (`bool`): if `True`, compute "corrected" `model_output` from the clipped
predicted original sample. Necessary because predicted original sample is clipped to [-1, 1] when
`self.config.clip_sample` is `True`. If no clipping has happened, "corrected" `model_output` would
coincide with the one provided as input and `use_clipped_model_output` will have not effect.
use_clipped_model_output (`bool`): TODO
generator: random number generator.
variance_noise (`torch.FloatTensor`): instead of generating noise for the variance using `generator`, we
can directly provide the noise for the variance itself. This is useful for methods such as
CycleDiffusion. (https://arxiv.org/abs/2210.05559)
return_dict (`bool`): option for returning tuple rather than DDIMSchedulerOutput class
Returns:
@@ -289,17 +273,8 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
if eta > 0:
# randn_like does not support generator https://github.com/pytorch/pytorch/issues/27072
device = model_output.device if torch.is_tensor(model_output) else "cpu"
if variance_noise is not None and generator is not None:
raise ValueError(
"Cannot pass both generator and variance_noise. Please make sure that either `generator` or"
" `variance_noise` stays `None`."
)
if variance_noise is None:
variance_noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator).to(
device
)
variance = self._get_variance(timestep, prev_timestep) ** (0.5) * eta * variance_noise
noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator).to(device)
variance = self._get_variance(timestep, prev_timestep) ** (0.5) * eta * noise
prev_sample = prev_sample + variance

View File

@@ -102,15 +102,6 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
"""
_compatible_classes = [
"DDIMScheduler",
"PNDMScheduler",
"LMSDiscreteScheduler",
"EulerDiscreteScheduler",
"EulerAncestralDiscreteScheduler",
"DPMSolverMultistepScheduler",
]
@register_to_config
def __init__(
self,

View File

@@ -1,506 +0,0 @@
# Copyright 2022 TSAIL Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver
import math
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils import SchedulerMixin, SchedulerOutput
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
to that part of the diffusion process.
Args:
num_diffusion_timesteps (`int`): the number of betas to produce.
max_beta (`float`): the maximum beta to use; use values lower than 1 to
prevent singularities.
Returns:
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
"""
def alpha_bar(time_step):
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
betas = []
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
return torch.tensor(betas, dtype=torch.float32)
class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
"""
DPM-Solver (and the improved version DPM-Solver++) is a fast dedicated high-order solver for diffusion ODEs with
the convergence order guarantee. Empirically, sampling by DPM-Solver with only 20 steps can generate high-quality
samples, and it can generate quite good samples even in only 10 steps.
For more details, see the original paper: https://arxiv.org/abs/2206.00927 and https://arxiv.org/abs/2211.01095
Currently, we support the multistep DPM-Solver for both noise prediction models and data prediction models. We
recommend to use `solver_order=2` for guided sampling, and `solver_order=3` for unconditional sampling.
We also support the "dynamic thresholding" method in Imagen (https://arxiv.org/abs/2205.11487). For pixel-space
diffusion models, you can set both `algorithm_type="dpmsolver++"` and `thresholding=True` to use the dynamic
thresholding. Note that the thresholding method is unsuitable for latent-space diffusion models (such as
stable-diffusion).
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
[`~ConfigMixin.from_config`] functions.
Args:
num_train_timesteps (`int`): number of diffusion steps used to train the model.
beta_start (`float`): the starting `beta` value of inference.
beta_end (`float`): the final `beta` value.
beta_schedule (`str`):
the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
`linear`, `scaled_linear`, or `squaredcos_cap_v2`.
trained_betas (`np.ndarray`, optional):
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
solver_order (`int`, default `2`):
the order of DPM-Solver; can be `1` or `2` or `3`. We recommend to use `solver_order=2` for guided
sampling, and `solver_order=3` for unconditional sampling.
predict_epsilon (`bool`, default `True`):
we currently support both the noise prediction model and the data prediction model. If the model predicts
the noise / epsilon, set `predict_epsilon` to `True`. If the model predicts the data / x0 directly, set
`predict_epsilon` to `False`.
thresholding (`bool`, default `False`):
whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487).
For pixel-space diffusion models, you can set both `algorithm_type=dpmsolver++` and `thresholding=True` to
use the dynamic thresholding. Note that the thresholding method is unsuitable for latent-space diffusion
models (such as stable-diffusion).
dynamic_thresholding_ratio (`float`, default `0.995`):
the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen
(https://arxiv.org/abs/2205.11487).
sample_max_value (`float`, default `1.0`):
the threshold value for dynamic thresholding. Valid only when `thresholding=True` and
`algorithm_type="dpmsolver++`.
algorithm_type (`str`, default `dpmsolver++`):
the algorithm type for the solver. Either `dpmsolver` or `dpmsolver++`. The `dpmsolver` type implements the
algorithms in https://arxiv.org/abs/2206.00927, and the `dpmsolver++` type implements the algorithms in
https://arxiv.org/abs/2211.01095. We recommend to use `dpmsolver++` with `solver_order=2` for guided
sampling (e.g. stable-diffusion).
solver_type (`str`, default `midpoint`):
the solver type for the second-order solver. Either `midpoint` or `heun`. The solver type slightly affects
the sample quality, especially for small number of steps. We empirically find that `midpoint` solvers are
slightly better, so we recommend to use the `midpoint` type.
lower_order_final (`bool`, default `True`):
whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. We empirically
find this trick can stabilize the sampling of DPM-Solver for steps < 15, especially for steps <= 10.
"""
_compatible_classes = [
"DDIMScheduler",
"DDPMScheduler",
"PNDMScheduler",
"LMSDiscreteScheduler",
"EulerDiscreteScheduler",
"EulerAncestralDiscreteScheduler",
]
@register_to_config
def __init__(
self,
num_train_timesteps: int = 1000,
beta_start: float = 0.0001,
beta_end: float = 0.02,
beta_schedule: str = "linear",
trained_betas: Optional[np.ndarray] = None,
solver_order: int = 2,
predict_epsilon: bool = True,
thresholding: bool = False,
dynamic_thresholding_ratio: float = 0.995,
sample_max_value: float = 1.0,
algorithm_type: str = "dpmsolver++",
solver_type: str = "midpoint",
lower_order_final: bool = True,
):
if trained_betas is not None:
self.betas = torch.from_numpy(trained_betas)
elif beta_schedule == "linear":
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model.
self.betas = (
torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
)
elif beta_schedule == "squaredcos_cap_v2":
# Glide cosine schedule
self.betas = betas_for_alpha_bar(num_train_timesteps)
else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
# Currently we only support VP-type noise schedule
self.alpha_t = torch.sqrt(self.alphas_cumprod)
self.sigma_t = torch.sqrt(1 - self.alphas_cumprod)
self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t)
# standard deviation of the initial noise distribution
self.init_noise_sigma = 1.0
# settings for DPM-Solver
if algorithm_type not in ["dpmsolver", "dpmsolver++"]:
raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}")
if solver_type not in ["midpoint", "heun"]:
raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}")
# setable values
self.num_inference_steps = None
timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy()
self.timesteps = torch.from_numpy(timesteps)
self.model_outputs = [None] * solver_order
self.lower_order_nums = 0
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
"""
Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
Args:
num_inference_steps (`int`):
the number of diffusion steps used when generating samples with a pre-trained model.
device (`str` or `torch.device`, optional):
the device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
"""
self.num_inference_steps = num_inference_steps
timesteps = (
np.linspace(0, self.num_train_timesteps - 1, num_inference_steps + 1)
.round()[::-1][:-1]
.copy()
.astype(np.int64)
)
self.timesteps = torch.from_numpy(timesteps).to(device)
self.model_outputs = [
None,
] * self.config.solver_order
self.lower_order_nums = 0
def convert_model_output(
self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor
) -> torch.FloatTensor:
"""
Convert the model output to the corresponding type that the algorithm (DPM-Solver / DPM-Solver++) needs.
DPM-Solver is designed to discretize an integral of the noise prediciton model, and DPM-Solver++ is designed to
discretize an integral of the data prediction model. So we need to first convert the model output to the
corresponding type to match the algorithm.
Note that the algorithm type and the model type is decoupled. That is to say, we can use either DPM-Solver or
DPM-Solver++ for both noise prediction model and data prediction model.
Args:
model_output (`torch.FloatTensor`): direct output from learned diffusion model.
timestep (`int`): current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
current instance of sample being created by diffusion process.
Returns:
`torch.FloatTensor`: the converted model output.
"""
# DPM-Solver++ needs to solve an integral of the data prediction model.
if self.config.algorithm_type == "dpmsolver++":
if self.config.predict_epsilon:
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
x0_pred = (sample - sigma_t * model_output) / alpha_t
else:
x0_pred = model_output
if self.config.thresholding:
# Dynamic thresholding in https://arxiv.org/abs/2205.11487
dynamic_max_val = torch.quantile(
torch.abs(x0_pred).reshape((x0_pred.shape[0], -1)), self.config.dynamic_thresholding_ratio, dim=1
)
dynamic_max_val = torch.maximum(
dynamic_max_val,
self.config.sample_max_value * torch.ones_like(dynamic_max_val).to(dynamic_max_val.device),
)[(...,) + (None,) * (x0_pred.ndim - 1)]
x0_pred = torch.clamp(x0_pred, -dynamic_max_val, dynamic_max_val) / dynamic_max_val
return x0_pred
# DPM-Solver needs to solve an integral of the noise prediction model.
elif self.config.algorithm_type == "dpmsolver":
if self.config.predict_epsilon:
return model_output
else:
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
epsilon = (sample - alpha_t * model_output) / sigma_t
return epsilon
def dpm_solver_first_order_update(
self,
model_output: torch.FloatTensor,
timestep: int,
prev_timestep: int,
sample: torch.FloatTensor,
) -> torch.FloatTensor:
"""
One step for the first-order DPM-Solver (equivalent to DDIM).
See https://arxiv.org/abs/2206.00927 for the detailed derivation.
Args:
model_output (`torch.FloatTensor`): direct output from learned diffusion model.
timestep (`int`): current discrete timestep in the diffusion chain.
prev_timestep (`int`): previous discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
current instance of sample being created by diffusion process.
Returns:
`torch.FloatTensor`: the sample tensor at the previous timestep.
"""
lambda_t, lambda_s = self.lambda_t[prev_timestep], self.lambda_t[timestep]
alpha_t, alpha_s = self.alpha_t[prev_timestep], self.alpha_t[timestep]
sigma_t, sigma_s = self.sigma_t[prev_timestep], self.sigma_t[timestep]
h = lambda_t - lambda_s
if self.config.algorithm_type == "dpmsolver++":
x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output
elif self.config.algorithm_type == "dpmsolver":
x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output
return x_t
def multistep_dpm_solver_second_order_update(
self,
model_output_list: List[torch.FloatTensor],
timestep_list: List[int],
prev_timestep: int,
sample: torch.FloatTensor,
) -> torch.FloatTensor:
"""
One step for the second-order multistep DPM-Solver.
Args:
model_output_list (`List[torch.FloatTensor]`):
direct outputs from learned diffusion model at current and latter timesteps.
timestep (`int`): current and latter discrete timestep in the diffusion chain.
prev_timestep (`int`): previous discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
current instance of sample being created by diffusion process.
Returns:
`torch.FloatTensor`: the sample tensor at the previous timestep.
"""
t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2]
m0, m1 = model_output_list[-1], model_output_list[-2]
lambda_t, lambda_s0, lambda_s1 = self.lambda_t[t], self.lambda_t[s0], self.lambda_t[s1]
alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0]
sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0]
h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1
r0 = h_0 / h
D0, D1 = m0, (1.0 / r0) * (m0 - m1)
if self.config.algorithm_type == "dpmsolver++":
# See https://arxiv.org/abs/2211.01095 for detailed derivations
if self.config.solver_type == "midpoint":
x_t = (
(sigma_t / sigma_s0) * sample
- (alpha_t * (torch.exp(-h) - 1.0)) * D0
- 0.5 * (alpha_t * (torch.exp(-h) - 1.0)) * D1
)
elif self.config.solver_type == "heun":
x_t = (
(sigma_t / sigma_s0) * sample
- (alpha_t * (torch.exp(-h) - 1.0)) * D0
+ (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1
)
elif self.config.algorithm_type == "dpmsolver":
# See https://arxiv.org/abs/2206.00927 for detailed derivations
if self.config.solver_type == "midpoint":
x_t = (
(alpha_t / alpha_s0) * sample
- (sigma_t * (torch.exp(h) - 1.0)) * D0
- 0.5 * (sigma_t * (torch.exp(h) - 1.0)) * D1
)
elif self.config.solver_type == "heun":
x_t = (
(alpha_t / alpha_s0) * sample
- (sigma_t * (torch.exp(h) - 1.0)) * D0
- (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
)
return x_t
def multistep_dpm_solver_third_order_update(
self,
model_output_list: List[torch.FloatTensor],
timestep_list: List[int],
prev_timestep: int,
sample: torch.FloatTensor,
) -> torch.FloatTensor:
"""
One step for the third-order multistep DPM-Solver.
Args:
model_output_list (`List[torch.FloatTensor]`):
direct outputs from learned diffusion model at current and latter timesteps.
timestep (`int`): current and latter discrete timestep in the diffusion chain.
prev_timestep (`int`): previous discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
current instance of sample being created by diffusion process.
Returns:
`torch.FloatTensor`: the sample tensor at the previous timestep.
"""
t, s0, s1, s2 = prev_timestep, timestep_list[-1], timestep_list[-2], timestep_list[-3]
m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3]
lambda_t, lambda_s0, lambda_s1, lambda_s2 = (
self.lambda_t[t],
self.lambda_t[s0],
self.lambda_t[s1],
self.lambda_t[s2],
)
alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0]
sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0]
h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2
r0, r1 = h_0 / h, h_1 / h
D0 = m0
D1_0, D1_1 = (1.0 / r0) * (m0 - m1), (1.0 / r1) * (m1 - m2)
D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1)
D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1)
if self.config.algorithm_type == "dpmsolver++":
# See https://arxiv.org/abs/2206.00927 for detailed derivations
x_t = (
(sigma_t / sigma_s0) * sample
- (alpha_t * (torch.exp(-h) - 1.0)) * D0
+ (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1
- (alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2
)
elif self.config.algorithm_type == "dpmsolver":
# See https://arxiv.org/abs/2206.00927 for detailed derivations
x_t = (
(alpha_t / alpha_s0) * sample
- (sigma_t * (torch.exp(h) - 1.0)) * D0
- (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
- (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2
)
return x_t
def step(
self,
model_output: torch.FloatTensor,
timestep: int,
sample: torch.FloatTensor,
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
"""
Step function propagating the sample with the multistep DPM-Solver.
Args:
model_output (`torch.FloatTensor`): direct output from learned diffusion model.
timestep (`int`): current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
current instance of sample being created by diffusion process.
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
Returns:
[`~scheduling_utils.SchedulerOutput`] or `tuple`: [`~scheduling_utils.SchedulerOutput`] if `return_dict` is
True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
"""
if self.num_inference_steps is None:
raise ValueError(
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
)
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
step_index = (self.timesteps == timestep).nonzero()
if len(step_index) == 0:
step_index = len(self.timesteps) - 1
else:
step_index = step_index.item()
prev_timestep = 0 if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1]
lower_order_final = (
(step_index == len(self.timesteps) - 1) and self.config.lower_order_final and len(self.timesteps) < 15
)
lower_order_second = (
(step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15
)
model_output = self.convert_model_output(model_output, timestep, sample)
for i in range(self.config.solver_order - 1):
self.model_outputs[i] = self.model_outputs[i + 1]
self.model_outputs[-1] = model_output
if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final:
prev_sample = self.dpm_solver_first_order_update(model_output, timestep, prev_timestep, sample)
elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second:
timestep_list = [self.timesteps[step_index - 1], timestep]
prev_sample = self.multistep_dpm_solver_second_order_update(
self.model_outputs, timestep_list, prev_timestep, sample
)
else:
timestep_list = [self.timesteps[step_index - 2], self.timesteps[step_index - 1], timestep]
prev_sample = self.multistep_dpm_solver_third_order_update(
self.model_outputs, timestep_list, prev_timestep, sample
)
if self.lower_order_nums < self.config.solver_order:
self.lower_order_nums += 1
if not return_dict:
return (prev_sample,)
return SchedulerOutput(prev_sample=prev_sample)
def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
"""
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep.
Args:
sample (`torch.FloatTensor`): input sample
Returns:
`torch.FloatTensor`: scaled input sample
"""
return sample
def add_noise(
self,
original_samples: torch.FloatTensor,
noise: torch.FloatTensor,
timesteps: torch.IntTensor,
) -> torch.FloatTensor:
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
timesteps = timesteps.to(original_samples.device)
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples
def __len__(self):
return self.config.num_train_timesteps

View File

@@ -1,590 +0,0 @@
# Copyright 2022 TSAIL Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver
import math
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
import flax
import jax
import jax.numpy as jnp
from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left
def betas_for_alpha_bar(num_diffusion_timesteps: int, max_beta=0.999) -> jnp.ndarray:
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
to that part of the diffusion process.
Args:
num_diffusion_timesteps (`int`): the number of betas to produce.
max_beta (`float`): the maximum beta to use; use values lower than 1 to
prevent singularities.
Returns:
betas (`jnp.ndarray`): the betas used by the scheduler to step the model outputs
"""
def alpha_bar(time_step):
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
betas = []
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
return jnp.array(betas, dtype=jnp.float32)
@flax.struct.dataclass
class DPMSolverMultistepSchedulerState:
# setable values
num_inference_steps: Optional[int] = None
timesteps: Optional[jnp.ndarray] = None
# running values
model_outputs: Optional[jnp.ndarray] = None
lower_order_nums: Optional[int] = None
step_index: Optional[int] = None
prev_timestep: Optional[int] = None
cur_sample: Optional[jnp.ndarray] = None
@classmethod
def create(cls, num_train_timesteps: int):
return cls(timesteps=jnp.arange(0, num_train_timesteps)[::-1])
@dataclass
class FlaxDPMSolverMultistepSchedulerOutput(FlaxSchedulerOutput):
state: DPMSolverMultistepSchedulerState
class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
"""
DPM-Solver (and the improved version DPM-Solver++) is a fast dedicated high-order solver for diffusion ODEs with
the convergence order guarantee. Empirically, sampling by DPM-Solver with only 20 steps can generate high-quality
samples, and it can generate quite good samples even in only 10 steps.
For more details, see the original paper: https://arxiv.org/abs/2206.00927 and https://arxiv.org/abs/2211.01095
Currently, we support the multistep DPM-Solver for both noise prediction models and data prediction models. We
recommend to use `solver_order=2` for guided sampling, and `solver_order=3` for unconditional sampling.
We also support the "dynamic thresholding" method in Imagen (https://arxiv.org/abs/2205.11487). For pixel-space
diffusion models, you can set both `algorithm_type="dpmsolver++"` and `thresholding=True` to use the dynamic
thresholding. Note that the thresholding method is unsuitable for latent-space diffusion models (such as
stable-diffusion).
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
[`~ConfigMixin.from_config`] functions.
For more details, see the original paper: https://arxiv.org/abs/2206.00927 and https://arxiv.org/abs/2211.01095
Args:
num_train_timesteps (`int`): number of diffusion steps used to train the model.
beta_start (`float`): the starting `beta` value of inference.
beta_end (`float`): the final `beta` value.
beta_schedule (`str`):
the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
`linear`, `scaled_linear`, or `squaredcos_cap_v2`.
trained_betas (`np.ndarray`, optional):
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
solver_order (`int`, default `2`):
the order of DPM-Solver; can be `1` or `2` or `3`. We recommend to use `solver_order=2` for guided
sampling, and `solver_order=3` for unconditional sampling.
predict_epsilon (`bool`, default `True`):
we currently support both the noise prediction model and the data prediction model. If the model predicts
the noise / epsilon, set `predict_epsilon` to `True`. If the model predicts the data / x0 directly, set
`predict_epsilon` to `False`.
thresholding (`bool`, default `False`):
whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487).
For pixel-space diffusion models, you can set both `algorithm_type=dpmsolver++` and `thresholding=True` to
use the dynamic thresholding. Note that the thresholding method is unsuitable for latent-space diffusion
models (such as stable-diffusion).
dynamic_thresholding_ratio (`float`, default `0.995`):
the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen
(https://arxiv.org/abs/2205.11487).
sample_max_value (`float`, default `1.0`):
the threshold value for dynamic thresholding. Valid only when `thresholding=True` and
`algorithm_type="dpmsolver++`.
algorithm_type (`str`, default `dpmsolver++`):
the algorithm type for the solver. Either `dpmsolver` or `dpmsolver++`. The `dpmsolver` type implements the
algorithms in https://arxiv.org/abs/2206.00927, and the `dpmsolver++` type implements the algorithms in
https://arxiv.org/abs/2211.01095. We recommend to use `dpmsolver++` with `solver_order=2` for guided
sampling (e.g. stable-diffusion).
solver_type (`str`, default `midpoint`):
the solver type for the second-order solver. Either `midpoint` or `heun`. The solver type slightly affects
the sample quality, especially for small number of steps. We empirically find that `midpoint` solvers are
slightly better, so we recommend to use the `midpoint` type.
lower_order_final (`bool`, default `True`):
whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. We empirically
find this trick can stabilize the sampling of DPM-Solver for steps < 15, especially for steps <= 10.
"""
@property
def has_state(self):
return True
@register_to_config
def __init__(
self,
num_train_timesteps: int = 1000,
beta_start: float = 0.0001,
beta_end: float = 0.02,
beta_schedule: str = "linear",
trained_betas: Optional[jnp.ndarray] = None,
solver_order: int = 2,
predict_epsilon: bool = True,
thresholding: bool = False,
dynamic_thresholding_ratio: float = 0.995,
sample_max_value: float = 1.0,
algorithm_type: str = "dpmsolver++",
solver_type: str = "midpoint",
lower_order_final: bool = True,
):
if trained_betas is not None:
self.betas = jnp.asarray(trained_betas)
elif beta_schedule == "linear":
self.betas = jnp.linspace(beta_start, beta_end, num_train_timesteps, dtype=jnp.float32)
elif beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model.
self.betas = jnp.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=jnp.float32) ** 2
elif beta_schedule == "squaredcos_cap_v2":
# Glide cosine schedule
self.betas = betas_for_alpha_bar(num_train_timesteps)
else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
self.alphas = 1.0 - self.betas
self.alphas_cumprod = jnp.cumprod(self.alphas, axis=0)
# Currently we only support VP-type noise schedule
self.alpha_t = jnp.sqrt(self.alphas_cumprod)
self.sigma_t = jnp.sqrt(1 - self.alphas_cumprod)
self.lambda_t = jnp.log(self.alpha_t) - jnp.log(self.sigma_t)
# standard deviation of the initial noise distribution
self.init_noise_sigma = 1.0
# settings for DPM-Solver
if algorithm_type not in ["dpmsolver", "dpmsolver++"]:
raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}")
if solver_type not in ["midpoint", "heun"]:
raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}")
def create_state(self):
return DPMSolverMultistepSchedulerState.create(num_train_timesteps=self.config.num_train_timesteps)
def set_timesteps(
self, state: DPMSolverMultistepSchedulerState, num_inference_steps: int, shape: Tuple
) -> DPMSolverMultistepSchedulerState:
"""
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
Args:
state (`DPMSolverMultistepSchedulerState`):
the `FlaxDPMSolverMultistepScheduler` state data class instance.
num_inference_steps (`int`):
the number of diffusion steps used when generating samples with a pre-trained model.
shape (`Tuple`):
the shape of the samples to be generated.
"""
timesteps = (
jnp.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1)
.round()[::-1][:-1]
.astype(jnp.int32)
)
return state.replace(
num_inference_steps=num_inference_steps,
timesteps=timesteps,
model_outputs=jnp.zeros((self.config.solver_order,) + shape),
lower_order_nums=0,
step_index=0,
prev_timestep=-1,
cur_sample=jnp.zeros(shape),
)
def convert_model_output(
self,
model_output: jnp.ndarray,
timestep: int,
sample: jnp.ndarray,
) -> jnp.ndarray:
"""
Convert the model output to the corresponding type that the algorithm (DPM-Solver / DPM-Solver++) needs.
DPM-Solver is designed to discretize an integral of the noise prediciton model, and DPM-Solver++ is designed to
discretize an integral of the data prediction model. So we need to first convert the model output to the
corresponding type to match the algorithm.
Note that the algorithm type and the model type is decoupled. That is to say, we can use either DPM-Solver or
DPM-Solver++ for both noise prediction model and data prediction model.
Args:
model_output (`jnp.ndarray`): direct output from learned diffusion model.
timestep (`int`): current discrete timestep in the diffusion chain.
sample (`jnp.ndarray`):
current instance of sample being created by diffusion process.
Returns:
`jnp.ndarray`: the converted model output.
"""
# DPM-Solver++ needs to solve an integral of the data prediction model.
if self.config.algorithm_type == "dpmsolver++":
if self.config.predict_epsilon:
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
x0_pred = (sample - sigma_t * model_output) / alpha_t
else:
x0_pred = model_output
if self.config.thresholding:
# Dynamic thresholding in https://arxiv.org/abs/2205.11487
dynamic_max_val = jnp.percentile(
jnp.abs(x0_pred), self.config.dynamic_thresholding_ratio, axis=tuple(range(1, x0_pred.ndim))
)
dynamic_max_val = jnp.maximum(
dynamic_max_val, self.config.sample_max_value * jnp.ones_like(dynamic_max_val)
)
x0_pred = jnp.clip(x0_pred, -dynamic_max_val, dynamic_max_val) / dynamic_max_val
return x0_pred
# DPM-Solver needs to solve an integral of the noise prediction model.
elif self.config.algorithm_type == "dpmsolver":
if self.config.predict_epsilon:
return model_output
else:
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
epsilon = (sample - alpha_t * model_output) / sigma_t
return epsilon
def dpm_solver_first_order_update(
self, model_output: jnp.ndarray, timestep: int, prev_timestep: int, sample: jnp.ndarray
) -> jnp.ndarray:
"""
One step for the first-order DPM-Solver (equivalent to DDIM).
See https://arxiv.org/abs/2206.00927 for the detailed derivation.
Args:
model_output (`jnp.ndarray`): direct output from learned diffusion model.
timestep (`int`): current discrete timestep in the diffusion chain.
prev_timestep (`int`): previous discrete timestep in the diffusion chain.
sample (`jnp.ndarray`):
current instance of sample being created by diffusion process.
Returns:
`jnp.ndarray`: the sample tensor at the previous timestep.
"""
t, s0 = prev_timestep, timestep
m0 = model_output
lambda_t, lambda_s = self.lambda_t[t], self.lambda_t[s0]
alpha_t, alpha_s = self.alpha_t[t], self.alpha_t[s0]
sigma_t, sigma_s = self.sigma_t[t], self.sigma_t[s0]
h = lambda_t - lambda_s
if self.config.algorithm_type == "dpmsolver++":
x_t = (sigma_t / sigma_s) * sample - (alpha_t * (jnp.exp(-h) - 1.0)) * m0
elif self.config.algorithm_type == "dpmsolver":
x_t = (alpha_t / alpha_s) * sample - (sigma_t * (jnp.exp(h) - 1.0)) * m0
return x_t
def multistep_dpm_solver_second_order_update(
self,
model_output_list: jnp.ndarray,
timestep_list: List[int],
prev_timestep: int,
sample: jnp.ndarray,
) -> jnp.ndarray:
"""
One step for the second-order multistep DPM-Solver.
Args:
model_output_list (`List[jnp.ndarray]`):
direct outputs from learned diffusion model at current and latter timesteps.
timestep (`int`): current and latter discrete timestep in the diffusion chain.
prev_timestep (`int`): previous discrete timestep in the diffusion chain.
sample (`jnp.ndarray`):
current instance of sample being created by diffusion process.
Returns:
`jnp.ndarray`: the sample tensor at the previous timestep.
"""
t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2]
m0, m1 = model_output_list[-1], model_output_list[-2]
lambda_t, lambda_s0, lambda_s1 = self.lambda_t[t], self.lambda_t[s0], self.lambda_t[s1]
alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0]
sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0]
h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1
r0 = h_0 / h
D0, D1 = m0, (1.0 / r0) * (m0 - m1)
if self.config.algorithm_type == "dpmsolver++":
# See https://arxiv.org/abs/2211.01095 for detailed derivations
if self.config.solver_type == "midpoint":
x_t = (
(sigma_t / sigma_s0) * sample
- (alpha_t * (jnp.exp(-h) - 1.0)) * D0
- 0.5 * (alpha_t * (jnp.exp(-h) - 1.0)) * D1
)
elif self.config.solver_type == "heun":
x_t = (
(sigma_t / sigma_s0) * sample
- (alpha_t * (jnp.exp(-h) - 1.0)) * D0
+ (alpha_t * ((jnp.exp(-h) - 1.0) / h + 1.0)) * D1
)
elif self.config.algorithm_type == "dpmsolver":
# See https://arxiv.org/abs/2206.00927 for detailed derivations
if self.config.solver_type == "midpoint":
x_t = (
(alpha_t / alpha_s0) * sample
- (sigma_t * (jnp.exp(h) - 1.0)) * D0
- 0.5 * (sigma_t * (jnp.exp(h) - 1.0)) * D1
)
elif self.config.solver_type == "heun":
x_t = (
(alpha_t / alpha_s0) * sample
- (sigma_t * (jnp.exp(h) - 1.0)) * D0
- (sigma_t * ((jnp.exp(h) - 1.0) / h - 1.0)) * D1
)
return x_t
def multistep_dpm_solver_third_order_update(
self,
model_output_list: jnp.ndarray,
timestep_list: List[int],
prev_timestep: int,
sample: jnp.ndarray,
) -> jnp.ndarray:
"""
One step for the third-order multistep DPM-Solver.
Args:
model_output_list (`List[jnp.ndarray]`):
direct outputs from learned diffusion model at current and latter timesteps.
timestep (`int`): current and latter discrete timestep in the diffusion chain.
prev_timestep (`int`): previous discrete timestep in the diffusion chain.
sample (`jnp.ndarray`):
current instance of sample being created by diffusion process.
Returns:
`jnp.ndarray`: the sample tensor at the previous timestep.
"""
t, s0, s1, s2 = prev_timestep, timestep_list[-1], timestep_list[-2], timestep_list[-3]
m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3]
lambda_t, lambda_s0, lambda_s1, lambda_s2 = (
self.lambda_t[t],
self.lambda_t[s0],
self.lambda_t[s1],
self.lambda_t[s2],
)
alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0]
sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0]
h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2
r0, r1 = h_0 / h, h_1 / h
D0 = m0
D1_0, D1_1 = (1.0 / r0) * (m0 - m1), (1.0 / r1) * (m1 - m2)
D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1)
D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1)
if self.config.algorithm_type == "dpmsolver++":
# See https://arxiv.org/abs/2206.00927 for detailed derivations
x_t = (
(sigma_t / sigma_s0) * sample
- (alpha_t * (jnp.exp(-h) - 1.0)) * D0
+ (alpha_t * ((jnp.exp(-h) - 1.0) / h + 1.0)) * D1
- (alpha_t * ((jnp.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2
)
elif self.config.algorithm_type == "dpmsolver":
# See https://arxiv.org/abs/2206.00927 for detailed derivations
x_t = (
(alpha_t / alpha_s0) * sample
- (sigma_t * (jnp.exp(h) - 1.0)) * D0
- (sigma_t * ((jnp.exp(h) - 1.0) / h - 1.0)) * D1
- (sigma_t * ((jnp.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2
)
return x_t
def step(
self,
state: DPMSolverMultistepSchedulerState,
model_output: jnp.ndarray,
timestep: int,
sample: jnp.ndarray,
return_dict: bool = True,
) -> Union[FlaxDPMSolverMultistepSchedulerOutput, Tuple]:
"""
Predict the sample at the previous timestep by DPM-Solver. Core function to propagate the diffusion process
from the learned model outputs (most often the predicted noise).
Args:
state (`DPMSolverMultistepSchedulerState`):
the `FlaxDPMSolverMultistepScheduler` state data class instance.
model_output (`jnp.ndarray`): direct output from learned diffusion model.
timestep (`int`): current discrete timestep in the diffusion chain.
sample (`jnp.ndarray`):
current instance of sample being created by diffusion process.
return_dict (`bool`): option for returning tuple rather than FlaxDPMSolverMultistepSchedulerOutput class
Returns:
[`FlaxDPMSolverMultistepSchedulerOutput`] or `tuple`: [`FlaxDPMSolverMultistepSchedulerOutput`] if
`return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
"""
prev_timestep = jax.lax.cond(
state.step_index == len(state.timesteps) - 1,
lambda _: 0,
lambda _: state.timesteps[state.step_index + 1],
(),
)
model_output = self.convert_model_output(model_output, timestep, sample)
model_outputs_new = jnp.roll(state.model_outputs, -1, axis=0)
model_outputs_new = model_outputs_new.at[-1].set(model_output)
state = state.replace(
model_outputs=model_outputs_new,
prev_timestep=prev_timestep,
cur_sample=sample,
)
def step_1(state: DPMSolverMultistepSchedulerState) -> jnp.ndarray:
return self.dpm_solver_first_order_update(
state.model_outputs[-1],
state.timesteps[state.step_index],
state.prev_timestep,
state.cur_sample,
)
def step_23(state: DPMSolverMultistepSchedulerState) -> jnp.ndarray:
def step_2(state: DPMSolverMultistepSchedulerState) -> jnp.ndarray:
timestep_list = jnp.array([state.timesteps[state.step_index - 1], state.timesteps[state.step_index]])
return self.multistep_dpm_solver_second_order_update(
state.model_outputs,
timestep_list,
state.prev_timestep,
state.cur_sample,
)
def step_3(state: DPMSolverMultistepSchedulerState) -> jnp.ndarray:
timestep_list = jnp.array(
[
state.timesteps[state.step_index - 2],
state.timesteps[state.step_index - 1],
state.timesteps[state.step_index],
]
)
return self.multistep_dpm_solver_third_order_update(
state.model_outputs,
timestep_list,
state.prev_timestep,
state.cur_sample,
)
if self.config.solver_order == 2:
return step_2(state)
elif self.config.lower_order_final and len(state.timesteps) < 15:
return jax.lax.cond(
state.lower_order_nums < 2,
step_2,
lambda state: jax.lax.cond(
state.step_index == len(state.timesteps) - 2,
step_2,
step_3,
state,
),
state,
)
else:
return jax.lax.cond(
state.lower_order_nums < 2,
step_2,
step_3,
state,
)
if self.config.solver_order == 1:
prev_sample = step_1(state)
elif self.config.lower_order_final and len(state.timesteps) < 15:
prev_sample = jax.lax.cond(
state.lower_order_nums < 1,
step_1,
lambda state: jax.lax.cond(
state.step_index == len(state.timesteps) - 1,
step_1,
step_23,
state,
),
state,
)
else:
prev_sample = jax.lax.cond(
state.lower_order_nums < 1,
step_1,
step_23,
state,
)
state = state.replace(
lower_order_nums=jnp.minimum(state.lower_order_nums + 1, self.config.solver_order),
step_index=(state.step_index + 1),
)
if not return_dict:
return (prev_sample, state)
return FlaxDPMSolverMultistepSchedulerOutput(prev_sample=prev_sample, state=state)
def scale_model_input(
self, state: DPMSolverMultistepSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None
) -> jnp.ndarray:
"""
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep.
Args:
state (`DPMSolverMultistepSchedulerState`):
the `FlaxDPMSolverMultistepScheduler` state data class instance.
sample (`jnp.ndarray`): input sample
timestep (`int`, optional): current timestep
Returns:
`jnp.ndarray`: scaled input sample
"""
return sample
def add_noise(
self,
original_samples: jnp.ndarray,
noise: jnp.ndarray,
timesteps: jnp.ndarray,
) -> jnp.ndarray:
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
sqrt_alpha_prod = broadcast_to_shape_from_left(sqrt_alpha_prod, original_samples.shape)
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.0
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
sqrt_one_minus_alpha_prod = broadcast_to_shape_from_left(sqrt_one_minus_alpha_prod, original_samples.shape)
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples
def __len__(self):
return self.config.num_train_timesteps

View File

@@ -1,267 +0,0 @@
# Copyright 2022 Katherine Crowson and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput, logging
from .scheduling_utils import SchedulerMixin
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@dataclass
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->EulerAncestralDiscrete
class EulerAncestralDiscreteSchedulerOutput(BaseOutput):
"""
Output class for the scheduler's step function output.
Args:
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
The predicted denoised sample (x_{0}) based on the model output from the current timestep.
`pred_original_sample` can be used to preview progress or for guidance.
"""
prev_sample: torch.FloatTensor
pred_original_sample: Optional[torch.FloatTensor] = None
class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
"""
Ancestral sampling with Euler method steps. Based on the original k-diffusion implementation by Katherine Crowson:
https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L72
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
[`~ConfigMixin.from_config`] functions.
Args:
num_train_timesteps (`int`): number of diffusion steps used to train the model.
beta_start (`float`): the starting `beta` value of inference.
beta_end (`float`): the final `beta` value.
beta_schedule (`str`):
the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
`linear` or `scaled_linear`.
trained_betas (`np.ndarray`, optional):
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
"""
_compatible_classes = [
"DDIMScheduler",
"DDPMScheduler",
"LMSDiscreteScheduler",
"PNDMScheduler",
"EulerDiscreteScheduler",
"DPMSolverMultistepScheduler",
]
@register_to_config
def __init__(
self,
num_train_timesteps: int = 1000,
beta_start: float = 0.0001,
beta_end: float = 0.02,
beta_schedule: str = "linear",
trained_betas: Optional[np.ndarray] = None,
):
if trained_betas is not None:
self.betas = torch.from_numpy(trained_betas)
elif beta_schedule == "linear":
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model.
self.betas = (
torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
)
else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32)
self.sigmas = torch.from_numpy(sigmas)
# standard deviation of the initial noise distribution
self.init_noise_sigma = self.sigmas.max()
# setable values
self.num_inference_steps = None
timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy()
self.timesteps = torch.from_numpy(timesteps)
self.is_scale_input_called = False
def scale_model_input(
self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
) -> torch.FloatTensor:
"""
Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm.
Args:
sample (`torch.FloatTensor`): input sample
timestep (`float` or `torch.FloatTensor`): the current timestep in the diffusion chain
Returns:
`torch.FloatTensor`: scaled input sample
"""
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
step_index = (self.timesteps == timestep).nonzero().item()
sigma = self.sigmas[step_index]
sample = sample / ((sigma**2 + 1) ** 0.5)
self.is_scale_input_called = True
return sample
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
"""
Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
Args:
num_inference_steps (`int`):
the number of diffusion steps used when generating samples with a pre-trained model.
device (`str` or `torch.device`, optional):
the device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
"""
self.num_inference_steps = num_inference_steps
timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy()
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
self.sigmas = torch.from_numpy(sigmas).to(device=device)
self.timesteps = torch.from_numpy(timesteps).to(device=device)
def step(
self,
model_output: torch.FloatTensor,
timestep: Union[float, torch.FloatTensor],
sample: torch.FloatTensor,
generator: Optional[torch.Generator] = None,
return_dict: bool = True,
) -> Union[EulerAncestralDiscreteSchedulerOutput, Tuple]:
"""
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
process from the learned model outputs (most often the predicted noise).
Args:
model_output (`torch.FloatTensor`): direct output from learned diffusion model.
timestep (`float`): current timestep in the diffusion chain.
sample (`torch.FloatTensor`):
current instance of sample being created by diffusion process.
generator (`torch.Generator`, optional): Random number generator.
return_dict (`bool`): option for returning tuple rather than EulerAncestralDiscreteSchedulerOutput class
Returns:
[`~schedulers.scheduling_utils.EulerAncestralDiscreteSchedulerOutput`] or `tuple`:
[`~schedulers.scheduling_utils.EulerAncestralDiscreteSchedulerOutput`] if `return_dict` is True, otherwise
a `tuple`. When returning a tuple, the first element is the sample tensor.
"""
if (
isinstance(timestep, int)
or isinstance(timestep, torch.IntTensor)
or isinstance(timestep, torch.LongTensor)
):
raise ValueError(
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
" `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
" one of the `scheduler.timesteps` as a timestep.",
)
if not self.is_scale_input_called:
logger.warn(
"The `scale_model_input` function should be called before `step` to ensure correct denoising. "
"See `StableDiffusionPipeline` for a usage example."
)
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
step_index = (self.timesteps == timestep).nonzero().item()
sigma = self.sigmas[step_index]
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
pred_original_sample = sample - sigma * model_output
sigma_from = self.sigmas[step_index]
sigma_to = self.sigmas[step_index + 1]
sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5
sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
# 2. Convert to an ODE derivative
derivative = (sample - pred_original_sample) / sigma
dt = sigma_down - sigma
prev_sample = sample + derivative * dt
device = model_output.device if torch.is_tensor(model_output) else "cpu"
if str(device) == "mps":
# randn does not work reproducibly on mps
noise = torch.randn(model_output.shape, dtype=model_output.dtype, device="cpu", generator=generator).to(
device
)
else:
noise = torch.randn(model_output.shape, dtype=model_output.dtype, device=device, generator=generator).to(
device
)
prev_sample = prev_sample + noise * sigma_up
if not return_dict:
return (prev_sample,)
return EulerAncestralDiscreteSchedulerOutput(
prev_sample=prev_sample, pred_original_sample=pred_original_sample
)
def add_noise(
self,
original_samples: torch.FloatTensor,
noise: torch.FloatTensor,
timesteps: torch.FloatTensor,
) -> torch.FloatTensor:
# Make sure sigmas and timesteps have the same device and dtype as original_samples
self.sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
# mps does not support float64
self.timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
else:
self.timesteps = self.timesteps.to(original_samples.device)
timesteps = timesteps.to(original_samples.device)
schedule_timesteps = self.timesteps
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
sigma = self.sigmas[step_indices].flatten()
while len(sigma.shape) < len(original_samples.shape):
sigma = sigma.unsqueeze(-1)
noisy_samples = original_samples + noise * sigma
return noisy_samples
def __len__(self):
return self.config.num_train_timesteps

View File

@@ -1,276 +0,0 @@
# Copyright 2022 Katherine Crowson and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput, logging
from .scheduling_utils import SchedulerMixin
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@dataclass
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->EulerDiscrete
class EulerDiscreteSchedulerOutput(BaseOutput):
"""
Output class for the scheduler's step function output.
Args:
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
The predicted denoised sample (x_{0}) based on the model output from the current timestep.
`pred_original_sample` can be used to preview progress or for guidance.
"""
prev_sample: torch.FloatTensor
pred_original_sample: Optional[torch.FloatTensor] = None
class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
"""
Euler scheduler (Algorithm 2) from Karras et al. (2022) https://arxiv.org/abs/2206.00364. . Based on the original
k-diffusion implementation by Katherine Crowson:
https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L51
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
[`~ConfigMixin.from_config`] functions.
Args:
num_train_timesteps (`int`): number of diffusion steps used to train the model.
beta_start (`float`): the starting `beta` value of inference.
beta_end (`float`): the final `beta` value.
beta_schedule (`str`):
the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
`linear` or `scaled_linear`.
trained_betas (`np.ndarray`, optional):
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
"""
_compatible_classes = [
"DDIMScheduler",
"DDPMScheduler",
"LMSDiscreteScheduler",
"PNDMScheduler",
"EulerAncestralDiscreteScheduler",
"DPMSolverMultistepScheduler",
]
@register_to_config
def __init__(
self,
num_train_timesteps: int = 1000,
beta_start: float = 0.0001,
beta_end: float = 0.02,
beta_schedule: str = "linear",
trained_betas: Optional[np.ndarray] = None,
):
if trained_betas is not None:
self.betas = torch.from_numpy(trained_betas)
elif beta_schedule == "linear":
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model.
self.betas = (
torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
)
else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32)
self.sigmas = torch.from_numpy(sigmas)
# standard deviation of the initial noise distribution
self.init_noise_sigma = self.sigmas.max()
# setable values
self.num_inference_steps = None
timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy()
self.timesteps = torch.from_numpy(timesteps)
self.is_scale_input_called = False
def scale_model_input(
self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
) -> torch.FloatTensor:
"""
Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm.
Args:
sample (`torch.FloatTensor`): input sample
timestep (`float` or `torch.FloatTensor`): the current timestep in the diffusion chain
Returns:
`torch.FloatTensor`: scaled input sample
"""
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
step_index = (self.timesteps == timestep).nonzero().item()
sigma = self.sigmas[step_index]
sample = sample / ((sigma**2 + 1) ** 0.5)
self.is_scale_input_called = True
return sample
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
"""
Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
Args:
num_inference_steps (`int`):
the number of diffusion steps used when generating samples with a pre-trained model.
device (`str` or `torch.device`, optional):
the device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
"""
self.num_inference_steps = num_inference_steps
timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy()
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
self.sigmas = torch.from_numpy(sigmas).to(device=device)
self.timesteps = torch.from_numpy(timesteps).to(device=device)
def step(
self,
model_output: torch.FloatTensor,
timestep: Union[float, torch.FloatTensor],
sample: torch.FloatTensor,
s_churn: float = 0.0,
s_tmin: float = 0.0,
s_tmax: float = float("inf"),
s_noise: float = 1.0,
generator: Optional[torch.Generator] = None,
return_dict: bool = True,
) -> Union[EulerDiscreteSchedulerOutput, Tuple]:
"""
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
process from the learned model outputs (most often the predicted noise).
Args:
model_output (`torch.FloatTensor`): direct output from learned diffusion model.
timestep (`float`): current timestep in the diffusion chain.
sample (`torch.FloatTensor`):
current instance of sample being created by diffusion process.
s_churn (`float`)
s_tmin (`float`)
s_tmax (`float`)
s_noise (`float`)
generator (`torch.Generator`, optional): Random number generator.
return_dict (`bool`): option for returning tuple rather than EulerDiscreteSchedulerOutput class
Returns:
[`~schedulers.scheduling_utils.EulerDiscreteSchedulerOutput`] or `tuple`:
[`~schedulers.scheduling_utils.EulerDiscreteSchedulerOutput`] if `return_dict` is True, otherwise a
`tuple`. When returning a tuple, the first element is the sample tensor.
"""
if (
isinstance(timestep, int)
or isinstance(timestep, torch.IntTensor)
or isinstance(timestep, torch.LongTensor)
):
raise ValueError(
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
" `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
" one of the `scheduler.timesteps` as a timestep.",
)
if not self.is_scale_input_called:
logger.warn(
"The `scale_model_input` function should be called before `step` to ensure correct denoising. "
"See `StableDiffusionPipeline` for a usage example."
)
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
step_index = (self.timesteps == timestep).nonzero().item()
sigma = self.sigmas[step_index]
gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0
device = model_output.device if torch.is_tensor(model_output) else "cpu"
if str(device) == "mps":
# randn does not work reproducibly on mps
noise = torch.randn(model_output.shape, dtype=model_output.dtype, device="cpu", generator=generator).to(
device
)
else:
noise = torch.randn(model_output.shape, dtype=model_output.dtype, device=device, generator=generator).to(
device
)
eps = noise * s_noise
sigma_hat = sigma * (gamma + 1)
if gamma > 0:
sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
pred_original_sample = sample - sigma_hat * model_output
# 2. Convert to an ODE derivative
derivative = (sample - pred_original_sample) / sigma_hat
dt = self.sigmas[step_index + 1] - sigma_hat
prev_sample = sample + derivative * dt
if not return_dict:
return (prev_sample,)
return EulerDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
def add_noise(
self,
original_samples: torch.FloatTensor,
noise: torch.FloatTensor,
timesteps: torch.FloatTensor,
) -> torch.FloatTensor:
# Make sure sigmas and timesteps have the same device and dtype as original_samples
self.sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
# mps does not support float64
self.timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
else:
self.timesteps = self.timesteps.to(original_samples.device)
timesteps = timesteps.to(original_samples.device)
schedule_timesteps = self.timesteps
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
sigma = self.sigmas[step_indices].flatten()
while len(sigma.shape) < len(original_samples.shape):
sigma = sigma.unsqueeze(-1)
noisy_samples = original_samples + noise * sigma
return noisy_samples
def __len__(self):
return self.config.num_train_timesteps

View File

@@ -21,7 +21,7 @@ import torch
from scipy import integrate
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput
from ..utils import BaseOutput, deprecate
from .scheduling_utils import SchedulerMixin
@@ -67,15 +67,6 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
"""
_compatible_classes = [
"DDIMScheduler",
"DDPMScheduler",
"PNDMScheduler",
"EulerDiscreteScheduler",
"EulerAncestralDiscreteScheduler",
"DPMSolverMultistepScheduler",
]
@register_to_config
def __init__(
self,
@@ -212,7 +203,22 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
step_index = (self.timesteps == timestep).nonzero().item()
if (
isinstance(timestep, int)
or isinstance(timestep, torch.IntTensor)
or isinstance(timestep, torch.LongTensor)
):
deprecate(
"timestep as an index",
"0.8.0",
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
" `LMSDiscreteScheduler.step()` will not be supported in future versions. Make sure to pass"
" one of the `scheduler.timesteps` as a timestep.",
standard_warn=False,
)
step_index = timestep
else:
step_index = (self.timesteps == timestep).nonzero().item()
sigma = self.sigmas[step_index]
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
@@ -255,7 +261,19 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
timesteps = timesteps.to(original_samples.device)
schedule_timesteps = self.timesteps
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
if isinstance(timesteps, torch.IntTensor) or isinstance(timesteps, torch.LongTensor):
deprecate(
"timesteps as indices",
"0.8.0",
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
" `LMSDiscreteScheduler.add_noise()` will not be supported in future versions. Make sure to"
" pass values from `scheduler.timesteps` as timesteps.",
standard_warn=False,
)
step_indices = timesteps
else:
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
sigma = self.sigmas[step_indices].flatten()
while len(sigma.shape) < len(original_samples.shape):

View File

@@ -88,15 +88,6 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
"""
_compatible_classes = [
"DDIMScheduler",
"DDPMScheduler",
"LMSDiscreteScheduler",
"EulerDiscreteScheduler",
"EulerAncestralDiscreteScheduler",
"DPMSolverMultistepScheduler",
]
@register_to_config
def __init__(
self,

View File

@@ -1,322 +0,0 @@
# Copyright 2022 ETH Zurich Computer Vision Lab and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput
from .scheduling_utils import SchedulerMixin
@dataclass
class RePaintSchedulerOutput(BaseOutput):
"""
Output class for the scheduler's step function output.
Args:
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
The predicted denoised sample (x_{0}) based on the model output from
the current timestep. `pred_original_sample` can be used to preview progress or for guidance.
"""
prev_sample: torch.FloatTensor
pred_original_sample: torch.FloatTensor
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
to that part of the diffusion process.
Args:
num_diffusion_timesteps (`int`): the number of betas to produce.
max_beta (`float`): the maximum beta to use; use values lower than 1 to
prevent singularities.
Returns:
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
"""
def alpha_bar(time_step):
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
betas = []
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
return torch.tensor(betas, dtype=torch.float32)
class RePaintScheduler(SchedulerMixin, ConfigMixin):
"""
RePaint is a schedule for DDPM inpainting inside a given mask.
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
[`~ConfigMixin.from_config`] functions.
For more details, see the original paper: https://arxiv.org/pdf/2201.09865.pdf
Args:
num_train_timesteps (`int`): number of diffusion steps used to train the model.
beta_start (`float`): the starting `beta` value of inference.
beta_end (`float`): the final `beta` value.
beta_schedule (`str`):
the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
`linear`, `scaled_linear`, or `squaredcos_cap_v2`.
eta (`float`):
The weight of noise for added noise in a diffusion step. Its value is between 0.0 and 1.0 -0.0 is DDIM and
1.0 is DDPM scheduler respectively.
trained_betas (`np.ndarray`, optional):
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
variance_type (`str`):
options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`,
`fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
clip_sample (`bool`, default `True`):
option to clip predicted sample between -1 and 1 for numerical stability.
"""
@register_to_config
def __init__(
self,
num_train_timesteps: int = 1000,
beta_start: float = 0.0001,
beta_end: float = 0.02,
beta_schedule: str = "linear",
eta: float = 0.0,
trained_betas: Optional[np.ndarray] = None,
clip_sample: bool = True,
):
if trained_betas is not None:
self.betas = torch.from_numpy(trained_betas)
elif beta_schedule == "linear":
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model.
self.betas = (
torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
)
elif beta_schedule == "squaredcos_cap_v2":
# Glide cosine schedule
self.betas = betas_for_alpha_bar(num_train_timesteps)
elif beta_schedule == "sigmoid":
# GeoDiff sigmoid schedule
betas = torch.linspace(-6, 6, num_train_timesteps)
self.betas = torch.sigmoid(betas) * (beta_end - beta_start) + beta_start
else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
self.one = torch.tensor(1.0)
self.final_alpha_cumprod = torch.tensor(1.0)
# standard deviation of the initial noise distribution
self.init_noise_sigma = 1.0
# setable values
self.num_inference_steps = None
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy())
self.eta = eta
def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
"""
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep.
Args:
sample (`torch.FloatTensor`): input sample
timestep (`int`, optional): current timestep
Returns:
`torch.FloatTensor`: scaled input sample
"""
return sample
def set_timesteps(
self,
num_inference_steps: int,
jump_length: int = 10,
jump_n_sample: int = 10,
device: Union[str, torch.device] = None,
):
num_inference_steps = min(self.config.num_train_timesteps, num_inference_steps)
self.num_inference_steps = num_inference_steps
timesteps = []
jumps = {}
for j in range(0, num_inference_steps - jump_length, jump_length):
jumps[j] = jump_n_sample - 1
t = num_inference_steps
while t >= 1:
t = t - 1
timesteps.append(t)
if jumps.get(t, 0) > 0:
jumps[t] = jumps[t] - 1
for _ in range(jump_length):
t = t + 1
timesteps.append(t)
timesteps = np.array(timesteps) * (self.config.num_train_timesteps // self.num_inference_steps)
self.timesteps = torch.from_numpy(timesteps).to(device)
def _get_variance(self, t):
prev_timestep = t - self.config.num_train_timesteps // self.num_inference_steps
alpha_prod_t = self.alphas_cumprod[t]
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev
# For t > 0, compute predicted variance βt (see formula (6) and (7) from
# https://arxiv.org/pdf/2006.11239.pdf) and sample from it to get
# previous sample x_{t-1} ~ N(pred_prev_sample, variance) == add
# variance to pred_sample
# Is equivalent to formula (16) in https://arxiv.org/pdf/2010.02502.pdf
# without eta.
# variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.betas[t]
variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
return variance
def step(
self,
model_output: torch.FloatTensor,
timestep: int,
sample: torch.FloatTensor,
original_image: torch.FloatTensor,
mask: torch.FloatTensor,
generator: Optional[torch.Generator] = None,
return_dict: bool = True,
) -> Union[RePaintSchedulerOutput, Tuple]:
"""
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
process from the learned model outputs (most often the predicted noise).
Args:
model_output (`torch.FloatTensor`): direct output from learned
diffusion model.
timestep (`int`): current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
current instance of sample being created by diffusion process.
original_image (`torch.FloatTensor`):
the original image to inpaint on.
mask (`torch.FloatTensor`):
the mask where 0.0 values define which part of the original image to inpaint (change).
generator (`torch.Generator`, *optional*): random number generator.
return_dict (`bool`): option for returning tuple rather than
DDPMSchedulerOutput class
Returns:
[`~schedulers.scheduling_utils.RePaintSchedulerOutput`] or `tuple`:
[`~schedulers.scheduling_utils.RePaintSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
returning a tuple, the first element is the sample tensor.
"""
t = timestep
prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps
# 1. compute alphas, betas
alpha_prod_t = self.alphas_cumprod[t]
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
beta_prod_t = 1 - alpha_prod_t
# 2. compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
pred_original_sample = (sample - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5
# 3. Clip "predicted x_0"
if self.config.clip_sample:
pred_original_sample = torch.clamp(pred_original_sample, -1, 1)
# We choose to follow RePaint Algorithm 1 to get x_{t-1}, however we
# substitute formula (7) in the algorithm coming from DDPM paper
# (formula (4) Algorithm 2 - Sampling) with formula (12) from DDIM paper.
# DDIM schedule gives the same results as DDPM with eta = 1.0
# Noise is being reused in 7. and 8., but no impact on quality has
# been observed.
# 5. Add noise
noise = torch.randn(
model_output.shape, dtype=model_output.dtype, generator=generator, device=model_output.device
)
std_dev_t = self.eta * self._get_variance(timestep) ** 0.5
variance = 0
if t > 0 and self.eta > 0:
variance = std_dev_t * noise
# 6. compute "direction pointing to x_t" of formula (12)
# from https://arxiv.org/pdf/2010.02502.pdf
pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** 0.5 * model_output
# 7. compute x_{t-1} of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
prev_unknown_part = alpha_prod_t_prev**0.5 * pred_original_sample + pred_sample_direction + variance
# 8. Algorithm 1 Line 5 https://arxiv.org/pdf/2201.09865.pdf
prev_known_part = (alpha_prod_t**0.5) * original_image + ((1 - alpha_prod_t) ** 0.5) * noise
# 9. Algorithm 1 Line 8 https://arxiv.org/pdf/2201.09865.pdf
pred_prev_sample = mask * prev_known_part + (1.0 - mask) * prev_unknown_part
if not return_dict:
return (
pred_prev_sample,
pred_original_sample,
)
return RePaintSchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample)
def undo_step(self, sample, timestep, generator=None):
n = self.config.num_train_timesteps // self.num_inference_steps
for i in range(n):
beta = self.betas[timestep + i]
noise = torch.randn(sample.shape, generator=generator, device=sample.device)
# 10. Algorithm 1 Line 10 https://arxiv.org/pdf/2201.09865.pdf
sample = (1 - beta) ** 0.5 * sample + beta**0.5 * noise
return sample
def add_noise(
self,
original_samples: torch.FloatTensor,
noise: torch.FloatTensor,
timesteps: torch.IntTensor,
) -> torch.FloatTensor:
raise NotImplementedError("Use `DDPMScheduler.add_noise()` to train for sampling with RePaint.")
def __len__(self):
return self.config.num_train_timesteps

View File

@@ -1,494 +0,0 @@
# Copyright 2022 Microsoft and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import numpy as np
import torch
import torch.nn.functional as F
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput
from .scheduling_utils import SchedulerMixin
@dataclass
class VQDiffusionSchedulerOutput(BaseOutput):
"""
Output class for the scheduler's step function output.
Args:
prev_sample (`torch.LongTensor` of shape `(batch size, num latent pixels)`):
Computed sample x_{t-1} of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
"""
prev_sample: torch.LongTensor
def index_to_log_onehot(x: torch.LongTensor, num_classes: int) -> torch.FloatTensor:
"""
Convert batch of vector of class indices into batch of log onehot vectors
Args:
x (`torch.LongTensor` of shape `(batch size, vector length)`):
Batch of class indices
num_classes (`int`):
number of classes to be used for the onehot vectors
Returns:
`torch.FloatTensor` of shape `(batch size, num classes, vector length)`:
Log onehot vectors
"""
x_onehot = F.one_hot(x, num_classes)
x_onehot = x_onehot.permute(0, 2, 1)
log_x = torch.log(x_onehot.float().clamp(min=1e-30))
return log_x
def gumbel_noised(logits: torch.FloatTensor, generator: Optional[torch.Generator]) -> torch.FloatTensor:
"""
Apply gumbel noise to `logits`
"""
uniform = torch.rand(logits.shape, device=logits.device, generator=generator)
gumbel_noise = -torch.log(-torch.log(uniform + 1e-30) + 1e-30)
noised = gumbel_noise + logits
return noised
def alpha_schedules(num_diffusion_timesteps: int, alpha_cum_start=0.99999, alpha_cum_end=0.000009):
"""
Cumulative and non-cumulative alpha schedules.
See section 4.1.
"""
att = (
np.arange(0, num_diffusion_timesteps) / (num_diffusion_timesteps - 1) * (alpha_cum_end - alpha_cum_start)
+ alpha_cum_start
)
att = np.concatenate(([1], att))
at = att[1:] / att[:-1]
att = np.concatenate((att[1:], [1]))
return at, att
def gamma_schedules(num_diffusion_timesteps: int, gamma_cum_start=0.000009, gamma_cum_end=0.99999):
"""
Cumulative and non-cumulative gamma schedules.
See section 4.1.
"""
ctt = (
np.arange(0, num_diffusion_timesteps) / (num_diffusion_timesteps - 1) * (gamma_cum_end - gamma_cum_start)
+ gamma_cum_start
)
ctt = np.concatenate(([0], ctt))
one_minus_ctt = 1 - ctt
one_minus_ct = one_minus_ctt[1:] / one_minus_ctt[:-1]
ct = 1 - one_minus_ct
ctt = np.concatenate((ctt[1:], [0]))
return ct, ctt
class VQDiffusionScheduler(SchedulerMixin, ConfigMixin):
"""
The VQ-diffusion transformer outputs predicted probabilities of the initial unnoised image.
The VQ-diffusion scheduler converts the transformer's output into a sample for the unnoised image at the previous
diffusion timestep.
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
[`~ConfigMixin.from_config`] functions.
For more details, see the original paper: https://arxiv.org/abs/2111.14822
Args:
num_vec_classes (`int`):
The number of classes of the vector embeddings of the latent pixels. Includes the class for the masked
latent pixel.
num_train_timesteps (`int`):
Number of diffusion steps used to train the model.
alpha_cum_start (`float`):
The starting cumulative alpha value.
alpha_cum_end (`float`):
The ending cumulative alpha value.
gamma_cum_start (`float`):
The starting cumulative gamma value.
gamma_cum_end (`float`):
The ending cumulative gamma value.
"""
@register_to_config
def __init__(
self,
num_vec_classes: int,
num_train_timesteps: int = 100,
alpha_cum_start: float = 0.99999,
alpha_cum_end: float = 0.000009,
gamma_cum_start: float = 0.000009,
gamma_cum_end: float = 0.99999,
):
self.num_embed = num_vec_classes
# By convention, the index for the mask class is the last class index
self.mask_class = self.num_embed - 1
at, att = alpha_schedules(num_train_timesteps, alpha_cum_start=alpha_cum_start, alpha_cum_end=alpha_cum_end)
ct, ctt = gamma_schedules(num_train_timesteps, gamma_cum_start=gamma_cum_start, gamma_cum_end=gamma_cum_end)
num_non_mask_classes = self.num_embed - 1
bt = (1 - at - ct) / num_non_mask_classes
btt = (1 - att - ctt) / num_non_mask_classes
at = torch.tensor(at.astype("float64"))
bt = torch.tensor(bt.astype("float64"))
ct = torch.tensor(ct.astype("float64"))
log_at = torch.log(at)
log_bt = torch.log(bt)
log_ct = torch.log(ct)
att = torch.tensor(att.astype("float64"))
btt = torch.tensor(btt.astype("float64"))
ctt = torch.tensor(ctt.astype("float64"))
log_cumprod_at = torch.log(att)
log_cumprod_bt = torch.log(btt)
log_cumprod_ct = torch.log(ctt)
self.log_at = log_at.float()
self.log_bt = log_bt.float()
self.log_ct = log_ct.float()
self.log_cumprod_at = log_cumprod_at.float()
self.log_cumprod_bt = log_cumprod_bt.float()
self.log_cumprod_ct = log_cumprod_ct.float()
# setable values
self.num_inference_steps = None
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy())
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
"""
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
Args:
num_inference_steps (`int`):
the number of diffusion steps used when generating samples with a pre-trained model.
device (`str` or `torch.device`):
device to place the timesteps and the diffusion process parameters (alpha, beta, gamma) on.
"""
self.num_inference_steps = num_inference_steps
timesteps = np.arange(0, self.num_inference_steps)[::-1].copy()
self.timesteps = torch.from_numpy(timesteps).to(device)
self.log_at = self.log_at.to(device)
self.log_bt = self.log_bt.to(device)
self.log_ct = self.log_ct.to(device)
self.log_cumprod_at = self.log_cumprod_at.to(device)
self.log_cumprod_bt = self.log_cumprod_bt.to(device)
self.log_cumprod_ct = self.log_cumprod_ct.to(device)
def step(
self,
model_output: torch.FloatTensor,
timestep: torch.long,
sample: torch.LongTensor,
generator: Optional[torch.Generator] = None,
return_dict: bool = True,
) -> Union[VQDiffusionSchedulerOutput, Tuple]:
"""
Predict the sample at the previous timestep via the reverse transition distribution i.e. Equation (11). See the
docstring for `self.q_posterior` for more in depth docs on how Equation (11) is computed.
Args:
log_p_x_0: (`torch.FloatTensor` of shape `(batch size, num classes - 1, num latent pixels)`):
The log probabilities for the predicted classes of the initial latent pixels. Does not include a
prediction for the masked class as the initial unnoised image cannot be masked.
t (`torch.long`):
The timestep that determines which transition matrices are used.
x_t: (`torch.LongTensor` of shape `(batch size, num latent pixels)`):
The classes of each latent pixel at time `t`
generator: (`torch.Generator` or None):
RNG for the noise applied to p(x_{t-1} | x_t) before it is sampled from.
return_dict (`bool`):
option for returning tuple rather than VQDiffusionSchedulerOutput class
Returns:
[`~schedulers.scheduling_utils.VQDiffusionSchedulerOutput`] or `tuple`:
[`~schedulers.scheduling_utils.VQDiffusionSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`.
When returning a tuple, the first element is the sample tensor.
"""
if timestep == 0:
log_p_x_t_min_1 = model_output
else:
log_p_x_t_min_1 = self.q_posterior(model_output, sample, timestep)
log_p_x_t_min_1 = gumbel_noised(log_p_x_t_min_1, generator)
x_t_min_1 = log_p_x_t_min_1.argmax(dim=1)
if not return_dict:
return (x_t_min_1,)
return VQDiffusionSchedulerOutput(prev_sample=x_t_min_1)
def q_posterior(self, log_p_x_0, x_t, t):
"""
Calculates the log probabilities for the predicted classes of the image at timestep `t-1`. I.e. Equation (11).
Instead of directly computing equation (11), we use Equation (5) to restate Equation (11) in terms of only
forward probabilities.
Equation (11) stated in terms of forward probabilities via Equation (5):
Where:
- the sum is over x_0 = {C_0 ... C_{k-1}} (classes for x_0)
p(x_{t-1} | x_t) = sum( q(x_t | x_{t-1}) * q(x_{t-1} | x_0) * p(x_0) / q(x_t | x_0) )
Args:
log_p_x_0: (`torch.FloatTensor` of shape `(batch size, num classes - 1, num latent pixels)`):
The log probabilities for the predicted classes of the initial latent pixels. Does not include a
prediction for the masked class as the initial unnoised image cannot be masked.
x_t: (`torch.LongTensor` of shape `(batch size, num latent pixels)`):
The classes of each latent pixel at time `t`
t (torch.Long):
The timestep that determines which transition matrix is used.
Returns:
`torch.FloatTensor` of shape `(batch size, num classes, num latent pixels)`:
The log probabilities for the predicted classes of the image at timestep `t-1`. I.e. Equation (11).
"""
log_onehot_x_t = index_to_log_onehot(x_t, self.num_embed)
log_q_x_t_given_x_0 = self.log_Q_t_transitioning_to_known_class(
t=t, x_t=x_t, log_onehot_x_t=log_onehot_x_t, cumulative=True
)
log_q_t_given_x_t_min_1 = self.log_Q_t_transitioning_to_known_class(
t=t, x_t=x_t, log_onehot_x_t=log_onehot_x_t, cumulative=False
)
# p_0(x_0=C_0 | x_t) / q(x_t | x_0=C_0) ... p_n(x_0=C_0 | x_t) / q(x_t | x_0=C_0)
# . . .
# . . .
# . . .
# p_0(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) ... p_n(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1})
q = log_p_x_0 - log_q_x_t_given_x_0
# sum_0 = p_0(x_0=C_0 | x_t) / q(x_t | x_0=C_0) + ... + p_0(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}), ... ,
# sum_n = p_n(x_0=C_0 | x_t) / q(x_t | x_0=C_0) + ... + p_n(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1})
q_log_sum_exp = torch.logsumexp(q, dim=1, keepdim=True)
# p_0(x_0=C_0 | x_t) / q(x_t | x_0=C_0) / sum_0 ... p_n(x_0=C_0 | x_t) / q(x_t | x_0=C_0) / sum_n
# . . .
# . . .
# . . .
# p_0(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) / sum_0 ... p_n(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) / sum_n
q = q - q_log_sum_exp
# (p_0(x_0=C_0 | x_t) / q(x_t | x_0=C_0) / sum_0) * a_cumulative_{t-1} + b_cumulative_{t-1} ... (p_n(x_0=C_0 | x_t) / q(x_t | x_0=C_0) / sum_n) * a_cumulative_{t-1} + b_cumulative_{t-1}
# . . .
# . . .
# . . .
# (p_0(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) / sum_0) * a_cumulative_{t-1} + b_cumulative_{t-1} ... (p_n(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) / sum_n) * a_cumulative_{t-1} + b_cumulative_{t-1}
# c_cumulative_{t-1} ... c_cumulative_{t-1}
q = self.apply_cumulative_transitions(q, t - 1)
# ((p_0(x_0=C_0 | x_t) / q(x_t | x_0=C_0) / sum_0) * a_cumulative_{t-1} + b_cumulative_{t-1}) * q(x_t | x_{t-1}=C_0) * sum_0 ... ((p_n(x_0=C_0 | x_t) / q(x_t | x_0=C_0) / sum_n) * a_cumulative_{t-1} + b_cumulative_{t-1}) * q(x_t | x_{t-1}=C_0) * sum_n
# . . .
# . . .
# . . .
# ((p_0(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) / sum_0) * a_cumulative_{t-1} + b_cumulative_{t-1}) * q(x_t | x_{t-1}=C_{k-1}) * sum_0 ... ((p_n(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) / sum_n) * a_cumulative_{t-1} + b_cumulative_{t-1}) * q(x_t | x_{t-1}=C_{k-1}) * sum_n
# c_cumulative_{t-1} * q(x_t | x_{t-1}=C_k) * sum_0 ... c_cumulative_{t-1} * q(x_t | x_{t-1}=C_k) * sum_0
log_p_x_t_min_1 = q + log_q_t_given_x_t_min_1 + q_log_sum_exp
# For each column, there are two possible cases.
#
# Where:
# - sum(p_n(x_0))) is summing over all classes for x_0
# - C_i is the class transitioning from (not to be confused with c_t and c_cumulative_t being used for gamma's)
# - C_j is the class transitioning to
#
# 1. x_t is masked i.e. x_t = c_k
#
# Simplifying the expression, the column vector is:
# .
# .
# .
# (c_t / c_cumulative_t) * (a_cumulative_{t-1} * p_n(x_0 = C_i | x_t) + b_cumulative_{t-1} * sum(p_n(x_0)))
# .
# .
# .
# (c_cumulative_{t-1} / c_cumulative_t) * sum(p_n(x_0))
#
# From equation (11) stated in terms of forward probabilities, the last row is trivially verified.
#
# For the other rows, we can state the equation as ...
#
# (c_t / c_cumulative_t) * [b_cumulative_{t-1} * p(x_0=c_0) + ... + (a_cumulative_{t-1} + b_cumulative_{t-1}) * p(x_0=C_i) + ... + b_cumulative_{k-1} * p(x_0=c_{k-1})]
#
# This verifies the other rows.
#
# 2. x_t is not masked
#
# Simplifying the expression, there are two cases for the rows of the column vector, where C_j = C_i and where C_j != C_i:
# .
# .
# .
# C_j != C_i: b_t * ((b_cumulative_{t-1} / b_cumulative_t) * p_n(x_0 = c_0) + ... + ((a_cumulative_{t-1} + b_cumulative_{t-1}) / b_cumulative_t) * p_n(x_0 = C_i) + ... + (b_cumulative_{t-1} / (a_cumulative_t + b_cumulative_t)) * p_n(c_0=C_j) + ... + (b_cumulative_{t-1} / b_cumulative_t) * p_n(x_0 = c_{k-1}))
# .
# .
# .
# C_j = C_i: (a_t + b_t) * ((b_cumulative_{t-1} / b_cumulative_t) * p_n(x_0 = c_0) + ... + ((a_cumulative_{t-1} + b_cumulative_{t-1}) / (a_cumulative_t + b_cumulative_t)) * p_n(x_0 = C_i = C_j) + ... + (b_cumulative_{t-1} / b_cumulative_t) * p_n(x_0 = c_{k-1}))
# .
# .
# .
# 0
#
# The last row is trivially verified. The other rows can be verified by directly expanding equation (11) stated in terms of forward probabilities.
return log_p_x_t_min_1
def log_Q_t_transitioning_to_known_class(
self, *, t: torch.int, x_t: torch.LongTensor, log_onehot_x_t: torch.FloatTensor, cumulative: bool
):
"""
Returns the log probabilities of the rows from the (cumulative or non-cumulative) transition matrix for each
latent pixel in `x_t`.
See equation (7) for the complete non-cumulative transition matrix. The complete cumulative transition matrix
is the same structure except the parameters (alpha, beta, gamma) are the cumulative analogs.
Args:
t (torch.Long):
The timestep that determines which transition matrix is used.
x_t (`torch.LongTensor` of shape `(batch size, num latent pixels)`):
The classes of each latent pixel at time `t`.
log_onehot_x_t (`torch.FloatTensor` of shape `(batch size, num classes, num latent pixels)`):
The log one-hot vectors of `x_t`
cumulative (`bool`):
If cumulative is `False`, we use the single step transition matrix `t-1`->`t`. If cumulative is `True`,
we use the cumulative transition matrix `0`->`t`.
Returns:
`torch.FloatTensor` of shape `(batch size, num classes - 1, num latent pixels)`:
Each _column_ of the returned matrix is a _row_ of log probabilities of the complete probability
transition matrix.
When non cumulative, returns `self.num_classes - 1` rows because the initial latent pixel cannot be
masked.
Where:
- `q_n` is the probability distribution for the forward process of the `n`th latent pixel.
- C_0 is a class of a latent pixel embedding
- C_k is the class of the masked latent pixel
non-cumulative result (omitting logarithms):
```
q_0(x_t | x_{t-1} = C_0) ... q_n(x_t | x_{t-1} = C_0)
. . .
. . .
. . .
q_0(x_t | x_{t-1} = C_k) ... q_n(x_t | x_{t-1} = C_k)
```
cumulative result (omitting logarithms):
```
q_0_cumulative(x_t | x_0 = C_0) ... q_n_cumulative(x_t | x_0 = C_0)
. . .
. . .
. . .
q_0_cumulative(x_t | x_0 = C_{k-1}) ... q_n_cumulative(x_t | x_0 = C_{k-1})
```
"""
if cumulative:
a = self.log_cumprod_at[t]
b = self.log_cumprod_bt[t]
c = self.log_cumprod_ct[t]
else:
a = self.log_at[t]
b = self.log_bt[t]
c = self.log_ct[t]
if not cumulative:
# The values in the onehot vector can also be used as the logprobs for transitioning
# from masked latent pixels. If we are not calculating the cumulative transitions,
# we need to save these vectors to be re-appended to the final matrix so the values
# aren't overwritten.
#
# `P(x_t!=mask|x_{t-1=mask}) = 0` and 0 will be the value of the last row of the onehot vector
# if x_t is not masked
#
# `P(x_t=mask|x_{t-1=mask}) = 1` and 1 will be the value of the last row of the onehot vector
# if x_t is masked
log_onehot_x_t_transitioning_from_masked = log_onehot_x_t[:, -1, :].unsqueeze(1)
# `index_to_log_onehot` will add onehot vectors for masked pixels,
# so the default one hot matrix has one too many rows. See the doc string
# for an explanation of the dimensionality of the returned matrix.
log_onehot_x_t = log_onehot_x_t[:, :-1, :]
# this is a cheeky trick to produce the transition probabilities using log one-hot vectors.
#
# Don't worry about what values this sets in the columns that mark transitions
# to masked latent pixels. They are overwrote later with the `mask_class_mask`.
#
# Looking at the below logspace formula in non-logspace, each value will evaluate to either
# `1 * a + b = a + b` where `log_Q_t` has the one hot value in the column
# or
# `0 * a + b = b` where `log_Q_t` has the 0 values in the column.
#
# See equation 7 for more details.
log_Q_t = (log_onehot_x_t + a).logaddexp(b)
# The whole column of each masked pixel is `c`
mask_class_mask = x_t == self.mask_class
mask_class_mask = mask_class_mask.unsqueeze(1).expand(-1, self.num_embed - 1, -1)
log_Q_t[mask_class_mask] = c
if not cumulative:
log_Q_t = torch.cat((log_Q_t, log_onehot_x_t_transitioning_from_masked), dim=1)
return log_Q_t
def apply_cumulative_transitions(self, q, t):
bsz = q.shape[0]
a = self.log_cumprod_at[t]
b = self.log_cumprod_bt[t]
c = self.log_cumprod_ct[t]
num_latent_pixels = q.shape[2]
c = c.expand(bsz, 1, num_latent_pixels)
q = (q + a).logaddexp(b)
q = torch.cat((q, c), dim=1)
return q

View File

@@ -31,7 +31,6 @@ from .import_utils import (
is_scipy_available,
is_tf_available,
is_torch_available,
is_torch_version,
is_transformers_available,
is_unidecode_available,
requires_backends,
@@ -43,7 +42,6 @@ from .outputs import BaseOutput
if is_torch_available():
from .testing_utils import (
floats_tensor,
load_hf_numpy,
load_image,
load_numpy,
parse_flag_from_env,

View File

@@ -94,21 +94,6 @@ class FlaxDDPMScheduler(metaclass=DummyObject):
requires_backends(cls, ["flax"])
class FlaxDPMSolverMultistepScheduler(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
class FlaxKarrasVeScheduler(metaclass=DummyObject):
_backends = ["flax"]

View File

@@ -34,21 +34,6 @@ class AutoencoderKL(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class Transformer2DModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class UNet1DModel(metaclass=DummyObject):
_backends = ["torch"]
@@ -242,21 +227,6 @@ class PNDMPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class RePaintPipeline(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class ScoreSdeVePipeline(metaclass=DummyObject):
_backends = ["torch"]
@@ -302,51 +272,6 @@ class DDPMScheduler(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class DPMSolverMultistepScheduler(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class EulerAncestralDiscreteScheduler(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class EulerDiscreteScheduler(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class IPNDMScheduler(metaclass=DummyObject):
_backends = ["torch"]
@@ -392,21 +317,6 @@ class PNDMScheduler(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class RePaintScheduler(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class SchedulerMixin(metaclass=DummyObject):
_backends = ["torch"]
@@ -437,21 +347,6 @@ class ScoreSdeVeScheduler(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class VQDiffusionScheduler(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class EMAModel(metaclass=DummyObject):
_backends = ["torch"]

View File

@@ -4,21 +4,6 @@
from ..utils import DummyObject, requires_backends
class CycleDiffusionPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class LDMTextToImagePipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
@@ -92,18 +77,3 @@ class StableDiffusionPipeline(metaclass=DummyObject):
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class VQDiffusionPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])

View File

@@ -15,14 +15,11 @@
Import utilities: Utilities related to imports and our lazy inits.
"""
import importlib.util
import operator as op
import os
import sys
from collections import OrderedDict
from typing import Union
from packaging import version
from packaging.version import Version, parse
from . import logging
@@ -43,8 +40,6 @@ USE_TF = os.environ.get("USE_TF", "AUTO").upper()
USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
USE_JAX = os.environ.get("USE_FLAX", "AUTO").upper()
STR_OPERATION_TO_FUNC = {">": op.gt, ">=": op.ge, "==": op.eq, "!=": op.ne, "<=": op.le, "<": op.lt}
_torch_version = "N/A"
if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES:
_torch_available = importlib.util.find_spec("torch") is not None
@@ -95,8 +90,7 @@ else:
logger.info("Disabling Tensorflow because USE_TORCH is set")
_tf_available = False
_jax_version = "N/A"
_flax_version = "N/A"
if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES:
_flax_available = importlib.util.find_spec("jax") is not None and importlib.util.find_spec("flax") is not None
if _flax_available:
@@ -142,7 +136,6 @@ except importlib_metadata.PackageNotFoundError:
_modelcards_available = False
_onnxruntime_version = "N/A"
_onnx_available = importlib.util.find_spec("onnxruntime") is not None
if _onnx_available:
candidates = ("onnxruntime", "onnxruntime-gpu", "onnxruntime-directml", "onnxruntime-openvino")
@@ -173,18 +166,6 @@ try:
except importlib_metadata.PackageNotFoundError:
_accelerate_available = False
_xformers_available = importlib.util.find_spec("xformers") is not None
try:
_xformers_version = importlib_metadata.version("xformers")
if _torch_available:
import torch
if torch.__version__ < version.Version("1.12"):
raise ValueError("PyTorch should be >= 1.12")
logger.debug(f"Successfully imported xformers version {_xformers_version}")
except importlib_metadata.PackageNotFoundError:
_xformers_available = False
def is_torch_available():
return _torch_available
@@ -222,10 +203,6 @@ def is_scipy_available():
return _scipy_available
def is_xformers_available():
return _xformers_available
def is_accelerate_available():
return _accelerate_available
@@ -314,36 +291,3 @@ class DummyObject(type):
if key.startswith("_"):
return super().__getattr__(cls, key)
requires_backends(cls, cls._backends)
# This function was copied from: https://github.com/huggingface/accelerate/blob/874c4967d94badd24f893064cc3bef45f57cadf7/src/accelerate/utils/versions.py#L319
def compare_versions(library_or_version: Union[str, Version], operation: str, requirement_version: str):
"""
Args:
Compares a library version to some requirement using a given operation.
library_or_version (`str` or `packaging.version.Version`):
A library name or a version to check.
operation (`str`):
A string representation of an operator, such as `">"` or `"<="`.
requirement_version (`str`):
The version to compare the library version against
"""
if operation not in STR_OPERATION_TO_FUNC.keys():
raise ValueError(f"`operation` must be one of {list(STR_OPERATION_TO_FUNC.keys())}, received {operation}")
operation = STR_OPERATION_TO_FUNC[operation]
if isinstance(library_or_version, str):
library_or_version = parse(importlib_metadata.version(library_or_version))
return operation(library_or_version, parse(requirement_version))
# This function was copied from: https://github.com/huggingface/accelerate/blob/874c4967d94badd24f893064cc3bef45f57cadf7/src/accelerate/utils/versions.py#L338
def is_torch_version(operation: str, version: str):
"""
Args:
Compares the current PyTorch version to a given reference with an operation.
operation (`str`):
A string representation of an operator, such as `">"` or `"<="`
version (`str`):
A string version of PyTorch
"""
return compare_versions(parse(_torch_version), operation, version)

View File

@@ -139,29 +139,6 @@ def require_onnxruntime(test_case):
return unittest.skipUnless(is_onnx_available(), "test requires onnxruntime")(test_case)
def load_numpy(arry: Union[str, np.ndarray]) -> np.ndarray:
if isinstance(arry, str):
if arry.startswith("http://") or arry.startswith("https://"):
response = requests.get(arry)
response.raise_for_status()
arry = np.load(BytesIO(response.content))
elif os.path.isfile(arry):
arry = np.load(arry)
else:
raise ValueError(
f"Incorrect path or url, URLs must start with `http://` or `https://`, and {arry} is not a valid path"
)
elif isinstance(arry, np.ndarray):
pass
else:
raise ValueError(
"Incorrect format used for numpy ndarray. Should be an url linking to an image, a local path, or a"
" ndarray."
)
return arry
def load_image(image: Union[str, PIL.Image.Image]) -> PIL.Image.Image:
"""
Args:
@@ -191,13 +168,17 @@ def load_image(image: Union[str, PIL.Image.Image]) -> PIL.Image.Image:
return image
def load_hf_numpy(path) -> np.ndarray:
def load_numpy(path) -> np.ndarray:
if not path.startswith("http://") or path.startswith("https://"):
path = os.path.join(
"https://huggingface.co/datasets/fusing/diffusers-testing/resolve/main", urllib.parse.quote(path)
)
return load_numpy(path)
response = requests.get(path)
response.raise_for_status()
array = np.load(BytesIO(response.content))
return array
# --- pytest conf functions --- #

View File

@@ -28,7 +28,7 @@ class UnetModel1DTests(unittest.TestCase):
@slow
def test_unet_1d_maestro(self):
model_id = "harmonai/maestro-150k"
model = UNet1DModel.from_pretrained(model_id, subfolder="unet")
model = UNet1DModel.from_pretrained(model_id, subfolder="unet", device_map="auto")
model.to(torch_device)
sample_size = 65536

View File

@@ -21,15 +21,7 @@ import unittest
import torch
from diffusers import UNet2DConditionModel, UNet2DModel
from diffusers.utils import (
floats_tensor,
load_hf_numpy,
logging,
require_torch_gpu,
slow,
torch_all_close,
torch_device,
)
from diffusers.utils import floats_tensor, load_numpy, logging, require_torch_gpu, slow, torch_all_close, torch_device
from parameterized import parameterized
from ..test_modeling_common import ModelTesterMixin
@@ -125,7 +117,9 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
@unittest.skipIf(torch_device != "cuda", "This test is supposed to run on GPU")
def test_from_pretrained_accelerate(self):
model, _ = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", output_loading_info=True)
model, _ = UNet2DModel.from_pretrained(
"fusing/unet-ldm-dummy-update", output_loading_info=True, device_map="auto"
)
model.to(torch_device)
image = model(**self.dummy_input).sample
@@ -133,8 +127,9 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
@unittest.skipIf(torch_device != "cuda", "This test is supposed to run on GPU")
def test_from_pretrained_accelerate_wont_change_results(self):
# by defautl model loading will use accelerate as `low_cpu_mem_usage=True`
model_accelerate, _ = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", output_loading_info=True)
model_accelerate, _ = UNet2DModel.from_pretrained(
"fusing/unet-ldm-dummy-update", output_loading_info=True, device_map="auto"
)
model_accelerate.to(torch_device)
model_accelerate.eval()
@@ -156,7 +151,7 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
gc.collect()
model_normal_load, _ = UNet2DModel.from_pretrained(
"fusing/unet-ldm-dummy-update", output_loading_info=True, low_cpu_mem_usage=False
"fusing/unet-ldm-dummy-update", output_loading_info=True, device_map="auto"
)
model_normal_load.to(torch_device)
model_normal_load.eval()
@@ -170,8 +165,9 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
gc.collect()
tracemalloc.start()
# by defautl model loading will use accelerate as `low_cpu_mem_usage=True`
model_accelerate, _ = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", output_loading_info=True)
model_accelerate, _ = UNet2DModel.from_pretrained(
"fusing/unet-ldm-dummy-update", output_loading_info=True, device_map="auto"
)
model_accelerate.to(torch_device)
model_accelerate.eval()
_, peak_accelerate = tracemalloc.get_traced_memory()
@@ -180,9 +176,7 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
torch.cuda.empty_cache()
gc.collect()
model_normal_load, _ = UNet2DModel.from_pretrained(
"fusing/unet-ldm-dummy-update", output_loading_info=True, low_cpu_mem_usage=False
)
model_normal_load, _ = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", output_loading_info=True)
model_normal_load.to(torch_device)
model_normal_load.eval()
_, peak_normal = tracemalloc.get_traced_memory()
@@ -346,7 +340,9 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
@slow
def test_from_pretrained_hub(self):
model, loading_info = UNet2DModel.from_pretrained("google/ncsnpp-celebahq-256", output_loading_info=True)
model, loading_info = UNet2DModel.from_pretrained(
"google/ncsnpp-celebahq-256", output_loading_info=True, device_map="auto"
)
self.assertIsNotNone(model)
self.assertEqual(len(loading_info["missing_keys"]), 0)
@@ -360,7 +356,7 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
@slow
def test_output_pretrained_ve_mid(self):
model = UNet2DModel.from_pretrained("google/ncsnpp-celebahq-256")
model = UNet2DModel.from_pretrained("google/ncsnpp-celebahq-256", device_map="auto")
model.to(torch_device)
torch.manual_seed(0)
@@ -427,7 +423,7 @@ class UNet2DConditionModelIntegrationTests(unittest.TestCase):
def get_latents(self, seed=0, shape=(4, 4, 64, 64), fp16=False):
dtype = torch.float16 if fp16 else torch.float32
image = torch.from_numpy(load_hf_numpy(self.get_file_format(seed, shape))).to(torch_device).to(dtype)
image = torch.from_numpy(load_numpy(self.get_file_format(seed, shape))).to(torch_device).to(dtype)
return image
def get_unet_model(self, fp16=False, model_id="CompVis/stable-diffusion-v1-4"):
@@ -435,7 +431,7 @@ class UNet2DConditionModelIntegrationTests(unittest.TestCase):
torch_dtype = torch.float16 if fp16 else torch.float32
model = UNet2DConditionModel.from_pretrained(
model_id, subfolder="unet", torch_dtype=torch_dtype, revision=revision
model_id, subfolder="unet", torch_dtype=torch_dtype, revision=revision, device_map="auto"
)
model.to(torch_device).eval()
@@ -443,7 +439,7 @@ class UNet2DConditionModelIntegrationTests(unittest.TestCase):
def get_encoder_hidden_states(self, seed=0, shape=(4, 77, 768), fp16=False):
dtype = torch.float16 if fp16 else torch.float32
hidden_states = torch.from_numpy(load_hf_numpy(self.get_file_format(seed, shape))).to(torch_device).to(dtype)
hidden_states = torch.from_numpy(load_numpy(self.get_file_format(seed, shape))).to(torch_device).to(dtype)
return hidden_states
@parameterized.expand(
@@ -456,7 +452,6 @@ class UNet2DConditionModelIntegrationTests(unittest.TestCase):
# fmt: on
]
)
@require_torch_gpu
def test_compvis_sd_v1_4(self, seed, timestep, expected_slice):
model = self.get_unet_model(model_id="CompVis/stable-diffusion-v1-4")
latents = self.get_latents(seed)
@@ -508,7 +503,6 @@ class UNet2DConditionModelIntegrationTests(unittest.TestCase):
# fmt: on
]
)
@require_torch_gpu
def test_compvis_sd_v1_5(self, seed, timestep, expected_slice):
model = self.get_unet_model(model_id="runwayml/stable-diffusion-v1-5")
latents = self.get_latents(seed)
@@ -560,7 +554,6 @@ class UNet2DConditionModelIntegrationTests(unittest.TestCase):
# fmt: on
]
)
@require_torch_gpu
def test_compvis_sd_inpaint(self, seed, timestep, expected_slice):
model = self.get_unet_model(model_id="runwayml/stable-diffusion-inpainting")
latents = self.get_latents(seed, shape=(4, 9, 64, 64))

View File

@@ -20,7 +20,7 @@ import torch
from diffusers import AutoencoderKL
from diffusers.modeling_utils import ModelMixin
from diffusers.utils import floats_tensor, load_hf_numpy, require_torch_gpu, slow, torch_all_close, torch_device
from diffusers.utils import floats_tensor, load_numpy, require_torch_gpu, slow, torch_all_close, torch_device
from parameterized import parameterized
from ..test_modeling_common import ModelTesterMixin
@@ -147,7 +147,7 @@ class AutoencoderKLIntegrationTests(unittest.TestCase):
def get_sd_image(self, seed=0, shape=(4, 3, 512, 512), fp16=False):
dtype = torch.float16 if fp16 else torch.float32
image = torch.from_numpy(load_hf_numpy(self.get_file_format(seed, shape))).to(torch_device).to(dtype)
image = torch.from_numpy(load_numpy(self.get_file_format(seed, shape))).to(torch_device).to(dtype)
return image
def get_sd_vae_model(self, model_id="CompVis/stable-diffusion-v1-4", fp16=False):
@@ -155,10 +155,7 @@ class AutoencoderKLIntegrationTests(unittest.TestCase):
torch_dtype = torch.float16 if fp16 else torch.float32
model = AutoencoderKL.from_pretrained(
model_id,
subfolder="vae",
torch_dtype=torch_dtype,
revision=revision,
model_id, subfolder="vae", torch_dtype=torch_dtype, revision=revision, device_map="auto"
)
model.to(torch_device).eval()

View File

@@ -86,7 +86,7 @@ class PipelineIntegrationTests(unittest.TestCase):
def test_dance_diffusion(self):
device = torch_device
pipe = DanceDiffusionPipeline.from_pretrained("harmonai/maestro-150k")
pipe = DanceDiffusionPipeline.from_pretrained("harmonai/maestro-150k", device_map="auto")
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None)
@@ -103,7 +103,9 @@ class PipelineIntegrationTests(unittest.TestCase):
def test_dance_diffusion_fp16(self):
device = torch_device
pipe = DanceDiffusionPipeline.from_pretrained("harmonai/maestro-150k", torch_dtype=torch.float16)
pipe = DanceDiffusionPipeline.from_pretrained(
"harmonai/maestro-150k", torch_dtype=torch.float16, device_map="auto"
)
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None)

View File

@@ -78,7 +78,7 @@ class DDIMPipelineIntegrationTests(unittest.TestCase):
def test_inference_ema_bedroom(self):
model_id = "google/ddpm-ema-bedroom-256"
unet = UNet2DModel.from_pretrained(model_id)
unet = UNet2DModel.from_pretrained(model_id, device_map="auto")
scheduler = DDIMScheduler.from_config(model_id)
ddpm = DDIMPipeline(unet=unet, scheduler=scheduler)
@@ -97,7 +97,7 @@ class DDIMPipelineIntegrationTests(unittest.TestCase):
def test_inference_cifar10(self):
model_id = "google/ddpm-cifar10-32"
unet = UNet2DModel.from_pretrained(model_id)
unet = UNet2DModel.from_pretrained(model_id, device_map="auto")
scheduler = DDIMScheduler()
ddim = DDIMPipeline(unet=unet, scheduler=scheduler)

View File

@@ -38,7 +38,7 @@ class DDPMPipelineIntegrationTests(unittest.TestCase):
def test_inference_cifar10(self):
model_id = "google/ddpm-cifar10-32"
unet = UNet2DModel.from_pretrained(model_id)
unet = UNet2DModel.from_pretrained(model_id, device_map="auto")
scheduler = DDPMScheduler.from_config(model_id)
ddpm = DDPMPipeline(unet=unet, scheduler=scheduler)

View File

@@ -70,7 +70,7 @@ class KarrasVePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
class KarrasVePipelineIntegrationTests(unittest.TestCase):
def test_inference(self):
model_id = "google/ncsnpp-celebahq-256"
model = UNet2DModel.from_pretrained(model_id)
model = UNet2DModel.from_pretrained(model_id, device_map="auto")
scheduler = KarrasVeScheduler()
pipe = KarrasVePipeline(unet=model, scheduler=scheduler)

View File

@@ -121,7 +121,7 @@ class LDMTextToImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
@require_torch
class LDMTextToImagePipelineIntegrationTests(unittest.TestCase):
def test_inference_text2img(self):
ldm = LDMTextToImagePipeline.from_pretrained("CompVis/ldm-text2im-large-256")
ldm = LDMTextToImagePipeline.from_pretrained("CompVis/ldm-text2im-large-256", device_map="auto")
ldm.to(torch_device)
ldm.set_progress_bar_config(disable=None)
@@ -138,7 +138,7 @@ class LDMTextToImagePipelineIntegrationTests(unittest.TestCase):
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
def test_inference_text2img_fast(self):
ldm = LDMTextToImagePipeline.from_pretrained("CompVis/ldm-text2im-large-256")
ldm = LDMTextToImagePipeline.from_pretrained("CompVis/ldm-text2im-large-256", device_map="auto")
ldm.to(torch_device)
ldm.set_progress_bar_config(disable=None)

Some files were not shown because too many files have changed in this diff Show More