mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-07 21:14:44 +08:00
Compare commits
1 Commits
fix-mps-cr
...
add_schedu
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b444122ab0 |
50
.github/workflows/build_docker_images.yml
vendored
50
.github/workflows/build_docker_images.yml
vendored
@@ -1,50 +0,0 @@
|
||||
name: Build Docker images (nightly)
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
schedule:
|
||||
- cron: "0 0 * * *" # every day at midnight
|
||||
|
||||
concurrency:
|
||||
group: docker-image-builds
|
||||
cancel-in-progress: false
|
||||
|
||||
env:
|
||||
REGISTRY: diffusers
|
||||
|
||||
jobs:
|
||||
build-docker-images:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
image-name:
|
||||
- diffusers-pytorch-cpu
|
||||
- diffusers-pytorch-cuda
|
||||
- diffusers-flax-cpu
|
||||
- diffusers-flax-tpu
|
||||
- diffusers-onnxruntime-cpu
|
||||
- diffusers-onnxruntime-cuda
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v2
|
||||
with:
|
||||
username: ${{ env.REGISTRY }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
|
||||
- name: Build and push
|
||||
uses: docker/build-push-action@v3
|
||||
with:
|
||||
no-cache: true
|
||||
context: ./docker/${{ matrix.image-name }}
|
||||
push: true
|
||||
tags: ${{ env.REGISTRY }}/${{ matrix.image-name }}:latest
|
||||
75
.github/workflows/pr_tests.yml
vendored
75
.github/workflows/pr_tests.yml
vendored
@@ -10,46 +10,19 @@ concurrency:
|
||||
cancel-in-progress: true
|
||||
|
||||
env:
|
||||
DIFFUSERS_IS_CI: yes
|
||||
OMP_NUM_THREADS: 4
|
||||
MKL_NUM_THREADS: 4
|
||||
OMP_NUM_THREADS: 8
|
||||
MKL_NUM_THREADS: 8
|
||||
PYTEST_TIMEOUT: 60
|
||||
MPS_TORCH_VERSION: 1.13.0
|
||||
|
||||
jobs:
|
||||
run_fast_tests:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
config:
|
||||
- name: Fast PyTorch CPU tests on Ubuntu
|
||||
framework: pytorch
|
||||
runner: docker-cpu
|
||||
image: diffusers/diffusers-pytorch-cpu
|
||||
report: torch_cpu
|
||||
- name: Fast Flax CPU tests on Ubuntu
|
||||
framework: flax
|
||||
runner: docker-cpu
|
||||
image: diffusers/diffusers-flax-cpu
|
||||
report: flax_cpu
|
||||
- name: Fast ONNXRuntime CPU tests on Ubuntu
|
||||
framework: onnxruntime
|
||||
runner: docker-cpu
|
||||
image: diffusers/diffusers-onnxruntime-cpu
|
||||
report: onnx_cpu
|
||||
|
||||
name: ${{ matrix.config.name }}
|
||||
|
||||
runs-on: ${{ matrix.config.runner }}
|
||||
|
||||
run_tests_cpu:
|
||||
name: CPU tests on Ubuntu
|
||||
runs-on: [ self-hosted, docker-gpu ]
|
||||
container:
|
||||
image: ${{ matrix.config.image }}
|
||||
image: python:3.7
|
||||
options: --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache/
|
||||
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
@@ -58,6 +31,8 @@ jobs:
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install torch --extra-index-url https://download.pytorch.org/whl/cpu
|
||||
python -m pip install -e .[quality,test]
|
||||
python -m pip install git+https://github.com/huggingface/accelerate
|
||||
|
||||
@@ -65,43 +40,23 @@ jobs:
|
||||
run: |
|
||||
python utils/print_env.py
|
||||
|
||||
- name: Run fast PyTorch CPU tests
|
||||
if: ${{ matrix.config.framework == 'pytorch' }}
|
||||
- name: Run all fast tests on CPU
|
||||
run: |
|
||||
python -m pytest -n 2 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "not Flax and not Onnx" \
|
||||
--make-reports=tests_${{ matrix.config.report }} \
|
||||
tests/
|
||||
|
||||
- name: Run fast Flax TPU tests
|
||||
if: ${{ matrix.config.framework == 'flax' }}
|
||||
run: |
|
||||
python -m pytest -n 2 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "Flax" \
|
||||
--make-reports=tests_${{ matrix.config.report }} \
|
||||
tests/
|
||||
|
||||
- name: Run fast ONNXRuntime CPU tests
|
||||
if: ${{ matrix.config.framework == 'onnxruntime' }}
|
||||
run: |
|
||||
python -m pytest -n 2 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "Onnx" \
|
||||
--make-reports=tests_${{ matrix.config.report }} \
|
||||
tests/
|
||||
python -m pytest -n 2 --max-worker-restart=0 --dist=loadfile -s -v --make-reports=tests_torch_cpu tests/
|
||||
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
run: cat reports/tests_${{ matrix.config.report }}_failures_short.txt
|
||||
run: cat reports/tests_torch_cpu_failures_short.txt
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v2
|
||||
with:
|
||||
name: pr_${{ matrix.config.report }}_test_reports
|
||||
name: pr_torch_cpu_test_reports
|
||||
path: reports
|
||||
|
||||
run_fast_tests_apple_m1:
|
||||
name: Fast PyTorch MPS tests on MacOS
|
||||
run_tests_apple_m1:
|
||||
name: MPS tests on Apple M1
|
||||
runs-on: [ self-hosted, apple-m1 ]
|
||||
|
||||
steps:
|
||||
@@ -133,7 +88,7 @@ jobs:
|
||||
run: |
|
||||
${CONDA_RUN} python utils/print_env.py
|
||||
|
||||
- name: Run fast PyTorch tests on M1 (MPS)
|
||||
- name: Run all fast tests on MPS
|
||||
shell: arch -arch arm64 bash {0}
|
||||
run: |
|
||||
${CONDA_RUN} python -m pytest -n 1 -s -v --make-reports=tests_torch_mps tests/
|
||||
|
||||
92
.github/workflows/push_tests.yml
vendored
92
.github/workflows/push_tests.yml
vendored
@@ -6,7 +6,6 @@ on:
|
||||
- main
|
||||
|
||||
env:
|
||||
DIFFUSERS_IS_CI: yes
|
||||
HF_HOME: /mnt/cache
|
||||
OMP_NUM_THREADS: 8
|
||||
MKL_NUM_THREADS: 8
|
||||
@@ -14,38 +13,12 @@ env:
|
||||
RUN_SLOW: yes
|
||||
|
||||
jobs:
|
||||
run_slow_tests:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
config:
|
||||
- name: Slow PyTorch CUDA tests on Ubuntu
|
||||
framework: pytorch
|
||||
runner: docker-gpu
|
||||
image: diffusers/diffusers-pytorch-cuda
|
||||
report: torch_cuda
|
||||
- name: Slow Flax TPU tests on Ubuntu
|
||||
framework: flax
|
||||
runner: docker-tpu
|
||||
image: diffusers/diffusers-flax-tpu
|
||||
report: flax_tpu
|
||||
- name: Slow ONNXRuntime CUDA tests on Ubuntu
|
||||
framework: onnxruntime
|
||||
runner: docker-gpu
|
||||
image: diffusers/diffusers-onnxruntime-cuda
|
||||
report: onnx_cuda
|
||||
|
||||
name: ${{ matrix.config.name }}
|
||||
|
||||
runs-on: ${{ matrix.config.runner }}
|
||||
|
||||
run_tests_single_gpu:
|
||||
name: Diffusers tests
|
||||
runs-on: [ self-hosted, docker-gpu, single-gpu ]
|
||||
container:
|
||||
image: ${{ matrix.config.image }}
|
||||
options: --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache/ ${{ matrix.config.runner == 'docker-tpu' && '--privileged' || '--gpus 0'}}
|
||||
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
image: nvcr.io/nvidia/pytorch:22.07-py3
|
||||
options: --gpus 0 --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache
|
||||
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
@@ -54,12 +27,14 @@ jobs:
|
||||
fetch-depth: 2
|
||||
|
||||
- name: NVIDIA-SMI
|
||||
if : ${{ matrix.config.runner == 'docker-gpu' }}
|
||||
run: |
|
||||
nvidia-smi
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip uninstall -y torch torchvision torchtext
|
||||
python -m pip install torch --extra-index-url https://download.pytorch.org/whl/cu117
|
||||
python -m pip install -e .[quality,test]
|
||||
python -m pip install git+https://github.com/huggingface/accelerate
|
||||
|
||||
@@ -67,55 +42,29 @@ jobs:
|
||||
run: |
|
||||
python utils/print_env.py
|
||||
|
||||
- name: Run slow PyTorch CUDA tests
|
||||
if: ${{ matrix.config.framework == 'pytorch' }}
|
||||
- name: Run all (incl. slow) tests on GPU
|
||||
env:
|
||||
HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }}
|
||||
run: |
|
||||
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "not Flax and not Onnx" \
|
||||
--make-reports=tests_${{ matrix.config.report }} \
|
||||
tests/
|
||||
|
||||
- name: Run slow Flax TPU tests
|
||||
if: ${{ matrix.config.framework == 'flax' }}
|
||||
env:
|
||||
HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }}
|
||||
run: |
|
||||
python -m pytest -n 0 \
|
||||
-s -v -k "Flax" \
|
||||
--make-reports=tests_${{ matrix.config.report }} \
|
||||
tests/
|
||||
|
||||
- name: Run slow ONNXRuntime CUDA tests
|
||||
if: ${{ matrix.config.framework == 'onnxruntime' }}
|
||||
env:
|
||||
HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }}
|
||||
run: |
|
||||
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "Onnx" \
|
||||
--make-reports=tests_${{ matrix.config.report }} \
|
||||
tests/
|
||||
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v --make-reports=tests_torch_gpu tests/
|
||||
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
run: cat reports/tests_${{ matrix.config.report }}_failures_short.txt
|
||||
run: cat reports/tests_torch_gpu_failures_short.txt
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v2
|
||||
with:
|
||||
name: ${{ matrix.config.report }}_test_reports
|
||||
name: torch_test_reports
|
||||
path: reports
|
||||
|
||||
run_examples_tests:
|
||||
name: Examples PyTorch CUDA tests on Ubuntu
|
||||
|
||||
runs-on: docker-gpu
|
||||
|
||||
run_examples_single_gpu:
|
||||
name: Examples tests
|
||||
runs-on: [ self-hosted, docker-gpu, single-gpu ]
|
||||
container:
|
||||
image: diffusers/diffusers-pytorch-cuda
|
||||
options: --gpus 0 --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache/
|
||||
image: nvcr.io/nvidia/pytorch:22.07-py3
|
||||
options: --gpus 0 --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache
|
||||
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
@@ -129,6 +78,9 @@ jobs:
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip uninstall -y torch torchvision torchtext
|
||||
python -m pip install torch --extra-index-url https://download.pytorch.org/whl/cu117
|
||||
python -m pip install -e .[quality,test,training]
|
||||
python -m pip install git+https://github.com/huggingface/accelerate
|
||||
|
||||
@@ -140,11 +92,11 @@ jobs:
|
||||
env:
|
||||
HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }}
|
||||
run: |
|
||||
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v --make-reports=examples_torch_cuda examples/
|
||||
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v --make-reports=examples_torch_gpu examples/
|
||||
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
run: cat reports/examples_torch_cuda_failures_short.txt
|
||||
run: cat reports/examples_torch_gpu_failures_short.txt
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
|
||||
22
README.md
22
README.md
@@ -27,12 +27,10 @@ More precisely, 🤗 Diffusers offers:
|
||||
|
||||
## Installation
|
||||
|
||||
### For PyTorch
|
||||
|
||||
**With `pip`**
|
||||
|
||||
```bash
|
||||
pip install --upgrade diffusers[torch]
|
||||
pip install --upgrade diffusers
|
||||
```
|
||||
|
||||
**With `conda`**
|
||||
@@ -41,14 +39,6 @@ pip install --upgrade diffusers[torch]
|
||||
conda install -c conda-forge diffusers
|
||||
```
|
||||
|
||||
### For Flax
|
||||
|
||||
**With `pip`**
|
||||
|
||||
```bash
|
||||
pip install --upgrade diffusers[flax]
|
||||
```
|
||||
|
||||
**Apple Silicon (M1/M2) support**
|
||||
|
||||
Please, refer to [the documentation](https://huggingface.co/docs/diffusers/optimization/mps).
|
||||
@@ -152,7 +142,11 @@ it before the pipeline and pass it to `from_pretrained`.
|
||||
```python
|
||||
from diffusers import LMSDiscreteScheduler
|
||||
|
||||
lms = LMSDiscreteScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
|
||||
lms = LMSDiscreteScheduler(
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
beta_schedule="scaled_linear"
|
||||
)
|
||||
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5",
|
||||
@@ -364,7 +358,7 @@ There are many ways to try running Diffusers! Here we outline code-focused tools
|
||||
If you want to run the code yourself 💻, you can try out:
|
||||
- [Text-to-Image Latent Diffusion](https://huggingface.co/CompVis/ldm-text2im-large-256)
|
||||
```python
|
||||
# !pip install diffusers["torch"] transformers
|
||||
# !pip install diffusers transformers
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
device = "cuda"
|
||||
@@ -383,7 +377,7 @@ image.save("squirrel.png")
|
||||
```
|
||||
- [Unconditional Diffusion with discrete scheduler](https://huggingface.co/google/ddpm-celebahq-256)
|
||||
```python
|
||||
# !pip install diffusers["torch"]
|
||||
# !pip install diffusers
|
||||
from diffusers import DDPMPipeline, DDIMPipeline, PNDMPipeline
|
||||
|
||||
model_id = "google/ddpm-celebahq-256"
|
||||
|
||||
@@ -1,42 +0,0 @@
|
||||
FROM ubuntu:20.04
|
||||
LABEL maintainer="Hugging Face"
|
||||
LABEL repository="diffusers"
|
||||
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
RUN apt update && \
|
||||
apt install -y bash \
|
||||
build-essential \
|
||||
git \
|
||||
git-lfs \
|
||||
curl \
|
||||
ca-certificates \
|
||||
python3.8 \
|
||||
python3-pip \
|
||||
python3.8-venv && \
|
||||
rm -rf /var/lib/apt/lists
|
||||
|
||||
# make sure to use venv
|
||||
RUN python3 -m venv /opt/venv
|
||||
ENV PATH="/opt/venv/bin:$PATH"
|
||||
|
||||
# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
|
||||
# follow the instructions here: https://cloud.google.com/tpu/docs/run-in-container#train_a_jax_model_in_a_docker_container
|
||||
RUN python3 -m pip install --no-cache-dir --upgrade pip && \
|
||||
python3 -m pip install --upgrade --no-cache-dir \
|
||||
clu \
|
||||
"jax[cpu]>=0.2.16,!=0.3.2" \
|
||||
"flax>=0.4.1" \
|
||||
"jaxlib>=0.1.65" && \
|
||||
python3 -m pip install --no-cache-dir \
|
||||
accelerate \
|
||||
datasets \
|
||||
hf-doc-builder \
|
||||
huggingface-hub \
|
||||
modelcards \
|
||||
numpy \
|
||||
scipy \
|
||||
tensorboard \
|
||||
transformers
|
||||
|
||||
CMD ["/bin/bash"]
|
||||
@@ -1,44 +0,0 @@
|
||||
FROM ubuntu:20.04
|
||||
LABEL maintainer="Hugging Face"
|
||||
LABEL repository="diffusers"
|
||||
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
RUN apt update && \
|
||||
apt install -y bash \
|
||||
build-essential \
|
||||
git \
|
||||
git-lfs \
|
||||
curl \
|
||||
ca-certificates \
|
||||
python3.8 \
|
||||
python3-pip \
|
||||
python3.8-venv && \
|
||||
rm -rf /var/lib/apt/lists
|
||||
|
||||
# make sure to use venv
|
||||
RUN python3 -m venv /opt/venv
|
||||
ENV PATH="/opt/venv/bin:$PATH"
|
||||
|
||||
# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
|
||||
# follow the instructions here: https://cloud.google.com/tpu/docs/run-in-container#train_a_jax_model_in_a_docker_container
|
||||
RUN python3 -m pip install --no-cache-dir --upgrade pip && \
|
||||
python3 -m pip install --no-cache-dir \
|
||||
"jax[tpu]>=0.2.16,!=0.3.2" \
|
||||
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html && \
|
||||
python3 -m pip install --upgrade --no-cache-dir \
|
||||
clu \
|
||||
"flax>=0.4.1" \
|
||||
"jaxlib>=0.1.65" && \
|
||||
python3 -m pip install --no-cache-dir \
|
||||
accelerate \
|
||||
datasets \
|
||||
hf-doc-builder \
|
||||
huggingface-hub \
|
||||
modelcards \
|
||||
numpy \
|
||||
scipy \
|
||||
tensorboard \
|
||||
transformers
|
||||
|
||||
CMD ["/bin/bash"]
|
||||
@@ -1,42 +0,0 @@
|
||||
FROM ubuntu:20.04
|
||||
LABEL maintainer="Hugging Face"
|
||||
LABEL repository="diffusers"
|
||||
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
RUN apt update && \
|
||||
apt install -y bash \
|
||||
build-essential \
|
||||
git \
|
||||
git-lfs \
|
||||
curl \
|
||||
ca-certificates \
|
||||
python3.8 \
|
||||
python3-pip \
|
||||
python3.8-venv && \
|
||||
rm -rf /var/lib/apt/lists
|
||||
|
||||
# make sure to use venv
|
||||
RUN python3 -m venv /opt/venv
|
||||
ENV PATH="/opt/venv/bin:$PATH"
|
||||
|
||||
# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
|
||||
RUN python3 -m pip install --no-cache-dir --upgrade pip && \
|
||||
python3 -m pip install --no-cache-dir \
|
||||
torch \
|
||||
torchvision \
|
||||
torchaudio \
|
||||
onnxruntime \
|
||||
--extra-index-url https://download.pytorch.org/whl/cpu && \
|
||||
python3 -m pip install --no-cache-dir \
|
||||
accelerate \
|
||||
datasets \
|
||||
hf-doc-builder \
|
||||
huggingface-hub \
|
||||
modelcards \
|
||||
numpy \
|
||||
scipy \
|
||||
tensorboard \
|
||||
transformers
|
||||
|
||||
CMD ["/bin/bash"]
|
||||
@@ -1,42 +0,0 @@
|
||||
FROM nvidia/cuda:11.6.2-cudnn8-devel-ubuntu20.04
|
||||
LABEL maintainer="Hugging Face"
|
||||
LABEL repository="diffusers"
|
||||
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
RUN apt update && \
|
||||
apt install -y bash \
|
||||
build-essential \
|
||||
git \
|
||||
git-lfs \
|
||||
curl \
|
||||
ca-certificates \
|
||||
python3.8 \
|
||||
python3-pip \
|
||||
python3.8-venv && \
|
||||
rm -rf /var/lib/apt/lists
|
||||
|
||||
# make sure to use venv
|
||||
RUN python3 -m venv /opt/venv
|
||||
ENV PATH="/opt/venv/bin:$PATH"
|
||||
|
||||
# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
|
||||
RUN python3 -m pip install --no-cache-dir --upgrade pip && \
|
||||
python3 -m pip install --no-cache-dir \
|
||||
torch \
|
||||
torchvision \
|
||||
torchaudio \
|
||||
"onnxruntime-gpu>=1.13.1" \
|
||||
--extra-index-url https://download.pytorch.org/whl/cu117 && \
|
||||
python3 -m pip install --no-cache-dir \
|
||||
accelerate \
|
||||
datasets \
|
||||
hf-doc-builder \
|
||||
huggingface-hub \
|
||||
modelcards \
|
||||
numpy \
|
||||
scipy \
|
||||
tensorboard \
|
||||
transformers
|
||||
|
||||
CMD ["/bin/bash"]
|
||||
@@ -1,41 +0,0 @@
|
||||
FROM ubuntu:20.04
|
||||
LABEL maintainer="Hugging Face"
|
||||
LABEL repository="diffusers"
|
||||
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
RUN apt update && \
|
||||
apt install -y bash \
|
||||
build-essential \
|
||||
git \
|
||||
git-lfs \
|
||||
curl \
|
||||
ca-certificates \
|
||||
python3.8 \
|
||||
python3-pip \
|
||||
python3.8-venv && \
|
||||
rm -rf /var/lib/apt/lists
|
||||
|
||||
# make sure to use venv
|
||||
RUN python3 -m venv /opt/venv
|
||||
ENV PATH="/opt/venv/bin:$PATH"
|
||||
|
||||
# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
|
||||
RUN python3 -m pip install --no-cache-dir --upgrade pip && \
|
||||
python3 -m pip install --no-cache-dir \
|
||||
torch \
|
||||
torchvision \
|
||||
torchaudio \
|
||||
--extra-index-url https://download.pytorch.org/whl/cpu && \
|
||||
python3 -m pip install --no-cache-dir \
|
||||
accelerate \
|
||||
datasets \
|
||||
hf-doc-builder \
|
||||
huggingface-hub \
|
||||
modelcards \
|
||||
numpy \
|
||||
scipy \
|
||||
tensorboard \
|
||||
transformers
|
||||
|
||||
CMD ["/bin/bash"]
|
||||
@@ -1,41 +0,0 @@
|
||||
FROM nvidia/cuda:11.7.1-cudnn8-runtime-ubuntu20.04
|
||||
LABEL maintainer="Hugging Face"
|
||||
LABEL repository="diffusers"
|
||||
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
RUN apt update && \
|
||||
apt install -y bash \
|
||||
build-essential \
|
||||
git \
|
||||
git-lfs \
|
||||
curl \
|
||||
ca-certificates \
|
||||
python3.8 \
|
||||
python3-pip \
|
||||
python3.8-venv && \
|
||||
rm -rf /var/lib/apt/lists
|
||||
|
||||
# make sure to use venv
|
||||
RUN python3 -m venv /opt/venv
|
||||
ENV PATH="/opt/venv/bin:$PATH"
|
||||
|
||||
# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
|
||||
RUN python3 -m pip install --no-cache-dir --upgrade pip && \
|
||||
python3 -m pip install --no-cache-dir \
|
||||
torch \
|
||||
torchvision \
|
||||
torchaudio \
|
||||
--extra-index-url https://download.pytorch.org/whl/cu117 && \
|
||||
python3 -m pip install --no-cache-dir \
|
||||
accelerate \
|
||||
datasets \
|
||||
hf-doc-builder \
|
||||
huggingface-hub \
|
||||
modelcards \
|
||||
numpy \
|
||||
scipy \
|
||||
tensorboard \
|
||||
transformers
|
||||
|
||||
CMD ["/bin/bash"]
|
||||
@@ -78,8 +78,6 @@
|
||||
- sections:
|
||||
- local: api/pipelines/overview
|
||||
title: "Overview"
|
||||
- local: api/pipelines/cycle_diffusion
|
||||
title: "Cycle Diffusion"
|
||||
- local: api/pipelines/ddim
|
||||
title: "DDIM"
|
||||
- local: api/pipelines/ddpm
|
||||
@@ -98,9 +96,5 @@
|
||||
title: "Stochastic Karras VE"
|
||||
- local: api/pipelines/dance_diffusion
|
||||
title: "Dance Diffusion"
|
||||
- local: api/pipelines/vq_diffusion
|
||||
title: "VQ Diffusion"
|
||||
- local: api/pipelines/repaint
|
||||
title: "RePaint"
|
||||
title: "Pipelines"
|
||||
title: "API"
|
||||
|
||||
@@ -49,12 +49,6 @@ The models are built on the base class ['ModelMixin'] that is a `torch.nn.module
|
||||
## AutoencoderKL
|
||||
[[autodoc]] AutoencoderKL
|
||||
|
||||
## Transformer2DModel
|
||||
[[autodoc]] Transformer2DModel
|
||||
|
||||
## Transformer2DModelOutput
|
||||
[[autodoc]] models.attention.Transformer2DModelOutput
|
||||
|
||||
## FlaxModelMixin
|
||||
[[autodoc]] FlaxModelMixin
|
||||
|
||||
|
||||
@@ -1,99 +0,0 @@
|
||||
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Cycle Diffusion
|
||||
|
||||
## Overview
|
||||
|
||||
Cycle Diffusion is a Text-Guided Image-to-Image Generation model proposed in [Unifying Diffusion Models' Latent Space, with Applications to CycleDiffusion and Guidance](https://arxiv.org/abs/2210.05559) by Chen Henry Wu, Fernando De la Torre.
|
||||
|
||||
The abstract of the paper is the following:
|
||||
|
||||
*Diffusion models have achieved unprecedented performance in generative modeling. The commonly-adopted formulation of the latent code of diffusion models is a sequence of gradually denoised samples, as opposed to the simpler (e.g., Gaussian) latent space of GANs, VAEs, and normalizing flows. This paper provides an alternative, Gaussian formulation of the latent space of various diffusion models, as well as an invertible DPM-Encoder that maps images into the latent space. While our formulation is purely based on the definition of diffusion models, we demonstrate several intriguing consequences. (1) Empirically, we observe that a common latent space emerges from two diffusion models trained independently on related domains. In light of this finding, we propose CycleDiffusion, which uses DPM-Encoder for unpaired image-to-image translation. Furthermore, applying CycleDiffusion to text-to-image diffusion models, we show that large-scale text-to-image diffusion models can be used as zero-shot image-to-image editors. (2) One can guide pre-trained diffusion models and GANs by controlling the latent codes in a unified, plug-and-play formulation based on energy-based models. Using the CLIP model and a face recognition model as guidance, we demonstrate that diffusion models have better coverage of low-density sub-populations and individuals than GANs.*
|
||||
|
||||
*Tips*:
|
||||
- The Cycle Diffusion pipeline is fully compatible with any [Stable Diffusion](./stable_diffusion) checkpoints
|
||||
- Currently Cycle Diffusion only works with the [`DDIMScheduler`].
|
||||
|
||||
*Example*:
|
||||
|
||||
In the following we should how to best use the [`CycleDiffusionPipeline`]
|
||||
|
||||
```python
|
||||
import requests
|
||||
import torch
|
||||
from PIL import Image
|
||||
from io import BytesIO
|
||||
|
||||
from diffusers import CycleDiffusionPipeline, DDIMScheduler
|
||||
|
||||
# load the pipeline
|
||||
# make sure you're logged in with `huggingface-cli login`
|
||||
model_id_or_path = "CompVis/stable-diffusion-v1-4"
|
||||
scheduler = DDIMScheduler.from_config(model_id_or_path, subfolder="scheduler")
|
||||
pipe = CycleDiffusionPipeline.from_pretrained(model_id_or_path, scheduler=scheduler).to("cuda")
|
||||
|
||||
# let's download an initial image
|
||||
url = "https://raw.githubusercontent.com/ChenWu98/cycle-diffusion/main/data/dalle2/An%20astronaut%20riding%20a%20horse.png"
|
||||
response = requests.get(url)
|
||||
init_image = Image.open(BytesIO(response.content)).convert("RGB")
|
||||
init_image = init_image.resize((512, 512))
|
||||
init_image.save("horse.png")
|
||||
|
||||
# let's specify a prompt
|
||||
source_prompt = "An astronaut riding a horse"
|
||||
prompt = "An astronaut riding an elephant"
|
||||
|
||||
# call the pipeline
|
||||
image = pipe(
|
||||
prompt=prompt,
|
||||
source_prompt=source_prompt,
|
||||
init_image=init_image,
|
||||
num_inference_steps=100,
|
||||
eta=0.1,
|
||||
strength=0.8,
|
||||
guidance_scale=2,
|
||||
source_guidance_scale=1,
|
||||
).images[0]
|
||||
|
||||
image.save("horse_to_elephant.png")
|
||||
|
||||
# let's try another example
|
||||
# See more samples at the original repo: https://github.com/ChenWu98/cycle-diffusion
|
||||
url = "https://raw.githubusercontent.com/ChenWu98/cycle-diffusion/main/data/dalle2/A%20black%20colored%20car.png"
|
||||
response = requests.get(url)
|
||||
init_image = Image.open(BytesIO(response.content)).convert("RGB")
|
||||
init_image = init_image.resize((512, 512))
|
||||
init_image.save("black.png")
|
||||
|
||||
source_prompt = "A black colored car"
|
||||
prompt = "A blue colored car"
|
||||
|
||||
# call the pipeline
|
||||
torch.manual_seed(0)
|
||||
image = pipe(
|
||||
prompt=prompt,
|
||||
source_prompt=source_prompt,
|
||||
init_image=init_image,
|
||||
num_inference_steps=100,
|
||||
eta=0.1,
|
||||
strength=0.85,
|
||||
guidance_scale=3,
|
||||
source_guidance_scale=1,
|
||||
).images[0]
|
||||
|
||||
image.save("black_to_blue.png")
|
||||
```
|
||||
|
||||
## CycleDiffusionPipeline
|
||||
[[autodoc]] CycleDiffusionPipeline
|
||||
- __call__
|
||||
@@ -28,7 +28,7 @@ or created independently from each other.
|
||||
|
||||
To that end, we strive to offer all open-sourced, state-of-the-art diffusion system under a unified API.
|
||||
More specifically, we strive to provide pipelines that
|
||||
- 1. can load the officially published weights and yield 1-to-1 the same outputs as the original implementation according to the corresponding paper (*e.g.* [LDMTextToImagePipeline](https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines/latent_diffusion), uses the officially released weights of [High-Resolution Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752)),
|
||||
- 1. can load the officially published weights and yield 1-to-1 the same outputs as the original implementation according to the corresponding paper (*e.g.* [LatentDiffusionPipeline](https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines/latent_diffusion), uses the officially released weights of [High-Resolution Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752)),
|
||||
- 2. have a simple user interface to run the model in inference (see the [Pipelines API](#pipelines-api) section),
|
||||
- 3. are easy to understand with code that is self-explanatory and can be read along-side the official paper (see [Pipelines summary](#pipelines-summary)),
|
||||
- 4. can easily be contributed by the community (see the [Contribution](#contribution) section).
|
||||
@@ -41,24 +41,19 @@ If you are looking for *official* training examples, please have a look at [exam
|
||||
The following table summarizes all officially supported pipelines, their corresponding paper, and if
|
||||
available a colab notebook to directly try them out.
|
||||
|
||||
|
||||
| Pipeline | Paper | Tasks | Colab
|
||||
|---|---|:---:|:---:|
|
||||
| [cycle_diffusion](./api/pipelines/cycle_diffusion) | [**Cycle Diffusion**](https://arxiv.org/abs/2210.05559) | Image-to-Image Text-Guided Generation |
|
||||
| [dance_diffusion](./api/pipelines/dance_diffusion) | [**Dance Diffusion**](https://github.com/williamberman/diffusers.git) | Unconditional Audio Generation |
|
||||
| [ddpm](./api/pipelines/ddpm) | [**Denoising Diffusion Probabilistic Models**](https://arxiv.org/abs/2006.11239) | Unconditional Image Generation |
|
||||
| [ddim](./api/pipelines/ddim) | [**Denoising Diffusion Implicit Models**](https://arxiv.org/abs/2010.02502) | Unconditional Image Generation |
|
||||
| [latent_diffusion](./api/pipelines/latent_diffusion) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752)| Text-to-Image Generation |
|
||||
| [latent_diffusion_uncond](./api/pipelines/latent_diffusion_uncond) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752) | Unconditional Image Generation |
|
||||
| [pndm](./api/pipelines/pndm) | [**Pseudo Numerical Methods for Diffusion Models on Manifolds**](https://arxiv.org/abs/2202.09778) | Unconditional Image Generation |
|
||||
| [score_sde_ve](./api/pipelines/score_sde_ve) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation |
|
||||
| [score_sde_vp](./api/pipelines/score_sde_vp) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation |
|
||||
| [stable_diffusion](./api/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-to-Image Generation | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb)
|
||||
| [stable_diffusion](./api/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Image-to-Image Text-Guided Generation | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb)
|
||||
| [stable_diffusion](./api/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-Guided Image Inpainting | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb)
|
||||
| [stochastic_karras_ve](./api/pipelines/stochastic_karras_ve) | [**Elucidating the Design Space of Diffusion-Based Generative Models**](https://arxiv.org/abs/2206.00364) | Unconditional Image Generation |
|
||||
| [vq_diffusion](./api/pipelines/vq_diffusion) | [Vector Quantized Diffusion Model for Text-to-Image Synthesis](https://arxiv.org/abs/2111.14822) | Text-to-Image Generation |
|
||||
|
||||
| [ddpm](./ddpm) | [**Denoising Diffusion Probabilistic Models**](https://arxiv.org/abs/2006.11239) | Unconditional Image Generation |
|
||||
| [ddim](./ddim) | [**Denoising Diffusion Implicit Models**](https://arxiv.org/abs/2010.02502) | Unconditional Image Generation | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb)
|
||||
| [latent_diffusion](./latent_diffusion) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752)| Text-to-Image Generation |
|
||||
| [latent_diffusion_uncond](./latent_diffusion_uncond) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752) | Unconditional Image Generation |
|
||||
| [pndm](./pndm) | [**Pseudo Numerical Methods for Diffusion Models on Manifolds**](https://arxiv.org/abs/2202.09778) | Unconditional Image Generation |
|
||||
| [score_sde_ve](./score_sde_ve) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation |
|
||||
| [score_sde_vp](./score_sde_vp) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation |
|
||||
| [stable_diffusion](./stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-to-Image Generation | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb)
|
||||
| [stable_diffusion](./stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Image-to-Image Text-Guided Generation | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb)
|
||||
| [stable_diffusion](./stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-Guided Image Inpainting | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb)
|
||||
| [stochastic_karras_ve](./stochastic_karras_ve) | [**Elucidating the Design Space of Diffusion-Based Generative Models**](https://arxiv.org/abs/2206.00364) | Unconditional Image Generation |
|
||||
|
||||
**Note**: Pipelines are simple examples of how to play around with the diffusion systems as described in the corresponding papers.
|
||||
|
||||
|
||||
@@ -1,77 +0,0 @@
|
||||
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# RePaint
|
||||
|
||||
## Overview
|
||||
|
||||
[RePaint: Inpainting using Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2201.09865) (PNDM) by Andreas Lugmayr, Martin Danelljan, Andres Romero, Fisher Yu, Radu Timofte, Luc Van Gool.
|
||||
|
||||
The abstract of the paper is the following:
|
||||
|
||||
Free-form inpainting is the task of adding new content to an image in the regions specified by an arbitrary binary mask. Most existing approaches train for a certain distribution of masks, which limits their generalization capabilities to unseen mask types. Furthermore, training with pixel-wise and perceptual losses often leads to simple textural extensions towards the missing areas instead of semantically meaningful generation. In this work, we propose RePaint: A Denoising Diffusion Probabilistic Model (DDPM) based inpainting approach that is applicable to even extreme masks. We employ a pretrained unconditional DDPM as the generative prior. To condition the generation process, we only alter the reverse diffusion iterations by sampling the unmasked regions using the given image information. Since this technique does not modify or condition the original DDPM network itself, the model produces high-quality and diverse output images for any inpainting form. We validate our method for both faces and general-purpose image inpainting using standard and extreme masks.
|
||||
RePaint outperforms state-of-the-art Autoregressive, and GAN approaches for at least five out of six mask distributions.
|
||||
|
||||
The original codebase can be found [here](https://github.com/andreas128/RePaint).
|
||||
|
||||
## Available Pipelines:
|
||||
|
||||
| Pipeline | Tasks | Colab
|
||||
|-------------------------------------------------------------------------------------------------------------------------------|--------------------|:---:|
|
||||
| [pipeline_repaint.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/repaint/pipeline_repaint.py) | *Image Inpainting* | - |
|
||||
|
||||
## Usage example
|
||||
|
||||
```python
|
||||
from io import BytesIO
|
||||
|
||||
import torch
|
||||
|
||||
import PIL
|
||||
import requests
|
||||
from diffusers import RePaintPipeline, RePaintScheduler
|
||||
|
||||
|
||||
def download_image(url):
|
||||
response = requests.get(url)
|
||||
return PIL.Image.open(BytesIO(response.content)).convert("RGB")
|
||||
|
||||
|
||||
img_url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/repaint/celeba_hq_256.png"
|
||||
mask_url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/repaint/mask_256.png"
|
||||
|
||||
# Load the original image and the mask as PIL images
|
||||
original_image = download_image(img_url).resize((256, 256))
|
||||
mask_image = download_image(mask_url).resize((256, 256))
|
||||
|
||||
# Load the RePaint scheduler and pipeline based on a pretrained DDPM model
|
||||
scheduler = RePaintScheduler.from_config("google/ddpm-ema-celebahq-256")
|
||||
pipe = RePaintPipeline.from_pretrained("google/ddpm-ema-celebahq-256", scheduler=scheduler)
|
||||
pipe = pipe.to("cuda")
|
||||
|
||||
generator = torch.Generator(device="cuda").manual_seed(0)
|
||||
output = pipe(
|
||||
original_image=original_image,
|
||||
mask_image=mask_image,
|
||||
num_inference_steps=250,
|
||||
eta=0.0,
|
||||
jump_length=10,
|
||||
jump_n_sample=10,
|
||||
generator=generator,
|
||||
)
|
||||
inpainted_image = output.images[0]
|
||||
```
|
||||
|
||||
## RePaintPipeline
|
||||
[[autodoc]] pipelines.repaint.pipeline_repaint.RePaintPipeline
|
||||
- __call__
|
||||
|
||||
@@ -31,21 +31,6 @@ For more details about how Stable Diffusion works and how it differs from the ba
|
||||
|
||||
## Tips
|
||||
|
||||
### How to load and use different schedulers.
|
||||
|
||||
The stable diffusion pipeline uses [`PNDMScheduler`] scheduler by default. But `diffusers` provides many other schedulers that can be used with the stable diffusion pipeline such as [`DDIMScheduler`], [`LMSDiscreteScheduler`], [`EulerDiscreteScheduler`], [`EulerAncestralDiscreteScheduler`] etc.
|
||||
To use a different scheduler, you can pass the `scheduler` argument to `from_pretrained` method of the pipeline. For example, to use the [`EulerDiscreteScheduler`], you can do the following:
|
||||
|
||||
```python
|
||||
from diffusers import StableDiffusionPipeline, EulerDiscreteScheduler
|
||||
|
||||
euler_scheduler = EulerDiscreteScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
|
||||
pipeline = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", scheduler=euler_scheduler)
|
||||
```
|
||||
|
||||
|
||||
### How to conver all use cases with multiple or single pipeline
|
||||
|
||||
If you want to use all possible use cases in a single `DiffusionPipeline` you can either:
|
||||
- Make use of the [Stable Diffusion Mega Pipeline](https://github.com/huggingface/diffusers/tree/main/examples/community#stable-diffusion-mega) or
|
||||
- Make use of the `components` functionality to instantiate all components in the most memory-efficient way:
|
||||
|
||||
@@ -1,34 +0,0 @@
|
||||
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# VQDiffusion
|
||||
|
||||
## Overview
|
||||
|
||||
[Vector Quantized Diffusion Model for Text-to-Image Synthesis](https://arxiv.org/abs/2111.14822) by Shuyang Gu, Dong Chen, Jianmin Bao, Fang Wen, Bo Zhang, Dongdong Chen, Lu Yuan, Baining Guo
|
||||
|
||||
The abstract of the paper is the following:
|
||||
|
||||
We present the vector quantized diffusion (VQ-Diffusion) model for text-to-image generation. This method is based on a vector quantized variational autoencoder (VQ-VAE) whose latent space is modeled by a conditional variant of the recently developed Denoising Diffusion Probabilistic Model (DDPM). We find that this latent-space method is well-suited for text-to-image generation tasks because it not only eliminates the unidirectional bias with existing methods but also allows us to incorporate a mask-and-replace diffusion strategy to avoid the accumulation of errors, which is a serious problem with existing methods. Our experiments show that the VQ-Diffusion produces significantly better text-to-image generation results when compared with conventional autoregressive (AR) models with similar numbers of parameters. Compared with previous GAN-based text-to-image methods, our VQ-Diffusion can handle more complex scenes and improve the synthesized image quality by a large margin. Finally, we show that the image generation computation in our method can be made highly efficient by reparameterization. With traditional AR methods, the text-to-image generation time increases linearly with the output image resolution and hence is quite time consuming even for normal size images. The VQ-Diffusion allows us to achieve a better trade-off between quality and speed. Our experiments indicate that the VQ-Diffusion model with the reparameterization is fifteen times faster than traditional AR methods while achieving a better image quality.
|
||||
|
||||
The original codebase can be found [here](https://github.com/microsoft/VQ-Diffusion).
|
||||
|
||||
## Available Pipelines:
|
||||
|
||||
| Pipeline | Tasks | Colab
|
||||
|---|---|:---:|
|
||||
| [pipeline_vq_diffusion.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/vq_diffusion/pipeline_vq_diffusion.py) | *Text-to-Image Generation* | - |
|
||||
|
||||
|
||||
## VQDiffusionPipeline
|
||||
[[autodoc]] pipelines.vq_diffusion.pipeline_vq_diffusion.VQDiffusionPipeline
|
||||
- __call__
|
||||
@@ -70,12 +70,6 @@ Original paper can be found [here](https://arxiv.org/abs/2010.02502).
|
||||
|
||||
[[autodoc]] DDPMScheduler
|
||||
|
||||
#### Multistep DPM-Solver
|
||||
|
||||
Original paper can be found [here](https://arxiv.org/abs/2206.00927) and the [improved version](https://arxiv.org/abs/2211.01095). The original implementation can be found [here](https://github.com/LuChengTHU/dpm-solver).
|
||||
|
||||
[[autodoc]] DPMSolverMultistepScheduler
|
||||
|
||||
#### Variance exploding, stochastic sampling from Karras et. al
|
||||
|
||||
Original paper can be found [here](https://arxiv.org/abs/2006.11239).
|
||||
@@ -118,34 +112,3 @@ Score SDE-VP is under construction.
|
||||
</Tip>
|
||||
|
||||
[[autodoc]] schedulers.scheduling_sde_vp.ScoreSdeVpScheduler
|
||||
|
||||
#### Euler scheduler
|
||||
|
||||
Euler scheduler (Algorithm 2) from the paper [Elucidating the Design Space of Diffusion-Based Generative Models](https://arxiv.org/abs/2206.00364) by Karras et al. (2022). Based on the original [k-diffusion](https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L51) implementation by Katherine Crowson.
|
||||
Fast scheduler which often times generates good outputs with 20-30 steps.
|
||||
|
||||
[[autodoc]] EulerDiscreteScheduler
|
||||
|
||||
|
||||
#### Euler Ancestral scheduler
|
||||
|
||||
Ancestral sampling with Euler method steps. Based on the original (k-diffusion)[https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L72] implementation by Katherine Crowson.
|
||||
Fast scheduler which often times generates good outputs with 20-30 steps.
|
||||
|
||||
[[autodoc]] EulerAncestralDiscreteScheduler
|
||||
|
||||
|
||||
#### VQDiffusionScheduler
|
||||
|
||||
Original paper can be found [here](https://arxiv.org/abs/2111.14822)
|
||||
|
||||
[[autodoc]] VQDiffusionScheduler
|
||||
|
||||
#### RePaint scheduler
|
||||
|
||||
DDPM-based inpainting scheduler for unsupervised inpainting with extreme masks.
|
||||
Intended for use with [`RePaintPipeline`].
|
||||
Based on the paper [RePaint: Inpainting using Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2201.09865)
|
||||
and the original implementation by Andreas Lugmayr et al.: https://github.com/andreas128/RePaint
|
||||
|
||||
[[autodoc]] RePaintScheduler
|
||||
|
||||
@@ -34,8 +34,6 @@ available a colab notebook to directly try them out.
|
||||
|
||||
| Pipeline | Paper | Tasks | Colab
|
||||
|---|---|:---:|:---:|
|
||||
| [cycle_diffusion](./api/pipelines/cycle_diffusion) | [**Cycle Diffusion**](https://arxiv.org/abs/2210.05559) | Image-to-Image Text-Guided Generation |
|
||||
| [dance_diffusion](./api/pipelines/dance_diffusion) | [**Dance Diffusion**](https://github.com/williamberman/diffusers.git) | Unconditional Audio Generation |
|
||||
| [ddpm](./api/pipelines/ddpm) | [**Denoising Diffusion Probabilistic Models**](https://arxiv.org/abs/2006.11239) | Unconditional Image Generation |
|
||||
| [ddim](./api/pipelines/ddim) | [**Denoising Diffusion Implicit Models**](https://arxiv.org/abs/2010.02502) | Unconditional Image Generation |
|
||||
| [latent_diffusion](./api/pipelines/latent_diffusion) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752)| Text-to-Image Generation |
|
||||
@@ -47,6 +45,5 @@ available a colab notebook to directly try them out.
|
||||
| [stable_diffusion](./api/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Image-to-Image Text-Guided Generation | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb)
|
||||
| [stable_diffusion](./api/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-Guided Image Inpainting | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb)
|
||||
| [stochastic_karras_ve](./api/pipelines/stochastic_karras_ve) | [**Elucidating the Design Space of Diffusion-Based Generative Models**](https://arxiv.org/abs/2206.00364) | Unconditional Image Generation |
|
||||
| [vq_diffusion](./api/pipelines/vq_diffusion) | [Vector Quantized Diffusion Model for Text-to-Image Synthesis](https://arxiv.org/abs/2111.14822) | Text-to-Image Generation |
|
||||
|
||||
**Note**: Pipelines are simple examples of how to play around with the diffusion systems as described in the corresponding papers.
|
||||
|
||||
@@ -12,12 +12,9 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
# Installation
|
||||
|
||||
Install 🤗 Diffusers for whichever deep learning library you’re working with.
|
||||
Install Diffusers for with PyTorch. Support for other libraries will come in the future
|
||||
|
||||
🤗 Diffusers is tested on Python 3.7+, PyTorch 1.7.0+ and flax. Follow the installation instructions below for the deep learning library you are using:
|
||||
|
||||
- [PyTorch](https://pytorch.org/get-started/locally/) installation instructions.
|
||||
- [Flax](https://flax.readthedocs.io/en/latest/) installation instructions.
|
||||
🤗 Diffusers is tested on Python 3.7+, and PyTorch 1.7.0+.
|
||||
|
||||
## Install with pip
|
||||
|
||||
@@ -39,30 +36,12 @@ source .env/bin/activate
|
||||
|
||||
Now you're ready to install 🤗 Diffusers with the following command:
|
||||
|
||||
**For PyTorch**
|
||||
|
||||
```bash
|
||||
pip install diffusers["torch"]
|
||||
```
|
||||
|
||||
**For Flax**
|
||||
|
||||
```bash
|
||||
pip install diffusers["flax"]
|
||||
pip install diffusers
|
||||
```
|
||||
|
||||
## Install from source
|
||||
|
||||
Before intsalling `diffusers` from source, make sure you have `torch` and `accelerate` installed.
|
||||
|
||||
For `torch` installation refer to the `torch` [docs](https://pytorch.org/get-started/locally/#start-locally).
|
||||
|
||||
To install `accelerate`
|
||||
|
||||
```bash
|
||||
pip install accelerate
|
||||
```
|
||||
|
||||
Install 🤗 Diffusers from source with the following command:
|
||||
|
||||
```bash
|
||||
@@ -88,18 +67,7 @@ Clone the repository and install 🤗 Diffusers with the following commands:
|
||||
```bash
|
||||
git clone https://github.com/huggingface/diffusers.git
|
||||
cd diffusers
|
||||
```
|
||||
|
||||
**For PyTorch**
|
||||
|
||||
```
|
||||
pip install -e ".[torch]"
|
||||
```
|
||||
|
||||
**For Flax**
|
||||
|
||||
```
|
||||
pip install -e ".[flax]"
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
These commands will link the folder you cloned the repository to and your Python library paths.
|
||||
|
||||
@@ -22,7 +22,6 @@ We present some techniques and ideas to optimize 🤗 Diffusers _inference_ for
|
||||
| fp16 | 3.61s | x2.63 |
|
||||
| channels last | 3.30s | x2.88 |
|
||||
| traced UNet | 3.21s | x2.96 |
|
||||
| memory efficient attention | 2.63s | x3.61 |
|
||||
|
||||
<em>
|
||||
obtained on NVIDIA TITAN RTX by generating a single image of size 512x512 from
|
||||
@@ -291,41 +290,3 @@ pipe.unet = TracedUNet()
|
||||
with torch.inference_mode():
|
||||
image = pipe([prompt] * 1, num_inference_steps=50).images[0]
|
||||
```
|
||||
|
||||
|
||||
## Memory Efficient Attention
|
||||
Recent work on optimizing the bandwitdh in the attention block have generated huge speed ups and gains in GPU memory usage. The most recent being Flash Attention (from @tridao, [code](https://github.com/HazyResearch/flash-attention), [paper](https://arxiv.org/pdf/2205.14135.pdf)) .
|
||||
Here are the speedups we obtain on a few Nvidia GPUs when running the inference at 512x512 with a batch size of 1 (one prompt):
|
||||
|
||||
| GPU | Base Attention FP16 | Memory Efficient Attention FP16 |
|
||||
|------------------ |--------------------- |--------------------------------- |
|
||||
| NVIDIA Tesla T4 | 3.5it/s | 5.5it/s |
|
||||
| NVIDIA 3060 RTX | 4.6it/s | 7.8it/s |
|
||||
| NVIDIA A10G | 8.88it/s | 15.6it/s |
|
||||
| NVIDIA RTX A6000 | 11.7it/s | 21.09it/s |
|
||||
| NVIDIA TITAN RTX | 12.51it/s | 18.22it/s |
|
||||
| A100-SXM4-40GB | 18.6it/s | 29.it/s |
|
||||
| A100-SXM-80GB | 18.7it/s | 29.5it/s |
|
||||
|
||||
To leverage it just make sure you have:
|
||||
- PyTorch > 1.12
|
||||
- Cuda available
|
||||
- Installed the [xformers](https://github.com/facebookresearch/xformers) library
|
||||
```python
|
||||
from diffusers import StableDiffusionPipeline
|
||||
import torch
|
||||
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5",
|
||||
revision="fp16",
|
||||
torch_dtype=torch.float16,
|
||||
).to("cuda")
|
||||
|
||||
pipe.enable_xformers_memory_efficient_attention()
|
||||
|
||||
with torch.inference_mode():
|
||||
sample = pipe("a small cat")
|
||||
|
||||
# optional: You can disable it via
|
||||
# pipe.disable_xformers_memory_efficient_attention()
|
||||
```
|
||||
@@ -19,8 +19,11 @@ specific language governing permissions and limitations under the License.
|
||||
- Mac computer with Apple silicon (M1/M2) hardware.
|
||||
- macOS 12.6 or later (13.0 or later recommended).
|
||||
- arm64 version of Python.
|
||||
- PyTorch 1.13. You can install it with `pip` or `conda` using the instructions in https://pytorch.org/get-started/locally/.
|
||||
- PyTorch 1.13.0 RC (Release Candidate). You can install it with `pip` using:
|
||||
|
||||
```
|
||||
pip3 install --pre torch --extra-index-url https://download.pytorch.org/whl/test/cpu
|
||||
```
|
||||
|
||||
## Inference Pipeline
|
||||
|
||||
@@ -60,4 +63,4 @@ pipeline.enable_attention_slicing()
|
||||
## Known Issues
|
||||
|
||||
- As mentioned above, we are investigating a strange [first-time inference issue](https://github.com/huggingface/diffusers/issues/372).
|
||||
- Generating multiple prompts in a batch [crashes or doesn't work reliably](https://github.com/huggingface/diffusers/issues/363). We believe this is related to the [`mps` backend in PyTorch](https://github.com/pytorch/pytorch/issues/84039). This is being resolved, but for now we recommend to iterate instead of batching.
|
||||
- Generating multiple prompts in a batch [crashes or doesn't work reliably](https://github.com/huggingface/diffusers/issues/363). We believe this is related to the [`mps` backend in PyTorch](https://github.com/pytorch/pytorch/issues/84039). For now, we recommend to iterate instead of batching.
|
||||
|
||||
@@ -121,7 +121,7 @@ you could use it as follows:
|
||||
```python
|
||||
>>> from diffusers import LMSDiscreteScheduler
|
||||
|
||||
>>> scheduler = LMSDiscreteScheduler.from_config("runwayml/stable-diffusion-v1-5", subfolder="scheduler")
|
||||
>>> scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear")
|
||||
|
||||
>>> generator = StableDiffusionPipeline.from_pretrained(
|
||||
... "runwayml/stable-diffusion-v1-5", scheduler=scheduler, use_auth_token=AUTH_TOKEN
|
||||
|
||||
@@ -38,9 +38,9 @@ Training examples show how to pretrain or fine-tune diffusion models for a varie
|
||||
|
||||
| Task | 🤗 Accelerate | 🤗 Datasets | Colab
|
||||
|---|---|:---:|:---:|
|
||||
| [**Unconditional Image Generation**](./unconditional_image_generation) | ✅ | ✅ | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb)
|
||||
| [**Text-to-Image fine-tuning**](./text_to_image) | ✅ | ✅ |
|
||||
| [**Textual Inversion**](./textual_inversion) | ✅ | - | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/sd_textual_inversion_training.ipynb)
|
||||
| [**Unconditional Image Generation**](./unconditional_training) | ✅ | ✅ | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb)
|
||||
| [**Text-to-Image fine-tuning**](./text2image) | ✅ | ✅ |
|
||||
| [**Textual Inversion**](./text_inversion) | ✅ | - | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/sd_textual_inversion_training.ipynb)
|
||||
| [**Dreambooth**](./dreambooth) | ✅ | - | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/sd_dreambooth_training.ipynb)
|
||||
|
||||
|
||||
|
||||
@@ -18,7 +18,6 @@ If a community doesn't work as expected, please open an issue and ping the autho
|
||||
| Composable Stable Diffusion| Stable Diffusion Pipeline that supports prompts that contain "|" in prompts (as an AND condition) and weights (separated by "|" as well) to positively / negatively weight prompts. | [Composable Stable Diffusion](#composable-stable-diffusion) | - | [Mark Rich](https://github.com/MarkRich) |
|
||||
| Seed Resizing Stable Diffusion| Stable Diffusion Pipeline that supports resizing an image and retaining the concepts of the 512 by 512 generation. | [Seed Resizing](#seed-resizing) | - | [Mark Rich](https://github.com/MarkRich) |
|
||||
|
||||
| Imagic Stable Diffusion | Stable Diffusion Pipeline that enables writing a text prompt to edit an existing image| [Imagic Stable Diffusion](#imagic-stable-diffusion) | - | [Mark Rich](https://github.com/MarkRich) |
|
||||
|
||||
|
||||
To load a custom pipeline you just need to pass the `custom_pipeline` argument to `DiffusionPipeline`, as one of the files in `diffusers/examples/community`. Feel free to send a PR with your own pipelines, we will merge them quickly.
|
||||
@@ -374,49 +373,6 @@ for i in range(4):
|
||||
for i, img in enumerate(images):
|
||||
img.save(f"./composable_diffusion/image_{i}.png")
|
||||
```
|
||||
|
||||
### Imagic Stable Diffusion
|
||||
Allows you to edit an image using stable diffusion.
|
||||
|
||||
```python
|
||||
import requests
|
||||
from PIL import Image
|
||||
from io import BytesIO
|
||||
import torch
|
||||
from diffusers import DiffusionPipeline, DDIMScheduler
|
||||
has_cuda = torch.cuda.is_available()
|
||||
device = torch.device('cpu' if not has_cuda else 'cuda')
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4",
|
||||
safety_checker=None,
|
||||
use_auth_token=True,
|
||||
custom_pipeline="imagic_stable_diffusion",
|
||||
scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
|
||||
).to(device)
|
||||
generator = th.Generator("cuda").manual_seed(0)
|
||||
seed = 0
|
||||
prompt = "A photo of Barack Obama smiling with a big grin"
|
||||
url = 'https://www.dropbox.com/s/6tlwzr73jd1r9yk/obama.png?dl=1'
|
||||
response = requests.get(url)
|
||||
init_image = Image.open(BytesIO(response.content)).convert("RGB")
|
||||
init_image = init_image.resize((512, 512))
|
||||
res = pipe.train(
|
||||
prompt,
|
||||
init_image,
|
||||
guidance_scale=7.5,
|
||||
num_inference_steps=50,
|
||||
generator=generator)
|
||||
res = pipe(alpha=1)
|
||||
image = res.images[0]
|
||||
image.save('./imagic/imagic_image_alpha_1.png')
|
||||
res = pipe(alpha=1.5)
|
||||
image = res.images[0]
|
||||
image.save('./imagic/imagic_image_alpha_1_5.png')
|
||||
res = pipe(alpha=2)
|
||||
image = res.images[0]
|
||||
image.save('./imagic/imagic_image_alpha_2.png')
|
||||
```
|
||||
|
||||
### Seed Resizing
|
||||
Test seed resizing. Originally generate an image in 512 by 512, then generate image with same seed at 512 by 592 using seed resizing. Finally, generate 512 by 592 using original stable diffusion pipeline.
|
||||
|
||||
@@ -500,4 +456,4 @@ res = pipe_compare(
|
||||
|
||||
image = res.images[0]
|
||||
image.save('./seed_resize/seed_resize_{w}_{h}_image_compare.png'.format(w=width, h=height))
|
||||
```
|
||||
```
|
||||
@@ -1,476 +0,0 @@
|
||||
"""
|
||||
modeled after the textual_inversion.py / train_dreambooth.py and the work
|
||||
of justinpinkney here: https://github.com/justinpinkney/stable-diffusion/blob/main/notebooks/imagic.ipynb
|
||||
"""
|
||||
import inspect
|
||||
import warnings
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
import PIL
|
||||
from accelerate import Accelerator
|
||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||
from diffusers.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
||||
from diffusers.utils import logging
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def preprocess(image):
|
||||
w, h = image.size
|
||||
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
|
||||
image = image.resize((w, h), resample=PIL.Image.LANCZOS)
|
||||
image = np.array(image).astype(np.float32) / 255.0
|
||||
image = image[None].transpose(0, 3, 1, 2)
|
||||
image = torch.from_numpy(image)
|
||||
return 2.0 * image - 1.0
|
||||
|
||||
|
||||
class ImagicStableDiffusionPipeline(DiffusionPipeline):
|
||||
r"""
|
||||
Pipeline for imagic image editing.
|
||||
See paper here: https://arxiv.org/pdf/2210.09276.pdf
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
text_encoder ([`CLIPTextModel`]):
|
||||
Frozen text-encoder. Stable Diffusion uses the text portion of
|
||||
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
|
||||
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
|
||||
tokenizer (`CLIPTokenizer`):
|
||||
Tokenizer of class
|
||||
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
||||
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
|
||||
scheduler ([`SchedulerMixin`]):
|
||||
A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
|
||||
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
||||
safety_checker ([`StableDiffusionSafetyChecker`]):
|
||||
Classification module that estimates whether generated images could be considered offsensive or harmful.
|
||||
Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
|
||||
feature_extractor ([`CLIPFeatureExtractor`]):
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vae: AutoencoderKL,
|
||||
text_encoder: CLIPTextModel,
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
|
||||
safety_checker: StableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
):
|
||||
super().__init__()
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
|
||||
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
|
||||
r"""
|
||||
Enable sliced attention computation.
|
||||
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
|
||||
in several steps. This is useful to save some memory in exchange for a small speed decrease.
|
||||
Args:
|
||||
slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
|
||||
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
|
||||
a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
|
||||
`attention_head_dim` must be a multiple of `slice_size`.
|
||||
"""
|
||||
if slice_size == "auto":
|
||||
# half the attention head size is usually a good trade-off between
|
||||
# speed and memory
|
||||
slice_size = self.unet.config.attention_head_dim // 2
|
||||
self.unet.set_attention_slice(slice_size)
|
||||
|
||||
def disable_attention_slicing(self):
|
||||
r"""
|
||||
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
|
||||
back to computing attention in one step.
|
||||
"""
|
||||
# set slice_size = `None` to disable `attention slicing`
|
||||
self.enable_attention_slicing(None)
|
||||
|
||||
def train(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
init_image: Union[torch.FloatTensor, PIL.Image.Image],
|
||||
height: Optional[int] = 512,
|
||||
width: Optional[int] = 512,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
embedding_learning_rate: float = 0.001,
|
||||
diffusion_model_learning_rate: float = 2e-6,
|
||||
text_embedding_optimization_steps: int = 500,
|
||||
model_fine_tuning_optimization_steps: int = 1000,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
Args:
|
||||
prompt (`str` or `List[str]`):
|
||||
The prompt or prompts to guide the image generation.
|
||||
height (`int`, *optional*, defaults to 512):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*, defaults to 512):
|
||||
The width in pixels of the generated image.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
guidance_scale (`float`, *optional*, defaults to 7.5):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
eta (`float`, *optional*, defaults to 0.0):
|
||||
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
||||
[`schedulers.DDIMScheduler`], will be ignored for others.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
||||
deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
||||
plain tuple.
|
||||
Returns:
|
||||
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
||||
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
|
||||
When returning a tuple, the first element is a list with the generated images, and the second element is a
|
||||
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
||||
(nsfw) content, according to the `safety_checker`.
|
||||
"""
|
||||
accelerator = Accelerator(
|
||||
gradient_accumulation_steps=1,
|
||||
mixed_precision="fp16",
|
||||
)
|
||||
|
||||
if "torch_device" in kwargs:
|
||||
device = kwargs.pop("torch_device")
|
||||
warnings.warn(
|
||||
"`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0."
|
||||
" Consider using `pipe.to(torch_device)` instead."
|
||||
)
|
||||
|
||||
if device is None:
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
self.to(device)
|
||||
|
||||
if height % 8 != 0 or width % 8 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
||||
|
||||
# Freeze vae and unet
|
||||
self.vae.requires_grad_(False)
|
||||
self.unet.requires_grad_(False)
|
||||
self.text_encoder.requires_grad_(False)
|
||||
self.unet.eval()
|
||||
self.vae.eval()
|
||||
self.text_encoder.eval()
|
||||
|
||||
if accelerator.is_main_process:
|
||||
accelerator.init_trackers(
|
||||
"imagic",
|
||||
config={
|
||||
"embedding_learning_rate": embedding_learning_rate,
|
||||
"text_embedding_optimization_steps": text_embedding_optimization_steps,
|
||||
},
|
||||
)
|
||||
|
||||
# get text embeddings for prompt
|
||||
text_input = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
truncaton=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_embeddings = torch.nn.Parameter(
|
||||
self.text_encoder(text_input.input_ids.to(self.device))[0], requires_grad=True
|
||||
)
|
||||
text_embeddings = text_embeddings.detach()
|
||||
text_embeddings.requires_grad_()
|
||||
text_embeddings_orig = text_embeddings.clone()
|
||||
|
||||
# Initialize the optimizer
|
||||
optimizer = torch.optim.Adam(
|
||||
[text_embeddings], # only optimize the embeddings
|
||||
lr=embedding_learning_rate,
|
||||
)
|
||||
|
||||
if isinstance(init_image, PIL.Image.Image):
|
||||
init_image = preprocess(init_image)
|
||||
|
||||
latents_dtype = text_embeddings.dtype
|
||||
init_image = init_image.to(device=self.device, dtype=latents_dtype)
|
||||
init_latent_image_dist = self.vae.encode(init_image).latent_dist
|
||||
init_image_latents = init_latent_image_dist.sample(generator=generator)
|
||||
init_image_latents = 0.18215 * init_image_latents
|
||||
|
||||
progress_bar = tqdm(range(text_embedding_optimization_steps), disable=not accelerator.is_local_main_process)
|
||||
progress_bar.set_description("Steps")
|
||||
|
||||
global_step = 0
|
||||
|
||||
logger.info("First optimizing the text embedding to better reconstruct the init image")
|
||||
for _ in range(text_embedding_optimization_steps):
|
||||
with accelerator.accumulate(text_embeddings):
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn(init_image_latents.shape).to(init_image_latents.device)
|
||||
timesteps = torch.randint(1000, (1,), device=init_image_latents.device)
|
||||
|
||||
# Add noise to the latents according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process)
|
||||
noisy_latents = self.scheduler.add_noise(init_image_latents, noise, timesteps)
|
||||
|
||||
# Predict the noise residual
|
||||
noise_pred = self.unet(noisy_latents, timesteps, text_embeddings).sample
|
||||
|
||||
loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
|
||||
accelerator.backward(loss)
|
||||
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
|
||||
logs = {"loss": loss.detach().item()} # , "lr": lr_scheduler.get_last_lr()[0]}
|
||||
progress_bar.set_postfix(**logs)
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
text_embeddings.requires_grad_(False)
|
||||
|
||||
# Now we fine tune the unet to better reconstruct the image
|
||||
self.unet.requires_grad_(True)
|
||||
self.unet.train()
|
||||
optimizer = torch.optim.Adam(
|
||||
self.unet.parameters(), # only optimize unet
|
||||
lr=diffusion_model_learning_rate,
|
||||
)
|
||||
progress_bar = tqdm(range(model_fine_tuning_optimization_steps), disable=not accelerator.is_local_main_process)
|
||||
|
||||
logger.info("Next fine tuning the entire model to better reconstruct the init image")
|
||||
for _ in range(model_fine_tuning_optimization_steps):
|
||||
with accelerator.accumulate(self.unet.parameters()):
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn(init_image_latents.shape).to(init_image_latents.device)
|
||||
timesteps = torch.randint(1000, (1,), device=init_image_latents.device)
|
||||
|
||||
# Add noise to the latents according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process)
|
||||
noisy_latents = self.scheduler.add_noise(init_image_latents, noise, timesteps)
|
||||
|
||||
# Predict the noise residual
|
||||
noise_pred = self.unet(noisy_latents, timesteps, text_embeddings).sample
|
||||
|
||||
loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
|
||||
accelerator.backward(loss)
|
||||
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
|
||||
logs = {"loss": loss.detach().item()} # , "lr": lr_scheduler.get_last_lr()[0]}
|
||||
progress_bar.set_postfix(**logs)
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
self.text_embeddings_orig = text_embeddings_orig
|
||||
self.text_embeddings = text_embeddings
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
alpha: float = 1.2,
|
||||
height: Optional[int] = 512,
|
||||
width: Optional[int] = 512,
|
||||
num_inference_steps: Optional[int] = 50,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
guidance_scale: float = 7.5,
|
||||
eta: float = 0.0,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
Args:
|
||||
prompt (`str` or `List[str]`):
|
||||
The prompt or prompts to guide the image generation.
|
||||
height (`int`, *optional*, defaults to 512):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*, defaults to 512):
|
||||
The width in pixels of the generated image.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
guidance_scale (`float`, *optional*, defaults to 7.5):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
eta (`float`, *optional*, defaults to 0.0):
|
||||
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
||||
[`schedulers.DDIMScheduler`], will be ignored for others.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
||||
deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
||||
plain tuple.
|
||||
Returns:
|
||||
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
||||
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
|
||||
When returning a tuple, the first element is a list with the generated images, and the second element is a
|
||||
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
||||
(nsfw) content, according to the `safety_checker`.
|
||||
"""
|
||||
if height % 8 != 0 or width % 8 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
||||
if self.text_embeddings is None:
|
||||
raise ValueError("Please run the pipe.train() before trying to generate an image.")
|
||||
if self.text_embeddings_orig is None:
|
||||
raise ValueError("Please run the pipe.train() before trying to generate an image.")
|
||||
|
||||
text_embeddings = alpha * self.text_embeddings_orig + (1 - alpha) * self.text_embeddings
|
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
# get unconditional embeddings for classifier free guidance
|
||||
if do_classifier_free_guidance:
|
||||
uncond_tokens = [""]
|
||||
max_length = self.tokenizer.model_max_length
|
||||
uncond_input = self.tokenizer(
|
||||
uncond_tokens,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
|
||||
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
seq_len = uncond_embeddings.shape[1]
|
||||
uncond_embeddings = uncond_embeddings.view(1, seq_len, -1)
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
||||
|
||||
# get the initial random noise unless the user supplied it
|
||||
|
||||
# Unlike in other pipelines, latents need to be generated in the target device
|
||||
# for 1-to-1 results reproducibility with the CompVis implementation.
|
||||
# However this currently doesn't work in `mps`.
|
||||
latents_shape = (1, self.unet.in_channels, height // 8, width // 8)
|
||||
latents_dtype = text_embeddings.dtype
|
||||
if self.device.type == "mps":
|
||||
# randn does not exist on mps
|
||||
latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to(
|
||||
self.device
|
||||
)
|
||||
else:
|
||||
latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)
|
||||
|
||||
# set timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
# Some schedulers like PNDM have timesteps as arrays
|
||||
# It's more optimized to move all timesteps to correct device beforehand
|
||||
timesteps_tensor = self.scheduler.timesteps.to(self.device)
|
||||
|
||||
# scale the initial noise by the standard deviation required by the scheduler
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
||||
# and should be between [0, 1]
|
||||
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
extra_step_kwargs = {}
|
||||
if accepts_eta:
|
||||
extra_step_kwargs["eta"] = eta
|
||||
|
||||
for i, t in enumerate(self.progress_bar(timesteps_tensor)):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
|
||||
latents = 1 / 0.18215 * latents
|
||||
image = self.vae.decode(latents).sample
|
||||
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
|
||||
if self.safety_checker is not None:
|
||||
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(
|
||||
self.device
|
||||
)
|
||||
image, has_nsfw_concept = self.safety_checker(
|
||||
images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
|
||||
)
|
||||
else:
|
||||
has_nsfw_concept = None
|
||||
|
||||
if output_type == "pil":
|
||||
image = self.numpy_to_pil(image)
|
||||
|
||||
if not return_dict:
|
||||
return (image, has_nsfw_concept)
|
||||
|
||||
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
||||
@@ -278,7 +278,7 @@ class StableDiffusionWalkPipeline(DiffusionPipeline):
|
||||
if do_classifier_free_guidance:
|
||||
uncond_tokens: List[str]
|
||||
if negative_prompt is None:
|
||||
uncond_tokens = [""] * batch_size
|
||||
uncond_tokens = [""]
|
||||
elif type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
@@ -307,7 +307,7 @@ class StableDiffusionWalkPipeline(DiffusionPipeline):
|
||||
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
seq_len = uncond_embeddings.shape[1]
|
||||
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
|
||||
uncond_embeddings = uncond_embeddings.repeat(batch_size, num_images_per_prompt, 1)
|
||||
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
|
||||
@@ -12,7 +12,7 @@ from diffusers.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
||||
from diffusers.utils import deprecate, is_accelerate_available, logging
|
||||
from diffusers.utils import deprecate, logging
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
|
||||
@@ -340,15 +340,13 @@ def get_weighted_text_embeddings(
|
||||
# assign weights to the prompts and normalize in the sense of mean
|
||||
# TODO: should we normalize by chunk or in a whole (current implementation)?
|
||||
if (not skip_parsing) and (not skip_weighting):
|
||||
previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
|
||||
previous_mean = text_embeddings.mean(axis=[-2, -1])
|
||||
text_embeddings *= prompt_weights.unsqueeze(-1)
|
||||
current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
|
||||
text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
|
||||
text_embeddings *= (previous_mean / text_embeddings.mean(axis=[-2, -1])).unsqueeze(-1).unsqueeze(-1)
|
||||
if uncond_prompt is not None:
|
||||
previous_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
|
||||
previous_mean = uncond_embeddings.mean(axis=[-2, -1])
|
||||
uncond_embeddings *= uncond_weights.unsqueeze(-1)
|
||||
current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
|
||||
uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
|
||||
uncond_embeddings *= (previous_mean / uncond_embeddings.mean(axis=[-2, -1])).unsqueeze(-1).unsqueeze(-1)
|
||||
|
||||
if uncond_prompt is not None:
|
||||
return text_embeddings, uncond_embeddings
|
||||
@@ -433,19 +431,6 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
|
||||
new_config["steps_offset"] = 1
|
||||
scheduler._internal_dict = FrozenDict(new_config)
|
||||
|
||||
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
|
||||
deprecation_message = (
|
||||
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
|
||||
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
|
||||
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
|
||||
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
|
||||
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
|
||||
)
|
||||
deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
|
||||
new_config = dict(scheduler.config)
|
||||
new_config["clip_sample"] = False
|
||||
scheduler._internal_dict = FrozenDict(new_config)
|
||||
|
||||
if safety_checker is None:
|
||||
logger.warn(
|
||||
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
||||
@@ -466,24 +451,6 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
|
||||
def enable_xformers_memory_efficient_attention(self):
|
||||
r"""
|
||||
Enable memory efficient attention as implemented in xformers.
|
||||
|
||||
When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
|
||||
time. Speed up at training time is not guaranteed.
|
||||
|
||||
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
|
||||
is used.
|
||||
"""
|
||||
self.unet.set_use_memory_efficient_attention_xformers(True)
|
||||
|
||||
def disable_xformers_memory_efficient_attention(self):
|
||||
r"""
|
||||
Disable memory efficient attention as implemented in xformers.
|
||||
"""
|
||||
self.unet.set_use_memory_efficient_attention_xformers(False)
|
||||
|
||||
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
|
||||
r"""
|
||||
Enable sliced attention computation.
|
||||
@@ -511,23 +478,6 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
|
||||
# set slice_size = `None` to disable `attention slicing`
|
||||
self.enable_attention_slicing(None)
|
||||
|
||||
def enable_sequential_cpu_offload(self):
|
||||
r"""
|
||||
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
|
||||
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
|
||||
`torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
|
||||
"""
|
||||
if is_accelerate_available():
|
||||
from accelerate import cpu_offload
|
||||
else:
|
||||
raise ImportError("Please install accelerate via `pip install accelerate`")
|
||||
|
||||
device = self.device
|
||||
|
||||
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]:
|
||||
if cpu_offloaded_model is not None:
|
||||
cpu_offload(cpu_offloaded_model, device)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
@@ -548,7 +498,6 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
is_cancelled_callback: Optional[Callable[[], bool]] = None,
|
||||
callback_steps: Optional[int] = 1,
|
||||
**kwargs,
|
||||
):
|
||||
@@ -611,15 +560,11 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
|
||||
callback (`Callable`, *optional*):
|
||||
A function that will be called every `callback_steps` steps during inference. The function will be
|
||||
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
||||
is_cancelled_callback (`Callable`, *optional*):
|
||||
A function that will be called every `callback_steps` steps during inference. If the function returns
|
||||
`True`, the inference will be cancelled.
|
||||
callback_steps (`int`, *optional*, defaults to 1):
|
||||
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
||||
called at every step.
|
||||
|
||||
Returns:
|
||||
`None` if cancelled by `is_cancelled_callback`,
|
||||
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
||||
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
|
||||
When returning a tuple, the first element is a list with the generated images, and the second element is a
|
||||
@@ -812,11 +757,8 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
|
||||
latents = (init_latents_proper * mask) + (latents * (1 - mask))
|
||||
|
||||
# call the callback, if provided
|
||||
if i % callback_steps == 0:
|
||||
if callback is not None:
|
||||
callback(i, t, latents)
|
||||
if is_cancelled_callback is not None and is_cancelled_callback():
|
||||
return None
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
|
||||
latents = 1 / 0.18215 * latents
|
||||
image = self.vae.decode(latents).sample
|
||||
|
||||
@@ -435,7 +435,6 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
|
||||
is_cancelled_callback: Optional[Callable[[], bool]] = None,
|
||||
callback_steps: Optional[int] = 1,
|
||||
**kwargs,
|
||||
):
|
||||
@@ -497,15 +496,11 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
|
||||
callback (`Callable`, *optional*):
|
||||
A function that will be called every `callback_steps` steps during inference. The function will be
|
||||
called with the following arguments: `callback(step: int, timestep: int, latents: np.ndarray)`.
|
||||
is_cancelled_callback (`Callable`, *optional*):
|
||||
A function that will be called every `callback_steps` steps during inference. If the function returns
|
||||
`True`, the inference will be cancelled.
|
||||
callback_steps (`int`, *optional*, defaults to 1):
|
||||
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
||||
called at every step.
|
||||
|
||||
Returns:
|
||||
`None` if cancelled by `is_cancelled_callback`,
|
||||
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
||||
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
|
||||
When returning a tuple, the first element is a list with the generated images, and the second element is a
|
||||
@@ -673,11 +668,8 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
|
||||
latents = (init_latents_proper * mask) + (latents * (1 - mask))
|
||||
|
||||
# call the callback, if provided
|
||||
if i % callback_steps == 0:
|
||||
if callback is not None:
|
||||
callback(i, t, latents)
|
||||
if is_cancelled_callback is not None and is_cancelled_callback():
|
||||
return None
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
|
||||
latents = 1 / 0.18215 * latents
|
||||
# image = self.vae_decoder(latent_sample=latents)[0]
|
||||
@@ -701,7 +693,7 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
|
||||
clip_input=safety_checker_input[i : i + 1], images=image[i : i + 1]
|
||||
)
|
||||
images.append(image_i)
|
||||
has_nsfw_concept.append(has_nsfw_concept_i[0])
|
||||
has_nsfw_concept.append(has_nsfw_concept_i)
|
||||
image = np.concatenate(images)
|
||||
else:
|
||||
has_nsfw_concept = None
|
||||
|
||||
@@ -148,7 +148,7 @@ class SpeechToImagePipeline(DiffusionPipeline):
|
||||
if do_classifier_free_guidance:
|
||||
uncond_tokens: List[str]
|
||||
if negative_prompt is None:
|
||||
uncond_tokens = [""] * batch_size
|
||||
uncond_tokens = [""]
|
||||
elif type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
@@ -177,7 +177,7 @@ class SpeechToImagePipeline(DiffusionPipeline):
|
||||
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
seq_len = uncond_embeddings.shape[1]
|
||||
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
|
||||
uncond_embeddings = uncond_embeddings.repeat(batch_size, num_images_per_prompt, 1)
|
||||
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
|
||||
@@ -295,7 +295,7 @@ class WildcardStableDiffusionPipeline(DiffusionPipeline):
|
||||
if do_classifier_free_guidance:
|
||||
uncond_tokens: List[str]
|
||||
if negative_prompt is None:
|
||||
uncond_tokens = [""] * batch_size
|
||||
uncond_tokens = [""]
|
||||
elif type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
@@ -324,7 +324,7 @@ class WildcardStableDiffusionPipeline(DiffusionPipeline):
|
||||
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
seq_len = uncond_embeddings.shape[1]
|
||||
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
|
||||
uncond_embeddings = uncond_embeddings.repeat(batch_size, num_images_per_prompt, 1)
|
||||
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
|
||||
@@ -185,7 +185,7 @@ accelerate launch train_dreambooth.py \
|
||||
--class_prompt="a photo of dog" \
|
||||
--resolution=512 \
|
||||
--train_batch_size=1 \
|
||||
--use_8bit_adam \
|
||||
--use_8bit_adam
|
||||
--gradient_checkpointing \
|
||||
--learning_rate=2e-6 \
|
||||
--lr_scheduler="constant" \
|
||||
@@ -291,4 +291,4 @@ python train_dreambooth_flax.py \
|
||||
--learning_rate=2e-6 \
|
||||
--num_class_images=200 \
|
||||
--max_train_steps=800
|
||||
```
|
||||
```
|
||||
@@ -469,7 +469,9 @@ def main(args):
|
||||
eps=args.adam_epsilon,
|
||||
)
|
||||
|
||||
noise_scheduler = DDPMScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
|
||||
noise_scheduler = DDPMScheduler(
|
||||
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000
|
||||
)
|
||||
|
||||
train_dataset = DreamBoothDataset(
|
||||
instance_data_root=args.instance_data_dir,
|
||||
@@ -494,12 +496,7 @@ def main(args):
|
||||
pixel_values = torch.stack(pixel_values)
|
||||
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
|
||||
|
||||
input_ids = tokenizer.pad(
|
||||
{"input_ids": input_ids},
|
||||
padding="max_length",
|
||||
max_length=tokenizer.model_max_length,
|
||||
return_tensors="pt",
|
||||
).input_ids
|
||||
input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids
|
||||
|
||||
batch = {
|
||||
"input_ids": input_ids,
|
||||
|
||||
@@ -361,8 +361,7 @@ def main():
|
||||
logger.info(f"Number of class images to sample: {num_new_images}.")
|
||||
|
||||
sample_dataset = PromptDataset(args.class_prompt, num_new_images)
|
||||
total_sample_batch_size = args.sample_batch_size * jax.local_device_count()
|
||||
sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=total_sample_batch_size)
|
||||
sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
|
||||
|
||||
for example in tqdm(
|
||||
sample_dataloader, desc="Generating class images", disable=not jax.process_index() == 0
|
||||
|
||||
@@ -372,7 +372,11 @@ def main():
|
||||
weight_decay=args.adam_weight_decay,
|
||||
eps=args.adam_epsilon,
|
||||
)
|
||||
noise_scheduler = DDPMScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
|
||||
|
||||
# TODO (patil-suraj): load scheduler using args
|
||||
noise_scheduler = DDPMScheduler(
|
||||
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000
|
||||
)
|
||||
|
||||
# Get the datasets: you can either provide your own training and evaluation files (see below)
|
||||
# or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
|
||||
@@ -605,7 +609,9 @@ def main():
|
||||
vae=vae,
|
||||
unet=unet,
|
||||
tokenizer=tokenizer,
|
||||
scheduler=PNDMScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler"),
|
||||
scheduler=PNDMScheduler(
|
||||
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True
|
||||
),
|
||||
safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"),
|
||||
feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
|
||||
)
|
||||
|
||||
@@ -29,7 +29,7 @@ accelerate config
|
||||
|
||||
### Cat toy example
|
||||
|
||||
You need to accept the model license before downloading or using the weights. In this example we'll use model version `v1-5`, so you'll need to visit [its card](https://huggingface.co/runwayml/stable-diffusion-v1-5), read the license and tick the checkbox if you agree.
|
||||
You need to accept the model license before downloading or using the weights. In this example we'll use model version `v1-4`, so you'll need to visit [its card](https://huggingface.co/CompVis/stable-diffusion-v1-4), read the license and tick the checkbox if you agree.
|
||||
|
||||
You have to be a registered user in 🤗 Hugging Face Hub, and you'll also need to use an access token for the code to work. For more information on access tokens, please refer to [this section of the documentation](https://huggingface.co/docs/hub/security-tokens).
|
||||
|
||||
@@ -111,4 +111,4 @@ python textual_inversion_flax.py \
|
||||
--learning_rate=5.0e-04 --scale_lr \
|
||||
--output_dir="textual_inversion_cat"
|
||||
```
|
||||
It should be at least 70% faster than the PyTorch script with the same configuration.
|
||||
It should be at least 70% faster than the PyTorch script with the same configuration.
|
||||
@@ -419,7 +419,13 @@ def main():
|
||||
eps=args.adam_epsilon,
|
||||
)
|
||||
|
||||
noise_scheduler = DDPMScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
|
||||
# TODO (patil-suraj): load scheduler using args
|
||||
noise_scheduler = DDPMScheduler(
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
beta_schedule="scaled_linear",
|
||||
num_train_timesteps=1000,
|
||||
)
|
||||
|
||||
train_dataset = TextualInversionDataset(
|
||||
data_root=args.train_data_dir,
|
||||
@@ -552,7 +558,9 @@ def main():
|
||||
vae=vae,
|
||||
unet=unet,
|
||||
tokenizer=tokenizer,
|
||||
scheduler=PNDMScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler"),
|
||||
scheduler=PNDMScheduler(
|
||||
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True
|
||||
),
|
||||
safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"),
|
||||
feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
|
||||
)
|
||||
|
||||
@@ -29,24 +29,6 @@ from tqdm.auto import tqdm
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def _extract_into_tensor(arr, timesteps, broadcast_shape):
|
||||
"""
|
||||
Extract values from a 1-D numpy array for a batch of indices.
|
||||
|
||||
:param arr: the 1-D numpy array.
|
||||
:param timesteps: a tensor of indices into the array to extract.
|
||||
:param broadcast_shape: a larger shape of K dimensions with the batch
|
||||
dimension equal to the length of timesteps.
|
||||
:return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
|
||||
"""
|
||||
if not isinstance(arr, torch.Tensor):
|
||||
arr = torch.from_numpy(arr)
|
||||
res = arr[timesteps].float().to(timesteps.device)
|
||||
while len(res.shape) < len(broadcast_shape):
|
||||
res = res[..., None]
|
||||
return res.expand(broadcast_shape)
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
||||
parser.add_argument(
|
||||
@@ -189,16 +171,6 @@ def parse_args():
|
||||
),
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--predict_mode",
|
||||
type=str,
|
||||
default="eps",
|
||||
help="What the model should predict. 'eps' to predict error, 'x0' to directly predict reconstruction",
|
||||
)
|
||||
|
||||
parser.add_argument("--ddpm_num_steps", type=int, default=1000)
|
||||
parser.add_argument("--ddpm_beta_schedule", type=str, default="linear")
|
||||
|
||||
args = parser.parse_args()
|
||||
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
||||
if env_local_rank != -1 and env_local_rank != args.local_rank:
|
||||
@@ -252,7 +224,7 @@ def main(args):
|
||||
"UpBlock2D",
|
||||
),
|
||||
)
|
||||
noise_scheduler = DDPMScheduler(num_train_timesteps=args.ddpm_num_steps, beta_schedule=args.ddpm_beta_schedule)
|
||||
noise_scheduler = DDPMScheduler(num_train_timesteps=1000)
|
||||
optimizer = torch.optim.AdamW(
|
||||
model.parameters(),
|
||||
lr=args.learning_rate,
|
||||
@@ -285,8 +257,6 @@ def main(args):
|
||||
images = [augmentations(image.convert("RGB")) for image in examples["image"]]
|
||||
return {"input": images}
|
||||
|
||||
logger.info(f"Dataset size: {len(dataset)}")
|
||||
|
||||
dataset.set_transform(transforms)
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers
|
||||
@@ -349,20 +319,8 @@ def main(args):
|
||||
|
||||
with accelerator.accumulate(model):
|
||||
# Predict the noise residual
|
||||
model_output = model(noisy_images, timesteps).sample
|
||||
|
||||
if args.predict_mode == "eps":
|
||||
loss = F.mse_loss(model_output, noise) # this could have different weights!
|
||||
elif args.predict_mode == "x0":
|
||||
alpha_t = _extract_into_tensor(
|
||||
noise_scheduler.alphas_cumprod, timesteps, (clean_images.shape[0], 1, 1, 1)
|
||||
)
|
||||
snr_weights = alpha_t / (1 - alpha_t)
|
||||
loss = snr_weights * F.mse_loss(
|
||||
model_output, clean_images, reduction="none"
|
||||
) # use SNR weighting from distillation paper
|
||||
loss = loss.mean()
|
||||
|
||||
noise_pred = model(noisy_images, timesteps).sample
|
||||
loss = F.mse_loss(noise_pred, noise)
|
||||
accelerator.backward(loss)
|
||||
|
||||
if accelerator.sync_gradients:
|
||||
@@ -397,12 +355,7 @@ def main(args):
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
# run pipeline in inference (sample random noise and denoise)
|
||||
images = pipeline(
|
||||
generator=generator,
|
||||
batch_size=args.eval_batch_size,
|
||||
output_type="numpy",
|
||||
predict_epsilon=args.predict_mode == "eps",
|
||||
).images
|
||||
images = pipeline(generator=generator, batch_size=args.eval_batch_size, output_type="numpy").images
|
||||
|
||||
# denormalize the images and save to tensorboard
|
||||
images_processed = (images * 255).round().astype("uint8")
|
||||
|
||||
@@ -1,885 +0,0 @@
|
||||
"""
|
||||
This script ports models from VQ-diffusion (https://github.com/microsoft/VQ-Diffusion) to diffusers.
|
||||
|
||||
It currently only supports porting the ITHQ dataset.
|
||||
|
||||
ITHQ dataset:
|
||||
```sh
|
||||
# From the root directory of diffusers.
|
||||
|
||||
# Download the VQVAE checkpoint
|
||||
$ wget https://facevcstandard.blob.core.windows.net/v-zhictang/Improved-VQ-Diffusion_model_release/ithq_vqvae.pth?sv=2020-10-02&st=2022-05-30T15%3A17%3A18Z&se=2030-05-31T15%3A17%3A00Z&sr=b&sp=r&sig=1jVavHFPpUjDs%2FTO1V3PTezaNbPp2Nx8MxiWI7y6fEY%3D -O ithq_vqvae.pth
|
||||
|
||||
# Download the VQVAE config
|
||||
# NOTE that in VQ-diffusion the documented file is `configs/ithq.yaml` but the target class
|
||||
# `image_synthesis.modeling.codecs.image_codec.ema_vqvae.PatchVQVAE`
|
||||
# loads `OUTPUT/pretrained_model/taming_dvae/config.yaml`
|
||||
$ wget https://raw.githubusercontent.com/microsoft/VQ-Diffusion/main/OUTPUT/pretrained_model/taming_dvae/config.yaml -O ithq_vqvae.yaml
|
||||
|
||||
# Download the main model checkpoint
|
||||
$ wget https://facevcstandard.blob.core.windows.net/v-zhictang/Improved-VQ-Diffusion_model_release/ithq_learnable.pth?sv=2020-10-02&st=2022-05-30T10%3A22%3A06Z&se=2030-05-31T10%3A22%3A00Z&sr=b&sp=r&sig=GOE%2Bza02%2FPnGxYVOOPtwrTR4RA3%2F5NVgMxdW4kjaEZ8%3D -O ithq_learnable.pth
|
||||
|
||||
# Download the main model config
|
||||
$ wget https://raw.githubusercontent.com/microsoft/VQ-Diffusion/main/configs/ithq.yaml -O ithq.yaml
|
||||
|
||||
# run the convert script
|
||||
$ python ./scripts/convert_vq_diffusion_to_diffusers.py \
|
||||
--checkpoint_path ./ithq_learnable.pth \
|
||||
--original_config_file ./ithq.yaml \
|
||||
--vqvae_checkpoint_path ./ithq_vqvae.pth \
|
||||
--vqvae_original_config_file ./ithq_vqvae.yaml \
|
||||
--dump_path <path to save pre-trained `VQDiffusionPipeline`>
|
||||
```
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import tempfile
|
||||
|
||||
import torch
|
||||
|
||||
import yaml
|
||||
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
|
||||
from diffusers import VQDiffusionPipeline, VQDiffusionScheduler, VQModel
|
||||
from diffusers.models.attention import Transformer2DModel
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
from yaml.loader import FullLoader
|
||||
|
||||
|
||||
try:
|
||||
from omegaconf import OmegaConf
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"OmegaConf is required to convert the VQ Diffusion checkpoints. Please install it with `pip install"
|
||||
" OmegaConf`."
|
||||
)
|
||||
|
||||
# vqvae model
|
||||
|
||||
PORTED_VQVAES = ["image_synthesis.modeling.codecs.image_codec.patch_vqgan.PatchVQGAN"]
|
||||
|
||||
|
||||
def vqvae_model_from_original_config(original_config):
|
||||
assert original_config.target in PORTED_VQVAES, f"{original_config.target} has not yet been ported to diffusers."
|
||||
|
||||
original_config = original_config.params
|
||||
|
||||
original_encoder_config = original_config.encoder_config.params
|
||||
original_decoder_config = original_config.decoder_config.params
|
||||
|
||||
in_channels = original_encoder_config.in_channels
|
||||
out_channels = original_decoder_config.out_ch
|
||||
|
||||
down_block_types = get_down_block_types(original_encoder_config)
|
||||
up_block_types = get_up_block_types(original_decoder_config)
|
||||
|
||||
assert original_encoder_config.ch == original_decoder_config.ch
|
||||
assert original_encoder_config.ch_mult == original_decoder_config.ch_mult
|
||||
block_out_channels = tuple(
|
||||
[original_encoder_config.ch * a_ch_mult for a_ch_mult in original_encoder_config.ch_mult]
|
||||
)
|
||||
|
||||
assert original_encoder_config.num_res_blocks == original_decoder_config.num_res_blocks
|
||||
layers_per_block = original_encoder_config.num_res_blocks
|
||||
|
||||
assert original_encoder_config.z_channels == original_decoder_config.z_channels
|
||||
latent_channels = original_encoder_config.z_channels
|
||||
|
||||
num_vq_embeddings = original_config.n_embed
|
||||
|
||||
# Hard coded value for ResnetBlock.GoupNorm(num_groups) in VQ-diffusion
|
||||
norm_num_groups = 32
|
||||
|
||||
e_dim = original_config.embed_dim
|
||||
|
||||
model = VQModel(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
down_block_types=down_block_types,
|
||||
up_block_types=up_block_types,
|
||||
block_out_channels=block_out_channels,
|
||||
layers_per_block=layers_per_block,
|
||||
latent_channels=latent_channels,
|
||||
num_vq_embeddings=num_vq_embeddings,
|
||||
norm_num_groups=norm_num_groups,
|
||||
vq_embed_dim=e_dim,
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def get_down_block_types(original_encoder_config):
|
||||
attn_resolutions = coerce_attn_resolutions(original_encoder_config.attn_resolutions)
|
||||
num_resolutions = len(original_encoder_config.ch_mult)
|
||||
resolution = coerce_resolution(original_encoder_config.resolution)
|
||||
|
||||
curr_res = resolution
|
||||
down_block_types = []
|
||||
|
||||
for _ in range(num_resolutions):
|
||||
if curr_res in attn_resolutions:
|
||||
down_block_type = "AttnDownEncoderBlock2D"
|
||||
else:
|
||||
down_block_type = "DownEncoderBlock2D"
|
||||
|
||||
down_block_types.append(down_block_type)
|
||||
|
||||
curr_res = [r // 2 for r in curr_res]
|
||||
|
||||
return down_block_types
|
||||
|
||||
|
||||
def get_up_block_types(original_decoder_config):
|
||||
attn_resolutions = coerce_attn_resolutions(original_decoder_config.attn_resolutions)
|
||||
num_resolutions = len(original_decoder_config.ch_mult)
|
||||
resolution = coerce_resolution(original_decoder_config.resolution)
|
||||
|
||||
curr_res = [r // 2 ** (num_resolutions - 1) for r in resolution]
|
||||
up_block_types = []
|
||||
|
||||
for _ in reversed(range(num_resolutions)):
|
||||
if curr_res in attn_resolutions:
|
||||
up_block_type = "AttnUpDecoderBlock2D"
|
||||
else:
|
||||
up_block_type = "UpDecoderBlock2D"
|
||||
|
||||
up_block_types.append(up_block_type)
|
||||
|
||||
curr_res = [r * 2 for r in curr_res]
|
||||
|
||||
return up_block_types
|
||||
|
||||
|
||||
def coerce_attn_resolutions(attn_resolutions):
|
||||
attn_resolutions = OmegaConf.to_object(attn_resolutions)
|
||||
attn_resolutions_ = []
|
||||
for ar in attn_resolutions:
|
||||
if isinstance(ar, (list, tuple)):
|
||||
attn_resolutions_.append(list(ar))
|
||||
else:
|
||||
attn_resolutions_.append([ar, ar])
|
||||
return attn_resolutions_
|
||||
|
||||
|
||||
def coerce_resolution(resolution):
|
||||
resolution = OmegaConf.to_object(resolution)
|
||||
if isinstance(resolution, int):
|
||||
resolution = [resolution, resolution] # H, W
|
||||
elif isinstance(resolution, (tuple, list)):
|
||||
resolution = list(resolution)
|
||||
else:
|
||||
raise ValueError("Unknown type of resolution:", resolution)
|
||||
return resolution
|
||||
|
||||
|
||||
# done vqvae model
|
||||
|
||||
# vqvae checkpoint
|
||||
|
||||
|
||||
def vqvae_original_checkpoint_to_diffusers_checkpoint(model, checkpoint):
|
||||
diffusers_checkpoint = {}
|
||||
|
||||
diffusers_checkpoint.update(vqvae_encoder_to_diffusers_checkpoint(model, checkpoint))
|
||||
|
||||
# quant_conv
|
||||
|
||||
diffusers_checkpoint.update(
|
||||
{
|
||||
"quant_conv.weight": checkpoint["quant_conv.weight"],
|
||||
"quant_conv.bias": checkpoint["quant_conv.bias"],
|
||||
}
|
||||
)
|
||||
|
||||
# quantize
|
||||
diffusers_checkpoint.update({"quantize.embedding.weight": checkpoint["quantize.embedding"]})
|
||||
|
||||
# post_quant_conv
|
||||
diffusers_checkpoint.update(
|
||||
{
|
||||
"post_quant_conv.weight": checkpoint["post_quant_conv.weight"],
|
||||
"post_quant_conv.bias": checkpoint["post_quant_conv.bias"],
|
||||
}
|
||||
)
|
||||
|
||||
# decoder
|
||||
diffusers_checkpoint.update(vqvae_decoder_to_diffusers_checkpoint(model, checkpoint))
|
||||
|
||||
return diffusers_checkpoint
|
||||
|
||||
|
||||
def vqvae_encoder_to_diffusers_checkpoint(model, checkpoint):
|
||||
diffusers_checkpoint = {}
|
||||
|
||||
# conv_in
|
||||
diffusers_checkpoint.update(
|
||||
{
|
||||
"encoder.conv_in.weight": checkpoint["encoder.conv_in.weight"],
|
||||
"encoder.conv_in.bias": checkpoint["encoder.conv_in.bias"],
|
||||
}
|
||||
)
|
||||
|
||||
# down_blocks
|
||||
for down_block_idx, down_block in enumerate(model.encoder.down_blocks):
|
||||
diffusers_down_block_prefix = f"encoder.down_blocks.{down_block_idx}"
|
||||
down_block_prefix = f"encoder.down.{down_block_idx}"
|
||||
|
||||
# resnets
|
||||
for resnet_idx, resnet in enumerate(down_block.resnets):
|
||||
diffusers_resnet_prefix = f"{diffusers_down_block_prefix}.resnets.{resnet_idx}"
|
||||
resnet_prefix = f"{down_block_prefix}.block.{resnet_idx}"
|
||||
|
||||
diffusers_checkpoint.update(
|
||||
vqvae_resnet_to_diffusers_checkpoint(
|
||||
resnet, checkpoint, diffusers_resnet_prefix=diffusers_resnet_prefix, resnet_prefix=resnet_prefix
|
||||
)
|
||||
)
|
||||
|
||||
# downsample
|
||||
|
||||
# do not include the downsample when on the last down block
|
||||
# There is no downsample on the last down block
|
||||
if down_block_idx != len(model.encoder.down_blocks) - 1:
|
||||
# There's a single downsample in the original checkpoint but a list of downsamples
|
||||
# in the diffusers model.
|
||||
diffusers_downsample_prefix = f"{diffusers_down_block_prefix}.downsamplers.0.conv"
|
||||
downsample_prefix = f"{down_block_prefix}.downsample.conv"
|
||||
diffusers_checkpoint.update(
|
||||
{
|
||||
f"{diffusers_downsample_prefix}.weight": checkpoint[f"{downsample_prefix}.weight"],
|
||||
f"{diffusers_downsample_prefix}.bias": checkpoint[f"{downsample_prefix}.bias"],
|
||||
}
|
||||
)
|
||||
|
||||
# attentions
|
||||
|
||||
if hasattr(down_block, "attentions"):
|
||||
for attention_idx, _ in enumerate(down_block.attentions):
|
||||
diffusers_attention_prefix = f"{diffusers_down_block_prefix}.attentions.{attention_idx}"
|
||||
attention_prefix = f"{down_block_prefix}.attn.{attention_idx}"
|
||||
diffusers_checkpoint.update(
|
||||
vqvae_attention_to_diffusers_checkpoint(
|
||||
checkpoint,
|
||||
diffusers_attention_prefix=diffusers_attention_prefix,
|
||||
attention_prefix=attention_prefix,
|
||||
)
|
||||
)
|
||||
|
||||
# mid block
|
||||
|
||||
# mid block attentions
|
||||
|
||||
# There is a single hardcoded attention block in the middle of the VQ-diffusion encoder
|
||||
diffusers_attention_prefix = "encoder.mid_block.attentions.0"
|
||||
attention_prefix = "encoder.mid.attn_1"
|
||||
diffusers_checkpoint.update(
|
||||
vqvae_attention_to_diffusers_checkpoint(
|
||||
checkpoint, diffusers_attention_prefix=diffusers_attention_prefix, attention_prefix=attention_prefix
|
||||
)
|
||||
)
|
||||
|
||||
# mid block resnets
|
||||
|
||||
for diffusers_resnet_idx, resnet in enumerate(model.encoder.mid_block.resnets):
|
||||
diffusers_resnet_prefix = f"encoder.mid_block.resnets.{diffusers_resnet_idx}"
|
||||
|
||||
# the hardcoded prefixes to `block_` are 1 and 2
|
||||
orig_resnet_idx = diffusers_resnet_idx + 1
|
||||
# There are two hardcoded resnets in the middle of the VQ-diffusion encoder
|
||||
resnet_prefix = f"encoder.mid.block_{orig_resnet_idx}"
|
||||
|
||||
diffusers_checkpoint.update(
|
||||
vqvae_resnet_to_diffusers_checkpoint(
|
||||
resnet, checkpoint, diffusers_resnet_prefix=diffusers_resnet_prefix, resnet_prefix=resnet_prefix
|
||||
)
|
||||
)
|
||||
|
||||
diffusers_checkpoint.update(
|
||||
{
|
||||
# conv_norm_out
|
||||
"encoder.conv_norm_out.weight": checkpoint["encoder.norm_out.weight"],
|
||||
"encoder.conv_norm_out.bias": checkpoint["encoder.norm_out.bias"],
|
||||
# conv_out
|
||||
"encoder.conv_out.weight": checkpoint["encoder.conv_out.weight"],
|
||||
"encoder.conv_out.bias": checkpoint["encoder.conv_out.bias"],
|
||||
}
|
||||
)
|
||||
|
||||
return diffusers_checkpoint
|
||||
|
||||
|
||||
def vqvae_decoder_to_diffusers_checkpoint(model, checkpoint):
|
||||
diffusers_checkpoint = {}
|
||||
|
||||
# conv in
|
||||
diffusers_checkpoint.update(
|
||||
{
|
||||
"decoder.conv_in.weight": checkpoint["decoder.conv_in.weight"],
|
||||
"decoder.conv_in.bias": checkpoint["decoder.conv_in.bias"],
|
||||
}
|
||||
)
|
||||
|
||||
# up_blocks
|
||||
|
||||
for diffusers_up_block_idx, up_block in enumerate(model.decoder.up_blocks):
|
||||
# up_blocks are stored in reverse order in the VQ-diffusion checkpoint
|
||||
orig_up_block_idx = len(model.decoder.up_blocks) - 1 - diffusers_up_block_idx
|
||||
|
||||
diffusers_up_block_prefix = f"decoder.up_blocks.{diffusers_up_block_idx}"
|
||||
up_block_prefix = f"decoder.up.{orig_up_block_idx}"
|
||||
|
||||
# resnets
|
||||
for resnet_idx, resnet in enumerate(up_block.resnets):
|
||||
diffusers_resnet_prefix = f"{diffusers_up_block_prefix}.resnets.{resnet_idx}"
|
||||
resnet_prefix = f"{up_block_prefix}.block.{resnet_idx}"
|
||||
|
||||
diffusers_checkpoint.update(
|
||||
vqvae_resnet_to_diffusers_checkpoint(
|
||||
resnet, checkpoint, diffusers_resnet_prefix=diffusers_resnet_prefix, resnet_prefix=resnet_prefix
|
||||
)
|
||||
)
|
||||
|
||||
# upsample
|
||||
|
||||
# there is no up sample on the last up block
|
||||
if diffusers_up_block_idx != len(model.decoder.up_blocks) - 1:
|
||||
# There's a single upsample in the VQ-diffusion checkpoint but a list of downsamples
|
||||
# in the diffusers model.
|
||||
diffusers_downsample_prefix = f"{diffusers_up_block_prefix}.upsamplers.0.conv"
|
||||
downsample_prefix = f"{up_block_prefix}.upsample.conv"
|
||||
diffusers_checkpoint.update(
|
||||
{
|
||||
f"{diffusers_downsample_prefix}.weight": checkpoint[f"{downsample_prefix}.weight"],
|
||||
f"{diffusers_downsample_prefix}.bias": checkpoint[f"{downsample_prefix}.bias"],
|
||||
}
|
||||
)
|
||||
|
||||
# attentions
|
||||
|
||||
if hasattr(up_block, "attentions"):
|
||||
for attention_idx, _ in enumerate(up_block.attentions):
|
||||
diffusers_attention_prefix = f"{diffusers_up_block_prefix}.attentions.{attention_idx}"
|
||||
attention_prefix = f"{up_block_prefix}.attn.{attention_idx}"
|
||||
diffusers_checkpoint.update(
|
||||
vqvae_attention_to_diffusers_checkpoint(
|
||||
checkpoint,
|
||||
diffusers_attention_prefix=diffusers_attention_prefix,
|
||||
attention_prefix=attention_prefix,
|
||||
)
|
||||
)
|
||||
|
||||
# mid block
|
||||
|
||||
# mid block attentions
|
||||
|
||||
# There is a single hardcoded attention block in the middle of the VQ-diffusion decoder
|
||||
diffusers_attention_prefix = "decoder.mid_block.attentions.0"
|
||||
attention_prefix = "decoder.mid.attn_1"
|
||||
diffusers_checkpoint.update(
|
||||
vqvae_attention_to_diffusers_checkpoint(
|
||||
checkpoint, diffusers_attention_prefix=diffusers_attention_prefix, attention_prefix=attention_prefix
|
||||
)
|
||||
)
|
||||
|
||||
# mid block resnets
|
||||
|
||||
for diffusers_resnet_idx, resnet in enumerate(model.encoder.mid_block.resnets):
|
||||
diffusers_resnet_prefix = f"decoder.mid_block.resnets.{diffusers_resnet_idx}"
|
||||
|
||||
# the hardcoded prefixes to `block_` are 1 and 2
|
||||
orig_resnet_idx = diffusers_resnet_idx + 1
|
||||
# There are two hardcoded resnets in the middle of the VQ-diffusion decoder
|
||||
resnet_prefix = f"decoder.mid.block_{orig_resnet_idx}"
|
||||
|
||||
diffusers_checkpoint.update(
|
||||
vqvae_resnet_to_diffusers_checkpoint(
|
||||
resnet, checkpoint, diffusers_resnet_prefix=diffusers_resnet_prefix, resnet_prefix=resnet_prefix
|
||||
)
|
||||
)
|
||||
|
||||
diffusers_checkpoint.update(
|
||||
{
|
||||
# conv_norm_out
|
||||
"decoder.conv_norm_out.weight": checkpoint["decoder.norm_out.weight"],
|
||||
"decoder.conv_norm_out.bias": checkpoint["decoder.norm_out.bias"],
|
||||
# conv_out
|
||||
"decoder.conv_out.weight": checkpoint["decoder.conv_out.weight"],
|
||||
"decoder.conv_out.bias": checkpoint["decoder.conv_out.bias"],
|
||||
}
|
||||
)
|
||||
|
||||
return diffusers_checkpoint
|
||||
|
||||
|
||||
def vqvae_resnet_to_diffusers_checkpoint(resnet, checkpoint, *, diffusers_resnet_prefix, resnet_prefix):
|
||||
rv = {
|
||||
# norm1
|
||||
f"{diffusers_resnet_prefix}.norm1.weight": checkpoint[f"{resnet_prefix}.norm1.weight"],
|
||||
f"{diffusers_resnet_prefix}.norm1.bias": checkpoint[f"{resnet_prefix}.norm1.bias"],
|
||||
# conv1
|
||||
f"{diffusers_resnet_prefix}.conv1.weight": checkpoint[f"{resnet_prefix}.conv1.weight"],
|
||||
f"{diffusers_resnet_prefix}.conv1.bias": checkpoint[f"{resnet_prefix}.conv1.bias"],
|
||||
# norm2
|
||||
f"{diffusers_resnet_prefix}.norm2.weight": checkpoint[f"{resnet_prefix}.norm2.weight"],
|
||||
f"{diffusers_resnet_prefix}.norm2.bias": checkpoint[f"{resnet_prefix}.norm2.bias"],
|
||||
# conv2
|
||||
f"{diffusers_resnet_prefix}.conv2.weight": checkpoint[f"{resnet_prefix}.conv2.weight"],
|
||||
f"{diffusers_resnet_prefix}.conv2.bias": checkpoint[f"{resnet_prefix}.conv2.bias"],
|
||||
}
|
||||
|
||||
if resnet.conv_shortcut is not None:
|
||||
rv.update(
|
||||
{
|
||||
f"{diffusers_resnet_prefix}.conv_shortcut.weight": checkpoint[f"{resnet_prefix}.nin_shortcut.weight"],
|
||||
f"{diffusers_resnet_prefix}.conv_shortcut.bias": checkpoint[f"{resnet_prefix}.nin_shortcut.bias"],
|
||||
}
|
||||
)
|
||||
|
||||
return rv
|
||||
|
||||
|
||||
def vqvae_attention_to_diffusers_checkpoint(checkpoint, *, diffusers_attention_prefix, attention_prefix):
|
||||
return {
|
||||
# group_norm
|
||||
f"{diffusers_attention_prefix}.group_norm.weight": checkpoint[f"{attention_prefix}.norm.weight"],
|
||||
f"{diffusers_attention_prefix}.group_norm.bias": checkpoint[f"{attention_prefix}.norm.bias"],
|
||||
# query
|
||||
f"{diffusers_attention_prefix}.query.weight": checkpoint[f"{attention_prefix}.q.weight"][:, :, 0, 0],
|
||||
f"{diffusers_attention_prefix}.query.bias": checkpoint[f"{attention_prefix}.q.bias"],
|
||||
# key
|
||||
f"{diffusers_attention_prefix}.key.weight": checkpoint[f"{attention_prefix}.k.weight"][:, :, 0, 0],
|
||||
f"{diffusers_attention_prefix}.key.bias": checkpoint[f"{attention_prefix}.k.bias"],
|
||||
# value
|
||||
f"{diffusers_attention_prefix}.value.weight": checkpoint[f"{attention_prefix}.v.weight"][:, :, 0, 0],
|
||||
f"{diffusers_attention_prefix}.value.bias": checkpoint[f"{attention_prefix}.v.bias"],
|
||||
# proj_attn
|
||||
f"{diffusers_attention_prefix}.proj_attn.weight": checkpoint[f"{attention_prefix}.proj_out.weight"][
|
||||
:, :, 0, 0
|
||||
],
|
||||
f"{diffusers_attention_prefix}.proj_attn.bias": checkpoint[f"{attention_prefix}.proj_out.bias"],
|
||||
}
|
||||
|
||||
|
||||
# done vqvae checkpoint
|
||||
|
||||
# transformer model
|
||||
|
||||
PORTED_DIFFUSIONS = ["image_synthesis.modeling.transformers.diffusion_transformer.DiffusionTransformer"]
|
||||
PORTED_TRANSFORMERS = ["image_synthesis.modeling.transformers.transformer_utils.Text2ImageTransformer"]
|
||||
PORTED_CONTENT_EMBEDDINGS = ["image_synthesis.modeling.embeddings.dalle_mask_image_embedding.DalleMaskImageEmbedding"]
|
||||
|
||||
|
||||
def transformer_model_from_original_config(
|
||||
original_diffusion_config, original_transformer_config, original_content_embedding_config
|
||||
):
|
||||
assert (
|
||||
original_diffusion_config.target in PORTED_DIFFUSIONS
|
||||
), f"{original_diffusion_config.target} has not yet been ported to diffusers."
|
||||
assert (
|
||||
original_transformer_config.target in PORTED_TRANSFORMERS
|
||||
), f"{original_transformer_config.target} has not yet been ported to diffusers."
|
||||
assert (
|
||||
original_content_embedding_config.target in PORTED_CONTENT_EMBEDDINGS
|
||||
), f"{original_content_embedding_config.target} has not yet been ported to diffusers."
|
||||
|
||||
original_diffusion_config = original_diffusion_config.params
|
||||
original_transformer_config = original_transformer_config.params
|
||||
original_content_embedding_config = original_content_embedding_config.params
|
||||
|
||||
inner_dim = original_transformer_config["n_embd"]
|
||||
|
||||
n_heads = original_transformer_config["n_head"]
|
||||
|
||||
# VQ-Diffusion gives dimension of the multi-headed attention layers as the
|
||||
# number of attention heads times the sequence length (the dimension) of a
|
||||
# single head. We want to specify our attention blocks with those values
|
||||
# specified separately
|
||||
assert inner_dim % n_heads == 0
|
||||
d_head = inner_dim // n_heads
|
||||
|
||||
depth = original_transformer_config["n_layer"]
|
||||
context_dim = original_transformer_config["condition_dim"]
|
||||
|
||||
num_embed = original_content_embedding_config["num_embed"]
|
||||
# the number of embeddings in the transformer includes the mask embedding.
|
||||
# the content embedding (the vqvae) does not include the mask embedding.
|
||||
num_embed = num_embed + 1
|
||||
|
||||
height = original_transformer_config["content_spatial_size"][0]
|
||||
width = original_transformer_config["content_spatial_size"][1]
|
||||
|
||||
assert width == height, "width has to be equal to height"
|
||||
dropout = original_transformer_config["resid_pdrop"]
|
||||
num_embeds_ada_norm = original_diffusion_config["diffusion_step"]
|
||||
|
||||
model_kwargs = {
|
||||
"attention_bias": True,
|
||||
"cross_attention_dim": context_dim,
|
||||
"attention_head_dim": d_head,
|
||||
"num_layers": depth,
|
||||
"dropout": dropout,
|
||||
"num_attention_heads": n_heads,
|
||||
"num_vector_embeds": num_embed,
|
||||
"num_embeds_ada_norm": num_embeds_ada_norm,
|
||||
"norm_num_groups": 32,
|
||||
"sample_size": width,
|
||||
"activation_fn": "geglu-approximate",
|
||||
}
|
||||
|
||||
model = Transformer2DModel(**model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
# done transformer model
|
||||
|
||||
# transformer checkpoint
|
||||
|
||||
|
||||
def transformer_original_checkpoint_to_diffusers_checkpoint(model, checkpoint):
|
||||
diffusers_checkpoint = {}
|
||||
|
||||
transformer_prefix = "transformer.transformer"
|
||||
|
||||
diffusers_latent_image_embedding_prefix = "latent_image_embedding"
|
||||
latent_image_embedding_prefix = f"{transformer_prefix}.content_emb"
|
||||
|
||||
# DalleMaskImageEmbedding
|
||||
diffusers_checkpoint.update(
|
||||
{
|
||||
f"{diffusers_latent_image_embedding_prefix}.emb.weight": checkpoint[
|
||||
f"{latent_image_embedding_prefix}.emb.weight"
|
||||
],
|
||||
f"{diffusers_latent_image_embedding_prefix}.height_emb.weight": checkpoint[
|
||||
f"{latent_image_embedding_prefix}.height_emb.weight"
|
||||
],
|
||||
f"{diffusers_latent_image_embedding_prefix}.width_emb.weight": checkpoint[
|
||||
f"{latent_image_embedding_prefix}.width_emb.weight"
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
# transformer blocks
|
||||
for transformer_block_idx, transformer_block in enumerate(model.transformer_blocks):
|
||||
diffusers_transformer_block_prefix = f"transformer_blocks.{transformer_block_idx}"
|
||||
transformer_block_prefix = f"{transformer_prefix}.blocks.{transformer_block_idx}"
|
||||
|
||||
# ada norm block
|
||||
diffusers_ada_norm_prefix = f"{diffusers_transformer_block_prefix}.norm1"
|
||||
ada_norm_prefix = f"{transformer_block_prefix}.ln1"
|
||||
|
||||
diffusers_checkpoint.update(
|
||||
transformer_ada_norm_to_diffusers_checkpoint(
|
||||
checkpoint, diffusers_ada_norm_prefix=diffusers_ada_norm_prefix, ada_norm_prefix=ada_norm_prefix
|
||||
)
|
||||
)
|
||||
|
||||
# attention block
|
||||
diffusers_attention_prefix = f"{diffusers_transformer_block_prefix}.attn1"
|
||||
attention_prefix = f"{transformer_block_prefix}.attn1"
|
||||
|
||||
diffusers_checkpoint.update(
|
||||
transformer_attention_to_diffusers_checkpoint(
|
||||
checkpoint, diffusers_attention_prefix=diffusers_attention_prefix, attention_prefix=attention_prefix
|
||||
)
|
||||
)
|
||||
|
||||
# ada norm block
|
||||
diffusers_ada_norm_prefix = f"{diffusers_transformer_block_prefix}.norm2"
|
||||
ada_norm_prefix = f"{transformer_block_prefix}.ln1_1"
|
||||
|
||||
diffusers_checkpoint.update(
|
||||
transformer_ada_norm_to_diffusers_checkpoint(
|
||||
checkpoint, diffusers_ada_norm_prefix=diffusers_ada_norm_prefix, ada_norm_prefix=ada_norm_prefix
|
||||
)
|
||||
)
|
||||
|
||||
# attention block
|
||||
diffusers_attention_prefix = f"{diffusers_transformer_block_prefix}.attn2"
|
||||
attention_prefix = f"{transformer_block_prefix}.attn2"
|
||||
|
||||
diffusers_checkpoint.update(
|
||||
transformer_attention_to_diffusers_checkpoint(
|
||||
checkpoint, diffusers_attention_prefix=diffusers_attention_prefix, attention_prefix=attention_prefix
|
||||
)
|
||||
)
|
||||
|
||||
# norm block
|
||||
diffusers_norm_block_prefix = f"{diffusers_transformer_block_prefix}.norm3"
|
||||
norm_block_prefix = f"{transformer_block_prefix}.ln2"
|
||||
|
||||
diffusers_checkpoint.update(
|
||||
{
|
||||
f"{diffusers_norm_block_prefix}.weight": checkpoint[f"{norm_block_prefix}.weight"],
|
||||
f"{diffusers_norm_block_prefix}.bias": checkpoint[f"{norm_block_prefix}.bias"],
|
||||
}
|
||||
)
|
||||
|
||||
# feedforward block
|
||||
diffusers_feedforward_prefix = f"{diffusers_transformer_block_prefix}.ff"
|
||||
feedforward_prefix = f"{transformer_block_prefix}.mlp"
|
||||
|
||||
diffusers_checkpoint.update(
|
||||
transformer_feedforward_to_diffusers_checkpoint(
|
||||
checkpoint,
|
||||
diffusers_feedforward_prefix=diffusers_feedforward_prefix,
|
||||
feedforward_prefix=feedforward_prefix,
|
||||
)
|
||||
)
|
||||
|
||||
# to logits
|
||||
|
||||
diffusers_norm_out_prefix = "norm_out"
|
||||
norm_out_prefix = f"{transformer_prefix}.to_logits.0"
|
||||
|
||||
diffusers_checkpoint.update(
|
||||
{
|
||||
f"{diffusers_norm_out_prefix}.weight": checkpoint[f"{norm_out_prefix}.weight"],
|
||||
f"{diffusers_norm_out_prefix}.bias": checkpoint[f"{norm_out_prefix}.bias"],
|
||||
}
|
||||
)
|
||||
|
||||
diffusers_out_prefix = "out"
|
||||
out_prefix = f"{transformer_prefix}.to_logits.1"
|
||||
|
||||
diffusers_checkpoint.update(
|
||||
{
|
||||
f"{diffusers_out_prefix}.weight": checkpoint[f"{out_prefix}.weight"],
|
||||
f"{diffusers_out_prefix}.bias": checkpoint[f"{out_prefix}.bias"],
|
||||
}
|
||||
)
|
||||
|
||||
return diffusers_checkpoint
|
||||
|
||||
|
||||
def transformer_ada_norm_to_diffusers_checkpoint(checkpoint, *, diffusers_ada_norm_prefix, ada_norm_prefix):
|
||||
return {
|
||||
f"{diffusers_ada_norm_prefix}.emb.weight": checkpoint[f"{ada_norm_prefix}.emb.weight"],
|
||||
f"{diffusers_ada_norm_prefix}.linear.weight": checkpoint[f"{ada_norm_prefix}.linear.weight"],
|
||||
f"{diffusers_ada_norm_prefix}.linear.bias": checkpoint[f"{ada_norm_prefix}.linear.bias"],
|
||||
}
|
||||
|
||||
|
||||
def transformer_attention_to_diffusers_checkpoint(checkpoint, *, diffusers_attention_prefix, attention_prefix):
|
||||
return {
|
||||
# key
|
||||
f"{diffusers_attention_prefix}.to_k.weight": checkpoint[f"{attention_prefix}.key.weight"],
|
||||
f"{diffusers_attention_prefix}.to_k.bias": checkpoint[f"{attention_prefix}.key.bias"],
|
||||
# query
|
||||
f"{diffusers_attention_prefix}.to_q.weight": checkpoint[f"{attention_prefix}.query.weight"],
|
||||
f"{diffusers_attention_prefix}.to_q.bias": checkpoint[f"{attention_prefix}.query.bias"],
|
||||
# value
|
||||
f"{diffusers_attention_prefix}.to_v.weight": checkpoint[f"{attention_prefix}.value.weight"],
|
||||
f"{diffusers_attention_prefix}.to_v.bias": checkpoint[f"{attention_prefix}.value.bias"],
|
||||
# linear out
|
||||
f"{diffusers_attention_prefix}.to_out.0.weight": checkpoint[f"{attention_prefix}.proj.weight"],
|
||||
f"{diffusers_attention_prefix}.to_out.0.bias": checkpoint[f"{attention_prefix}.proj.bias"],
|
||||
}
|
||||
|
||||
|
||||
def transformer_feedforward_to_diffusers_checkpoint(checkpoint, *, diffusers_feedforward_prefix, feedforward_prefix):
|
||||
return {
|
||||
f"{diffusers_feedforward_prefix}.net.0.proj.weight": checkpoint[f"{feedforward_prefix}.0.weight"],
|
||||
f"{diffusers_feedforward_prefix}.net.0.proj.bias": checkpoint[f"{feedforward_prefix}.0.bias"],
|
||||
f"{diffusers_feedforward_prefix}.net.2.weight": checkpoint[f"{feedforward_prefix}.2.weight"],
|
||||
f"{diffusers_feedforward_prefix}.net.2.bias": checkpoint[f"{feedforward_prefix}.2.bias"],
|
||||
}
|
||||
|
||||
|
||||
# done transformer checkpoint
|
||||
|
||||
|
||||
def read_config_file(filename):
|
||||
# The yaml file contains annotations that certain values should
|
||||
# loaded as tuples. By default, OmegaConf will panic when reading
|
||||
# these. Instead, we can manually read the yaml with the FullLoader and then
|
||||
# construct the OmegaConf object.
|
||||
with open(filename) as f:
|
||||
original_config = yaml.load(f, FullLoader)
|
||||
|
||||
return OmegaConf.create(original_config)
|
||||
|
||||
|
||||
# We take separate arguments for the vqvae because the ITHQ vqvae config file
|
||||
# is separate from the config file for the rest of the model.
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--vqvae_checkpoint_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the vqvae checkpoint to convert.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--vqvae_original_config_file",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The YAML config file corresponding to the original architecture for the vqvae.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--original_config_file",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The YAML config file corresponding to the original architecture.",
|
||||
)
|
||||
|
||||
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
|
||||
|
||||
parser.add_argument(
|
||||
"--checkpoint_load_device",
|
||||
default="cpu",
|
||||
type=str,
|
||||
required=False,
|
||||
help="The device passed to `map_location` when loading checkpoints.",
|
||||
)
|
||||
|
||||
# See link for how ema weights are always selected
|
||||
# https://github.com/microsoft/VQ-Diffusion/blob/3c98e77f721db7c787b76304fa2c96a36c7b00af/inference_VQ_Diffusion.py#L65
|
||||
parser.add_argument(
|
||||
"--no_use_ema",
|
||||
action="store_true",
|
||||
required=False,
|
||||
help=(
|
||||
"Set to not use the ema weights from the original VQ-Diffusion checkpoint. You probably do not want to set"
|
||||
" it as the original VQ-Diffusion always uses the ema weights when loading models."
|
||||
),
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
use_ema = not args.no_use_ema
|
||||
|
||||
print(f"loading checkpoints to {args.checkpoint_load_device}")
|
||||
|
||||
checkpoint_map_location = torch.device(args.checkpoint_load_device)
|
||||
|
||||
# vqvae_model
|
||||
|
||||
print(f"loading vqvae, config: {args.vqvae_original_config_file}, checkpoint: {args.vqvae_checkpoint_path}")
|
||||
|
||||
vqvae_original_config = read_config_file(args.vqvae_original_config_file).model
|
||||
vqvae_checkpoint = torch.load(args.vqvae_checkpoint_path, map_location=checkpoint_map_location)["model"]
|
||||
|
||||
with init_empty_weights():
|
||||
vqvae_model = vqvae_model_from_original_config(vqvae_original_config)
|
||||
|
||||
vqvae_diffusers_checkpoint = vqvae_original_checkpoint_to_diffusers_checkpoint(vqvae_model, vqvae_checkpoint)
|
||||
|
||||
with tempfile.NamedTemporaryFile() as vqvae_diffusers_checkpoint_file:
|
||||
torch.save(vqvae_diffusers_checkpoint, vqvae_diffusers_checkpoint_file.name)
|
||||
del vqvae_diffusers_checkpoint
|
||||
del vqvae_checkpoint
|
||||
load_checkpoint_and_dispatch(vqvae_model, vqvae_diffusers_checkpoint_file.name, device_map="auto")
|
||||
|
||||
print("done loading vqvae")
|
||||
|
||||
# done vqvae_model
|
||||
|
||||
# transformer_model
|
||||
|
||||
print(
|
||||
f"loading transformer, config: {args.original_config_file}, checkpoint: {args.checkpoint_path}, use ema:"
|
||||
f" {use_ema}"
|
||||
)
|
||||
|
||||
original_config = read_config_file(args.original_config_file).model
|
||||
|
||||
diffusion_config = original_config.params.diffusion_config
|
||||
transformer_config = original_config.params.diffusion_config.params.transformer_config
|
||||
content_embedding_config = original_config.params.diffusion_config.params.content_emb_config
|
||||
|
||||
pre_checkpoint = torch.load(args.checkpoint_path, map_location=checkpoint_map_location)
|
||||
|
||||
if use_ema:
|
||||
if "ema" in pre_checkpoint:
|
||||
checkpoint = {}
|
||||
for k, v in pre_checkpoint["model"].items():
|
||||
checkpoint[k] = v
|
||||
|
||||
for k, v in pre_checkpoint["ema"].items():
|
||||
# The ema weights are only used on the transformer. To mimic their key as if they came
|
||||
# from the state_dict for the top level model, we prefix with an additional "transformer."
|
||||
# See the source linked in the args.use_ema config for more information.
|
||||
checkpoint[f"transformer.{k}"] = v
|
||||
else:
|
||||
print("attempted to load ema weights but no ema weights are specified in the loaded checkpoint.")
|
||||
checkpoint = pre_checkpoint["model"]
|
||||
else:
|
||||
checkpoint = pre_checkpoint["model"]
|
||||
|
||||
del pre_checkpoint
|
||||
|
||||
with init_empty_weights():
|
||||
transformer_model = transformer_model_from_original_config(
|
||||
diffusion_config, transformer_config, content_embedding_config
|
||||
)
|
||||
|
||||
diffusers_transformer_checkpoint = transformer_original_checkpoint_to_diffusers_checkpoint(
|
||||
transformer_model, checkpoint
|
||||
)
|
||||
|
||||
with tempfile.NamedTemporaryFile() as diffusers_transformer_checkpoint_file:
|
||||
torch.save(diffusers_transformer_checkpoint, diffusers_transformer_checkpoint_file.name)
|
||||
del diffusers_transformer_checkpoint
|
||||
del checkpoint
|
||||
load_checkpoint_and_dispatch(transformer_model, diffusers_transformer_checkpoint_file.name, device_map="auto")
|
||||
|
||||
print("done loading transformer")
|
||||
|
||||
# done transformer_model
|
||||
|
||||
# text encoder
|
||||
|
||||
print("loading CLIP text encoder")
|
||||
|
||||
clip_name = "openai/clip-vit-base-patch32"
|
||||
|
||||
# The original VQ-Diffusion specifies the pad value by the int used in the
|
||||
# returned tokens. Each model uses `0` as the pad value. The transformers clip api
|
||||
# specifies the pad value via the token before it has been tokenized. The `!` pad
|
||||
# token is the same as padding with the `0` pad value.
|
||||
pad_token = "!"
|
||||
|
||||
tokenizer_model = CLIPTokenizer.from_pretrained(clip_name, pad_token=pad_token, device_map="auto")
|
||||
|
||||
assert tokenizer_model.convert_tokens_to_ids(pad_token) == 0
|
||||
|
||||
text_encoder_model = CLIPTextModel.from_pretrained(
|
||||
clip_name,
|
||||
# `CLIPTextModel` does not support device_map="auto"
|
||||
# device_map="auto"
|
||||
)
|
||||
|
||||
print("done loading CLIP text encoder")
|
||||
|
||||
# done text encoder
|
||||
|
||||
# scheduler
|
||||
|
||||
scheduler_model = VQDiffusionScheduler(
|
||||
# the scheduler has the same number of embeddings as the transformer
|
||||
num_vec_classes=transformer_model.num_vector_embeds
|
||||
)
|
||||
|
||||
# done scheduler
|
||||
|
||||
print(f"saving VQ diffusion model, path: {args.dump_path}")
|
||||
|
||||
pipe = VQDiffusionPipeline(
|
||||
vqvae=vqvae_model,
|
||||
transformer=transformer_model,
|
||||
tokenizer=tokenizer_model,
|
||||
text_encoder=text_encoder_model,
|
||||
scheduler=scheduler_model,
|
||||
)
|
||||
pipe.save_pretrained(args.dump_path)
|
||||
|
||||
print("done writing VQ diffusion model")
|
||||
11
setup.py
11
setup.py
@@ -89,10 +89,11 @@ _deps = [
|
||||
"huggingface-hub>=0.10.0",
|
||||
"importlib_metadata",
|
||||
"isort>=5.5.4",
|
||||
"jax>=0.2.8,!=0.3.2",
|
||||
"jaxlib>=0.1.65",
|
||||
"jax>=0.2.8,!=0.3.2,<=0.3.6",
|
||||
"jaxlib>=0.1.65,<=0.3.6",
|
||||
"modelcards>=0.1.4",
|
||||
"numpy",
|
||||
"onnxruntime",
|
||||
"parameterized",
|
||||
"pytest",
|
||||
"pytest-timeout",
|
||||
@@ -178,7 +179,9 @@ extras["quality"] = deps_list("black", "isort", "flake8", "hf-doc-builder")
|
||||
extras["docs"] = deps_list("hf-doc-builder")
|
||||
extras["training"] = deps_list("accelerate", "datasets", "tensorboard", "modelcards")
|
||||
extras["test"] = deps_list(
|
||||
"accelerate",
|
||||
"datasets",
|
||||
"onnxruntime",
|
||||
"parameterized",
|
||||
"pytest",
|
||||
"pytest-timeout",
|
||||
@@ -187,7 +190,7 @@ extras["test"] = deps_list(
|
||||
"torchvision",
|
||||
"transformers"
|
||||
)
|
||||
extras["torch"] = deps_list("torch", "accelerate")
|
||||
extras["torch"] = deps_list("torch")
|
||||
|
||||
if os.name == "nt": # windows
|
||||
extras["flax"] = [] # jax is not supported on windows
|
||||
@@ -210,7 +213,7 @@ install_requires = [
|
||||
|
||||
setup(
|
||||
name="diffusers",
|
||||
version="0.8.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
version="0.7.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
description="Diffusers",
|
||||
long_description=open("README.md", "r", encoding="utf-8").read(),
|
||||
long_description_content_type="text/markdown",
|
||||
|
||||
@@ -9,7 +9,7 @@ from .utils import (
|
||||
)
|
||||
|
||||
|
||||
__version__ = "0.8.0.dev0"
|
||||
__version__ = "0.7.0.dev0"
|
||||
|
||||
from .configuration_utils import ConfigMixin
|
||||
from .onnx_utils import OnnxRuntimeModel
|
||||
@@ -18,7 +18,7 @@ from .utils import logging
|
||||
|
||||
if is_torch_available():
|
||||
from .modeling_utils import ModelMixin
|
||||
from .models import AutoencoderKL, Transformer2DModel, UNet1DModel, UNet2DConditionModel, UNet2DModel, VQModel
|
||||
from .models import AutoencoderKL, UNet1DModel, UNet2DConditionModel, UNet2DModel, VQModel
|
||||
from .optimization import (
|
||||
get_constant_schedule,
|
||||
get_constant_schedule_with_warmup,
|
||||
@@ -36,22 +36,16 @@ if is_torch_available():
|
||||
KarrasVePipeline,
|
||||
LDMPipeline,
|
||||
PNDMPipeline,
|
||||
RePaintPipeline,
|
||||
ScoreSdeVePipeline,
|
||||
)
|
||||
from .schedulers import (
|
||||
DDIMScheduler,
|
||||
DDPMScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
IPNDMScheduler,
|
||||
KarrasVeScheduler,
|
||||
PNDMScheduler,
|
||||
RePaintScheduler,
|
||||
SchedulerMixin,
|
||||
ScoreSdeVeScheduler,
|
||||
VQDiffusionScheduler,
|
||||
)
|
||||
from .training_utils import EMAModel
|
||||
else:
|
||||
@@ -64,13 +58,11 @@ else:
|
||||
|
||||
if is_torch_available() and is_transformers_available():
|
||||
from .pipelines import (
|
||||
CycleDiffusionPipeline,
|
||||
LDMTextToImagePipeline,
|
||||
StableDiffusionImg2ImgPipeline,
|
||||
StableDiffusionInpaintPipeline,
|
||||
StableDiffusionInpaintPipelineLegacy,
|
||||
StableDiffusionPipeline,
|
||||
VQDiffusionPipeline,
|
||||
)
|
||||
else:
|
||||
from .utils.dummy_torch_and_transformers_objects import * # noqa F403
|
||||
@@ -93,7 +85,6 @@ if is_flax_available():
|
||||
from .schedulers import (
|
||||
FlaxDDIMScheduler,
|
||||
FlaxDDPMScheduler,
|
||||
FlaxDPMSolverMultistepScheduler,
|
||||
FlaxKarrasVeScheduler,
|
||||
FlaxLMSDiscreteScheduler,
|
||||
FlaxPNDMScheduler,
|
||||
|
||||
@@ -16,7 +16,6 @@
|
||||
""" ConfigMixin base class and utilities."""
|
||||
import dataclasses
|
||||
import functools
|
||||
import importlib
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
@@ -49,13 +48,9 @@ class ConfigMixin:
|
||||
[`~ConfigMixin.save_config`] (should be overridden by parent class).
|
||||
- **ignore_for_config** (`List[str]`) -- A list of attributes that should not be saved in the config (should be
|
||||
overridden by parent class).
|
||||
- **_compatible_classes** (`List[str]`) -- A list of classes that are compatible with the parent class, so that
|
||||
`from_config` can be used from a class different than the one used to save the config (should be overridden
|
||||
by parent class).
|
||||
"""
|
||||
config_name = None
|
||||
ignore_for_config = []
|
||||
_compatible_classes = []
|
||||
|
||||
def register_to_config(self, **kwargs):
|
||||
if self.config_name is None:
|
||||
@@ -285,14 +280,9 @@ class ConfigMixin:
|
||||
|
||||
return config_dict
|
||||
|
||||
@staticmethod
|
||||
def _get_init_keys(cls):
|
||||
return set(dict(inspect.signature(cls.__init__).parameters).keys())
|
||||
|
||||
@classmethod
|
||||
def extract_init_dict(cls, config_dict, **kwargs):
|
||||
# 1. Retrieve expected config attributes from __init__ signature
|
||||
expected_keys = cls._get_init_keys(cls)
|
||||
expected_keys = set(dict(inspect.signature(cls.__init__).parameters).keys())
|
||||
expected_keys.remove("self")
|
||||
# remove general kwargs if present in dict
|
||||
if "kwargs" in expected_keys:
|
||||
@@ -302,36 +292,9 @@ class ConfigMixin:
|
||||
for arg in cls._flax_internal_args:
|
||||
expected_keys.remove(arg)
|
||||
|
||||
# 2. Remove attributes that cannot be expected from expected config attributes
|
||||
# remove keys to be ignored
|
||||
if len(cls.ignore_for_config) > 0:
|
||||
expected_keys = expected_keys - set(cls.ignore_for_config)
|
||||
|
||||
# load diffusers library to import compatible and original scheduler
|
||||
diffusers_library = importlib.import_module(__name__.split(".")[0])
|
||||
|
||||
# remove attributes from compatible classes that orig cannot expect
|
||||
compatible_classes = [getattr(diffusers_library, c, None) for c in cls._compatible_classes]
|
||||
# filter out None potentially undefined dummy classes
|
||||
compatible_classes = [c for c in compatible_classes if c is not None]
|
||||
expected_keys_comp_cls = set()
|
||||
for c in compatible_classes:
|
||||
expected_keys_c = cls._get_init_keys(c)
|
||||
expected_keys_comp_cls = expected_keys_comp_cls.union(expected_keys_c)
|
||||
expected_keys_comp_cls = expected_keys_comp_cls - cls._get_init_keys(cls)
|
||||
config_dict = {k: v for k, v in config_dict.items() if k not in expected_keys_comp_cls}
|
||||
|
||||
# remove attributes from orig class that cannot be expected
|
||||
orig_cls_name = config_dict.pop("_class_name", cls.__name__)
|
||||
if orig_cls_name != cls.__name__ and hasattr(diffusers_library, orig_cls_name):
|
||||
orig_cls = getattr(diffusers_library, orig_cls_name)
|
||||
unexpected_keys_from_orig = cls._get_init_keys(orig_cls) - expected_keys
|
||||
config_dict = {k: v for k, v in config_dict.items() if k not in unexpected_keys_from_orig}
|
||||
|
||||
# remove private attributes
|
||||
config_dict = {k: v for k, v in config_dict.items() if not k.startswith("_")}
|
||||
|
||||
# 3. Create keyword arguments that will be passed to __init__ from expected keyword arguments
|
||||
init_dict = {}
|
||||
for key in expected_keys:
|
||||
if key in kwargs:
|
||||
@@ -341,7 +304,8 @@ class ConfigMixin:
|
||||
# use value from config dict
|
||||
init_dict[key] = config_dict.pop(key)
|
||||
|
||||
# 4. Give nice warning if unexpected values have been passed
|
||||
config_dict = {k: v for k, v in config_dict.items() if not k.startswith("_")}
|
||||
|
||||
if len(config_dict) > 0:
|
||||
logger.warning(
|
||||
f"The config attributes {config_dict} were passed to {cls.__name__}, "
|
||||
@@ -349,16 +313,14 @@ class ConfigMixin:
|
||||
f"{cls.config_name} configuration file."
|
||||
)
|
||||
|
||||
# 5. Give nice info if config attributes are initiliazed to default because they have not been passed
|
||||
unused_kwargs = {**config_dict, **kwargs}
|
||||
|
||||
passed_keys = set(init_dict.keys())
|
||||
if len(expected_keys - passed_keys) > 0:
|
||||
logger.info(
|
||||
f"{expected_keys - passed_keys} was not found in config. Values will be initialized to default values."
|
||||
)
|
||||
|
||||
# 6. Define unused keyword arguments
|
||||
unused_kwargs = {**config_dict, **kwargs}
|
||||
|
||||
return init_dict, unused_kwargs
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -13,10 +13,11 @@ deps = {
|
||||
"huggingface-hub": "huggingface-hub>=0.10.0",
|
||||
"importlib_metadata": "importlib_metadata",
|
||||
"isort": "isort>=5.5.4",
|
||||
"jax": "jax>=0.2.8,!=0.3.2",
|
||||
"jaxlib": "jaxlib>=0.1.65",
|
||||
"jax": "jax>=0.2.8,!=0.3.2,<=0.3.6",
|
||||
"jaxlib": "jaxlib>=0.1.65,<=0.3.6",
|
||||
"modelcards": "modelcards>=0.1.4",
|
||||
"numpy": "numpy",
|
||||
"onnxruntime": "onnxruntime",
|
||||
"parameterized": "parameterized",
|
||||
"pytest": "pytest",
|
||||
"pytest-timeout": "pytest-timeout",
|
||||
|
||||
@@ -16,25 +16,13 @@
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional, Union
|
||||
from uuid import uuid4
|
||||
from typing import Optional
|
||||
|
||||
from huggingface_hub import HfFolder, Repository, whoami
|
||||
|
||||
from . import __version__
|
||||
from .utils import ENV_VARS_TRUE_VALUES, deprecate, logging
|
||||
from .utils.import_utils import (
|
||||
_flax_version,
|
||||
_jax_version,
|
||||
_onnxruntime_version,
|
||||
_torch_version,
|
||||
is_flax_available,
|
||||
is_modelcards_available,
|
||||
is_onnx_available,
|
||||
is_torch_available,
|
||||
)
|
||||
from .pipeline_utils import DiffusionPipeline
|
||||
from .utils import deprecate, is_modelcards_available, logging
|
||||
|
||||
|
||||
if is_modelcards_available():
|
||||
@@ -45,32 +33,6 @@ logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
MODEL_CARD_TEMPLATE_PATH = Path(__file__).parent / "utils" / "model_card_template.md"
|
||||
SESSION_ID = uuid4().hex
|
||||
DISABLE_TELEMETRY = os.getenv("DISABLE_TELEMETRY", "").upper() in ENV_VARS_TRUE_VALUES
|
||||
|
||||
|
||||
def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str:
|
||||
"""
|
||||
Formats a user-agent string with basic info about a request.
|
||||
"""
|
||||
ua = f"diffusers/{__version__}; python/{sys.version.split()[0]}; session_id/{SESSION_ID}"
|
||||
if DISABLE_TELEMETRY:
|
||||
return ua + "; telemetry/off"
|
||||
if is_torch_available():
|
||||
ua += f"; torch/{_torch_version}"
|
||||
if is_flax_available():
|
||||
ua += f"; jax/{_jax_version}"
|
||||
ua += f"; flax/{_flax_version}"
|
||||
if is_onnx_available():
|
||||
ua += f"; onnxruntime/{_onnxruntime_version}"
|
||||
# CI will set this value to True
|
||||
if os.environ.get("DIFFUSERS_IS_CI", "").upper() in ENV_VARS_TRUE_VALUES:
|
||||
ua += "; is_ci/true"
|
||||
if isinstance(user_agent, dict):
|
||||
ua += "; " + "; ".join(f"{k}/{v}" for k, v in user_agent.items())
|
||||
elif isinstance(user_agent, str):
|
||||
ua += "; " + user_agent
|
||||
return ua
|
||||
|
||||
|
||||
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
|
||||
@@ -139,7 +101,7 @@ def init_git_repo(args, at_init: bool = False):
|
||||
|
||||
def push_to_hub(
|
||||
args,
|
||||
pipeline,
|
||||
pipeline: DiffusionPipeline,
|
||||
repo: Repository,
|
||||
commit_message: Optional[str] = "End of training",
|
||||
blocking: bool = True,
|
||||
|
||||
@@ -21,37 +21,18 @@ from typing import Callable, List, Optional, Tuple, Union
|
||||
import torch
|
||||
from torch import Tensor, device
|
||||
|
||||
from diffusers.utils import is_accelerate_available
|
||||
from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
|
||||
from requests import HTTPError
|
||||
|
||||
from . import __version__
|
||||
from .utils import (
|
||||
CONFIG_NAME,
|
||||
DIFFUSERS_CACHE,
|
||||
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
|
||||
WEIGHTS_NAME,
|
||||
is_accelerate_available,
|
||||
is_torch_version,
|
||||
logging,
|
||||
)
|
||||
from .utils import CONFIG_NAME, DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, WEIGHTS_NAME, logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
if is_torch_version(">=", "1.9.0"):
|
||||
_LOW_CPU_MEM_USAGE_DEFAULT = True
|
||||
else:
|
||||
_LOW_CPU_MEM_USAGE_DEFAULT = False
|
||||
|
||||
|
||||
if is_accelerate_available():
|
||||
import accelerate
|
||||
from accelerate.utils import set_module_tensor_to_device
|
||||
from accelerate.utils.versions import is_torch_version
|
||||
|
||||
|
||||
def get_parameter_device(parameter: torch.nn.Module):
|
||||
try:
|
||||
return next(parameter.parameters()).device
|
||||
@@ -287,19 +268,6 @@ class ModelMixin(torch.nn.Module):
|
||||
Mirror source to accelerate downloads in China. If you are from China and have an accessibility
|
||||
problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
|
||||
Please refer to the mirror site for more information.
|
||||
device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
|
||||
A map that specifies where each submodule should go. It doesn't need to be refined to each
|
||||
parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the
|
||||
same device.
|
||||
|
||||
To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For
|
||||
more information about each option see [designing a device
|
||||
map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
|
||||
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
|
||||
Speed up model loading by not initializing the weights and only loading the pre-trained weights. This
|
||||
also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the
|
||||
model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch,
|
||||
setting this argument to `True` will raise an error.
|
||||
|
||||
<Tip>
|
||||
|
||||
@@ -328,41 +296,6 @@ class ModelMixin(torch.nn.Module):
|
||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||
subfolder = kwargs.pop("subfolder", None)
|
||||
device_map = kwargs.pop("device_map", None)
|
||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
||||
|
||||
if low_cpu_mem_usage and not is_accelerate_available():
|
||||
low_cpu_mem_usage = False
|
||||
logger.warn(
|
||||
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
||||
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
||||
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
||||
" install accelerate\n```\n."
|
||||
)
|
||||
|
||||
if device_map is not None and not is_accelerate_available():
|
||||
raise NotImplementedError(
|
||||
"Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set"
|
||||
" `device_map=None`. You can install accelerate with `pip install accelerate`."
|
||||
)
|
||||
|
||||
# Check if we can handle device_map and dispatching the weights
|
||||
if device_map is not None and not is_torch_version(">=", "1.9.0"):
|
||||
raise NotImplementedError(
|
||||
"Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
||||
" `device_map=None`."
|
||||
)
|
||||
|
||||
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
||||
raise NotImplementedError(
|
||||
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
||||
" `low_cpu_mem_usage=False`."
|
||||
)
|
||||
|
||||
if low_cpu_mem_usage is False and device_map is not None:
|
||||
raise ValueError(
|
||||
f"You cannot set `low_cpu_mem_usage` to `False` while using device_map={device_map} for loading and"
|
||||
" dispatching. Please make sure to set `low_cpu_mem_usage=True`."
|
||||
)
|
||||
|
||||
user_agent = {
|
||||
"diffusers": __version__,
|
||||
@@ -445,8 +378,12 @@ class ModelMixin(torch.nn.Module):
|
||||
|
||||
# restore default dtype
|
||||
|
||||
if low_cpu_mem_usage:
|
||||
# Instantiate model with empty weights
|
||||
if device_map == "auto":
|
||||
if is_accelerate_available():
|
||||
import accelerate
|
||||
else:
|
||||
raise ImportError("Please install accelerate via `pip install accelerate`")
|
||||
|
||||
with accelerate.init_empty_weights():
|
||||
model, unused_kwargs = cls.from_config(
|
||||
config_path,
|
||||
@@ -463,17 +400,7 @@ class ModelMixin(torch.nn.Module):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# if device_map is Non,e load the state dict on move the params from meta device to the cpu
|
||||
if device_map is None:
|
||||
param_device = "cpu"
|
||||
state_dict = load_state_dict(model_file)
|
||||
# move the parms from meta device to cpu
|
||||
for param_name, param in state_dict.items():
|
||||
set_module_tensor_to_device(model, param_name, param_device, value=param)
|
||||
else: # else let accelerate handle loading and dispatching.
|
||||
# Load weights and dispatch according to the device_map
|
||||
# by deafult the device_map is None and the weights are loaded on the CPU
|
||||
accelerate.load_checkpoint_and_dispatch(model, model_file, device_map)
|
||||
accelerate.load_checkpoint_and_dispatch(model, model_file, device_map)
|
||||
|
||||
loading_info = {
|
||||
"missing_keys": [],
|
||||
|
||||
@@ -16,7 +16,6 @@ from ..utils import is_flax_available, is_torch_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
from .attention import Transformer2DModel
|
||||
from .unet_1d import UNet1DModel
|
||||
from .unet_2d import UNet2DModel
|
||||
from .unet_2d_condition import UNet2DConditionModel
|
||||
|
||||
@@ -12,218 +12,12 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..models.embeddings import ImagePositionalEmbeddings
|
||||
from ..utils import BaseOutput
|
||||
from ..utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
@dataclass
|
||||
class Transformer2DModelOutput(BaseOutput):
|
||||
"""
|
||||
Args:
|
||||
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
|
||||
Hidden states conditioned on `encoder_hidden_states` input. If discrete, returns probability distributions
|
||||
for the unnoised latent pixels.
|
||||
"""
|
||||
|
||||
sample: torch.FloatTensor
|
||||
|
||||
|
||||
if is_xformers_available():
|
||||
import xformers
|
||||
import xformers.ops
|
||||
else:
|
||||
xformers = None
|
||||
|
||||
|
||||
class Transformer2DModel(ModelMixin, ConfigMixin):
|
||||
"""
|
||||
Transformer model for image-like data. Takes either discrete (classes of vector embeddings) or continuous (actual
|
||||
embeddings) inputs.
|
||||
|
||||
When input is continuous: First, project the input (aka embedding) and reshape to b, t, d. Then apply standard
|
||||
transformer action. Finally, reshape to image.
|
||||
|
||||
When input is discrete: First, input (classes of latent pixels) is converted to embeddings and has positional
|
||||
embeddings applied, see `ImagePositionalEmbeddings`. Then apply standard transformer action. Finally, predict
|
||||
classes of unnoised image.
|
||||
|
||||
Note that it is assumed one of the input classes is the masked latent pixel. The predicted classes of the unnoised
|
||||
image do not contain a prediction for the masked pixel as the unnoised image cannot be masked.
|
||||
|
||||
Parameters:
|
||||
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
|
||||
in_channels (`int`, *optional*):
|
||||
Pass if the input is continuous. The number of channels in the input and output.
|
||||
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
|
||||
dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use.
|
||||
cross_attention_dim (`int`, *optional*): The number of context dimensions to use.
|
||||
sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
|
||||
Note that this is fixed at training time as it is used for learning a number of position embeddings. See
|
||||
`ImagePositionalEmbeddings`.
|
||||
num_vector_embeds (`int`, *optional*):
|
||||
Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels.
|
||||
Includes the class for the masked latent pixel.
|
||||
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
||||
num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`.
|
||||
The number of diffusion steps used during training. Note that this is fixed at training time as it is used
|
||||
to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for
|
||||
up to but not more than steps than `num_embeds_ada_norm`.
|
||||
attention_bias (`bool`, *optional*):
|
||||
Configure if the TransformerBlocks' attention should contain a bias parameter.
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_attention_heads: int = 16,
|
||||
attention_head_dim: int = 88,
|
||||
in_channels: Optional[int] = None,
|
||||
num_layers: int = 1,
|
||||
dropout: float = 0.0,
|
||||
norm_num_groups: int = 32,
|
||||
cross_attention_dim: Optional[int] = None,
|
||||
attention_bias: bool = False,
|
||||
sample_size: Optional[int] = None,
|
||||
num_vector_embeds: Optional[int] = None,
|
||||
activation_fn: str = "geglu",
|
||||
num_embeds_ada_norm: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.attention_head_dim = attention_head_dim
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
|
||||
# 1. Transformer2DModel can process both standard continous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
|
||||
# Define whether input is continuous or discrete depending on configuration
|
||||
self.is_input_continuous = in_channels is not None
|
||||
self.is_input_vectorized = num_vector_embeds is not None
|
||||
|
||||
if self.is_input_continuous and self.is_input_vectorized:
|
||||
raise ValueError(
|
||||
f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
|
||||
" sure that either `in_channels` or `num_vector_embeds` is None."
|
||||
)
|
||||
elif not self.is_input_continuous and not self.is_input_vectorized:
|
||||
raise ValueError(
|
||||
f"Has to define either `in_channels`: {in_channels} or `num_vector_embeds`: {num_vector_embeds}. Make"
|
||||
" sure that either `in_channels` or `num_vector_embeds` is not None."
|
||||
)
|
||||
|
||||
# 2. Define input layers
|
||||
if self.is_input_continuous:
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
|
||||
elif self.is_input_vectorized:
|
||||
assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
|
||||
assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
|
||||
|
||||
self.height = sample_size
|
||||
self.width = sample_size
|
||||
self.num_vector_embeds = num_vector_embeds
|
||||
self.num_latent_pixels = self.height * self.width
|
||||
|
||||
self.latent_image_embedding = ImagePositionalEmbeddings(
|
||||
num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
|
||||
)
|
||||
|
||||
# 3. Define transformers blocks
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
BasicTransformerBlock(
|
||||
inner_dim,
|
||||
num_attention_heads,
|
||||
attention_head_dim,
|
||||
dropout=dropout,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
activation_fn=activation_fn,
|
||||
num_embeds_ada_norm=num_embeds_ada_norm,
|
||||
attention_bias=attention_bias,
|
||||
)
|
||||
for d in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
# 4. Define output layers
|
||||
if self.is_input_continuous:
|
||||
self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
elif self.is_input_vectorized:
|
||||
self.norm_out = nn.LayerNorm(inner_dim)
|
||||
self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
|
||||
|
||||
def _set_attention_slice(self, slice_size):
|
||||
for block in self.transformer_blocks:
|
||||
block._set_attention_slice(slice_size)
|
||||
|
||||
def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True):
|
||||
"""
|
||||
Args:
|
||||
hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
|
||||
When continous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
|
||||
hidden_states
|
||||
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, context dim)`, *optional*):
|
||||
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
|
||||
self-attention.
|
||||
timestep ( `torch.long`, *optional*):
|
||||
Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
[`~models.attention.Transformer2DModelOutput`] or `tuple`: [`~models.attention.Transformer2DModelOutput`]
|
||||
if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample
|
||||
tensor.
|
||||
"""
|
||||
# 1. Input
|
||||
if self.is_input_continuous:
|
||||
batch, channel, height, weight = hidden_states.shape
|
||||
residual = hidden_states
|
||||
hidden_states = self.norm(hidden_states)
|
||||
hidden_states = self.proj_in(hidden_states)
|
||||
inner_dim = hidden_states.shape[1]
|
||||
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
|
||||
elif self.is_input_vectorized:
|
||||
hidden_states = self.latent_image_embedding(hidden_states)
|
||||
|
||||
# 2. Blocks
|
||||
for block in self.transformer_blocks:
|
||||
hidden_states = block(hidden_states, context=encoder_hidden_states, timestep=timestep)
|
||||
|
||||
# 3. Output
|
||||
if self.is_input_continuous:
|
||||
hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
output = hidden_states + residual
|
||||
elif self.is_input_vectorized:
|
||||
hidden_states = self.norm_out(hidden_states)
|
||||
logits = self.out(hidden_states)
|
||||
# (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
|
||||
logits = logits.permute(0, 2, 1)
|
||||
|
||||
# log(p(x_0))
|
||||
output = F.log_softmax(logits.double(), dim=1).float()
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
return Transformer2DModelOutput(sample=output)
|
||||
|
||||
def _set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
|
||||
for block in self.transformer_blocks:
|
||||
block._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
|
||||
|
||||
|
||||
class AttentionBlock(nn.Module):
|
||||
"""
|
||||
@@ -233,19 +27,19 @@ class AttentionBlock(nn.Module):
|
||||
Uses three q, k, v linear layers to compute attention.
|
||||
|
||||
Parameters:
|
||||
channels (`int`): The number of channels in the input and output.
|
||||
num_head_channels (`int`, *optional*):
|
||||
channels (:obj:`int`): The number of channels in the input and output.
|
||||
num_head_channels (:obj:`int`, *optional*):
|
||||
The number of channels in each head. If None, then `num_heads` = 1.
|
||||
norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for group norm.
|
||||
rescale_output_factor (`float`, *optional*, defaults to 1.0): The factor to rescale the output by.
|
||||
eps (`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm.
|
||||
num_groups (:obj:`int`, *optional*, defaults to 32): The number of groups to use for group norm.
|
||||
rescale_output_factor (:obj:`float`, *optional*, defaults to 1.0): The factor to rescale the output by.
|
||||
eps (:obj:`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
num_head_channels: Optional[int] = None,
|
||||
norm_num_groups: int = 32,
|
||||
num_groups: int = 32,
|
||||
rescale_output_factor: float = 1.0,
|
||||
eps: float = 1e-5,
|
||||
):
|
||||
@@ -254,7 +48,7 @@ class AttentionBlock(nn.Module):
|
||||
|
||||
self.num_heads = channels // num_head_channels if num_head_channels is not None else 1
|
||||
self.num_head_size = num_head_channels
|
||||
self.group_norm = nn.GroupNorm(num_channels=channels, num_groups=norm_num_groups, eps=eps, affine=True)
|
||||
self.group_norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=eps, affine=True)
|
||||
|
||||
# define q,k,v as linear layers
|
||||
self.query = nn.Linear(channels, channels)
|
||||
@@ -310,108 +104,112 @@ class AttentionBlock(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class SpatialTransformer(nn.Module):
|
||||
"""
|
||||
Transformer block for image-like data. First, project the input (aka embedding) and reshape to b, t, d. Then apply
|
||||
standard transformer action. Finally, reshape to image.
|
||||
|
||||
Parameters:
|
||||
in_channels (:obj:`int`): The number of channels in the input and output.
|
||||
n_heads (:obj:`int`): The number of heads to use for multi-head attention.
|
||||
d_head (:obj:`int`): The number of channels in each head.
|
||||
depth (:obj:`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
|
||||
dropout (:obj:`float`, *optional*, defaults to 0.1): The dropout probability to use.
|
||||
context_dim (:obj:`int`, *optional*): The number of context dimensions to use.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
n_heads: int,
|
||||
d_head: int,
|
||||
depth: int = 1,
|
||||
dropout: float = 0.0,
|
||||
num_groups: int = 32,
|
||||
context_dim: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.n_heads = n_heads
|
||||
self.d_head = d_head
|
||||
self.in_channels = in_channels
|
||||
inner_dim = n_heads * d_head
|
||||
self.norm = torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
|
||||
self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
|
||||
for d in range(depth)
|
||||
]
|
||||
)
|
||||
|
||||
self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
def _set_attention_slice(self, slice_size):
|
||||
for block in self.transformer_blocks:
|
||||
block._set_attention_slice(slice_size)
|
||||
|
||||
def forward(self, hidden_states, context=None):
|
||||
# note: if no context is given, cross-attention defaults to self-attention
|
||||
batch, channel, height, weight = hidden_states.shape
|
||||
residual = hidden_states
|
||||
hidden_states = self.norm(hidden_states)
|
||||
hidden_states = self.proj_in(hidden_states)
|
||||
inner_dim = hidden_states.shape[1]
|
||||
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
|
||||
for block in self.transformer_blocks:
|
||||
hidden_states = block(hidden_states, context=context)
|
||||
hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
return hidden_states + residual
|
||||
|
||||
|
||||
class BasicTransformerBlock(nn.Module):
|
||||
r"""
|
||||
A basic Transformer block.
|
||||
|
||||
Parameters:
|
||||
dim (`int`): The number of channels in the input and output.
|
||||
num_attention_heads (`int`): The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`): The number of channels in each head.
|
||||
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
||||
cross_attention_dim (`int`, *optional*): The size of the context vector for cross attention.
|
||||
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
||||
num_embeds_ada_norm (:
|
||||
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
|
||||
attention_bias (:
|
||||
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
|
||||
dim (:obj:`int`): The number of channels in the input and output.
|
||||
n_heads (:obj:`int`): The number of heads to use for multi-head attention.
|
||||
d_head (:obj:`int`): The number of channels in each head.
|
||||
dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
||||
context_dim (:obj:`int`, *optional*): The size of the context vector for cross attention.
|
||||
gated_ff (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use a gated feed-forward network.
|
||||
checkpoint (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use checkpointing.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
n_heads: int,
|
||||
d_head: int,
|
||||
dropout=0.0,
|
||||
cross_attention_dim: Optional[int] = None,
|
||||
activation_fn: str = "geglu",
|
||||
num_embeds_ada_norm: Optional[int] = None,
|
||||
attention_bias: bool = False,
|
||||
context_dim: Optional[int] = None,
|
||||
gated_ff: bool = True,
|
||||
checkpoint: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.attn1 = CrossAttention(
|
||||
query_dim=dim,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
dropout=dropout,
|
||||
bias=attention_bias,
|
||||
query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout
|
||||
) # is a self-attention
|
||||
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
|
||||
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
|
||||
self.attn2 = CrossAttention(
|
||||
query_dim=dim,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
dropout=dropout,
|
||||
bias=attention_bias,
|
||||
query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, dropout=dropout
|
||||
) # is self-attn if context is none
|
||||
|
||||
# layer norms
|
||||
self.use_ada_layer_norm = num_embeds_ada_norm is not None
|
||||
if self.use_ada_layer_norm:
|
||||
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
|
||||
self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm)
|
||||
else:
|
||||
self.norm1 = nn.LayerNorm(dim)
|
||||
self.norm2 = nn.LayerNorm(dim)
|
||||
self.norm1 = nn.LayerNorm(dim)
|
||||
self.norm2 = nn.LayerNorm(dim)
|
||||
self.norm3 = nn.LayerNorm(dim)
|
||||
self.checkpoint = checkpoint
|
||||
|
||||
def _set_attention_slice(self, slice_size):
|
||||
self.attn1._slice_size = slice_size
|
||||
self.attn2._slice_size = slice_size
|
||||
|
||||
def _set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
|
||||
if not is_xformers_available():
|
||||
print("Here is how to install it")
|
||||
raise ModuleNotFoundError(
|
||||
"Refer to https://github.com/facebookresearch/xformers for more information on how to install"
|
||||
" xformers",
|
||||
name="xformers",
|
||||
)
|
||||
elif not torch.cuda.is_available():
|
||||
raise ValueError(
|
||||
"torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
|
||||
" available for GPU "
|
||||
)
|
||||
else:
|
||||
try:
|
||||
# Make sure we can run the memory efficient attention
|
||||
_ = xformers.ops.memory_efficient_attention(
|
||||
torch.randn((1, 2, 40), device="cuda"),
|
||||
torch.randn((1, 2, 40), device="cuda"),
|
||||
torch.randn((1, 2, 40), device="cuda"),
|
||||
)
|
||||
except Exception as e:
|
||||
raise e
|
||||
self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
|
||||
self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
|
||||
|
||||
def forward(self, hidden_states, context=None, timestep=None):
|
||||
# 1. Self-Attention
|
||||
norm_hidden_states = (
|
||||
self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states)
|
||||
)
|
||||
hidden_states = self.attn1(norm_hidden_states) + hidden_states
|
||||
|
||||
# 2. Cross-Attention
|
||||
norm_hidden_states = (
|
||||
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
|
||||
)
|
||||
hidden_states = self.attn2(norm_hidden_states, context=context) + hidden_states
|
||||
|
||||
# 3. Feed-forward
|
||||
def forward(self, hidden_states, context=None):
|
||||
hidden_states = self.attn1(self.norm1(hidden_states)) + hidden_states
|
||||
hidden_states = self.attn2(self.norm2(hidden_states), context=context) + hidden_states
|
||||
hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
@@ -420,28 +218,20 @@ class CrossAttention(nn.Module):
|
||||
A cross attention layer.
|
||||
|
||||
Parameters:
|
||||
query_dim (`int`): The number of channels in the query.
|
||||
cross_attention_dim (`int`, *optional*):
|
||||
query_dim (:obj:`int`): The number of channels in the query.
|
||||
context_dim (:obj:`int`, *optional*):
|
||||
The number of channels in the context. If not given, defaults to `query_dim`.
|
||||
heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
|
||||
dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
|
||||
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
||||
bias (`bool`, *optional*, defaults to False):
|
||||
Set to `True` for the query, key, and value linear layers to contain a bias parameter.
|
||||
heads (:obj:`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
|
||||
dim_head (:obj:`int`, *optional*, defaults to 64): The number of channels in each head.
|
||||
dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
query_dim: int,
|
||||
cross_attention_dim: Optional[int] = None,
|
||||
heads: int = 8,
|
||||
dim_head: int = 64,
|
||||
dropout: float = 0.0,
|
||||
bias=False,
|
||||
self, query_dim: int, context_dim: Optional[int] = None, heads: int = 8, dim_head: int = 64, dropout: int = 0.0
|
||||
):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
|
||||
context_dim = context_dim if context_dim is not None else query_dim
|
||||
|
||||
self.scale = dim_head**-0.5
|
||||
self.heads = heads
|
||||
@@ -449,15 +239,12 @@ class CrossAttention(nn.Module):
|
||||
# is split across the batch axis to save memory
|
||||
# You can set slice_size with `set_attention_slice`
|
||||
self._slice_size = None
|
||||
self._use_memory_efficient_attention_xformers = False
|
||||
|
||||
self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
|
||||
self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
|
||||
self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
|
||||
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
||||
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
|
||||
self.to_out = nn.ModuleList([])
|
||||
self.to_out.append(nn.Linear(inner_dim, query_dim))
|
||||
self.to_out.append(nn.Dropout(dropout))
|
||||
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
|
||||
|
||||
def reshape_heads_to_batch_dim(self, tensor):
|
||||
batch_size, seq_len, dim = tensor.shape
|
||||
@@ -490,19 +277,13 @@ class CrossAttention(nn.Module):
|
||||
# TODO(PVP) - mask is currently never used. Remember to re-implement when used
|
||||
|
||||
# attention, what we cannot get enough of
|
||||
if self._use_memory_efficient_attention_xformers:
|
||||
hidden_states = self._memory_efficient_attention_xformers(query, key, value)
|
||||
else:
|
||||
if self._slice_size is None or query.shape[0] // self._slice_size == 1:
|
||||
hidden_states = self._attention(query, key, value)
|
||||
else:
|
||||
hidden_states = self._sliced_attention(query, key, value, sequence_length, dim)
|
||||
|
||||
# linear proj
|
||||
hidden_states = self.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = self.to_out[1](hidden_states)
|
||||
return hidden_states
|
||||
if self._slice_size is None or query.shape[0] // self._slice_size == 1:
|
||||
hidden_states = self._attention(query, key, value)
|
||||
else:
|
||||
hidden_states = self._sliced_attention(query, key, value, sequence_length, dim)
|
||||
|
||||
return self.to_out(hidden_states)
|
||||
|
||||
def _attention(self, query, key, value):
|
||||
# TODO: use baddbmm for better performance
|
||||
@@ -554,53 +335,31 @@ class CrossAttention(nn.Module):
|
||||
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
def _memory_efficient_attention_xformers(self, query, key, value):
|
||||
hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=None)
|
||||
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
r"""
|
||||
A feed-forward layer.
|
||||
|
||||
Parameters:
|
||||
dim (`int`): The number of channels in the input.
|
||||
dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
|
||||
mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
|
||||
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
||||
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
||||
dim (:obj:`int`): The number of channels in the input.
|
||||
dim_out (:obj:`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
|
||||
mult (:obj:`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
|
||||
glu (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use GLU activation.
|
||||
dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
dim_out: Optional[int] = None,
|
||||
mult: int = 4,
|
||||
dropout: float = 0.0,
|
||||
activation_fn: str = "geglu",
|
||||
self, dim: int, dim_out: Optional[int] = None, mult: int = 4, glu: bool = False, dropout: float = 0.0
|
||||
):
|
||||
super().__init__()
|
||||
inner_dim = int(dim * mult)
|
||||
dim_out = dim_out if dim_out is not None else dim
|
||||
project_in = GEGLU(dim, inner_dim)
|
||||
|
||||
if activation_fn == "geglu":
|
||||
geglu = GEGLU(dim, inner_dim)
|
||||
elif activation_fn == "geglu-approximate":
|
||||
geglu = ApproximateGELU(dim, inner_dim)
|
||||
|
||||
self.net = nn.ModuleList([])
|
||||
# project in
|
||||
self.net.append(geglu)
|
||||
# project dropout
|
||||
self.net.append(nn.Dropout(dropout))
|
||||
# project out
|
||||
self.net.append(nn.Linear(inner_dim, dim_out))
|
||||
self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out))
|
||||
|
||||
def forward(self, hidden_states):
|
||||
for module in self.net:
|
||||
hidden_states = module(hidden_states)
|
||||
return hidden_states
|
||||
return self.net(hidden_states)
|
||||
|
||||
|
||||
# feedforward
|
||||
@@ -609,8 +368,8 @@ class GEGLU(nn.Module):
|
||||
A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
|
||||
|
||||
Parameters:
|
||||
dim_in (`int`): The number of channels in the input.
|
||||
dim_out (`int`): The number of channels in the output.
|
||||
dim_in (:obj:`int`): The number of channels in the input.
|
||||
dim_out (:obj:`int`): The number of channels in the output.
|
||||
"""
|
||||
|
||||
def __init__(self, dim_in: int, dim_out: int):
|
||||
@@ -626,38 +385,3 @@ class GEGLU(nn.Module):
|
||||
def forward(self, hidden_states):
|
||||
hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
|
||||
return hidden_states * self.gelu(gate)
|
||||
|
||||
|
||||
class ApproximateGELU(nn.Module):
|
||||
"""
|
||||
The approximate form of Gaussian Error Linear Unit (GELU)
|
||||
|
||||
For more details, see section 2: https://arxiv.org/abs/1606.08415
|
||||
"""
|
||||
|
||||
def __init__(self, dim_in: int, dim_out: int):
|
||||
super().__init__()
|
||||
self.proj = nn.Linear(dim_in, dim_out)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.proj(x)
|
||||
return x * torch.sigmoid(1.702 * x)
|
||||
|
||||
|
||||
class AdaLayerNorm(nn.Module):
|
||||
"""
|
||||
Norm layer modified to incorporate timestep embeddings.
|
||||
"""
|
||||
|
||||
def __init__(self, embedding_dim, num_embeddings):
|
||||
super().__init__()
|
||||
self.emb = nn.Embedding(num_embeddings, embedding_dim)
|
||||
self.silu = nn.SiLU()
|
||||
self.linear = nn.Linear(embedding_dim, embedding_dim * 2)
|
||||
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False)
|
||||
|
||||
def forward(self, x, timestep):
|
||||
emb = self.linear(self.silu(self.emb(timestep)))
|
||||
scale, shift = torch.chunk(emb, 2)
|
||||
x = self.norm(x) * (1 + scale) + shift
|
||||
return x
|
||||
|
||||
@@ -142,7 +142,7 @@ class FlaxBasicTransformerBlock(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FlaxTransformer2DModel(nn.Module):
|
||||
class FlaxSpatialTransformer(nn.Module):
|
||||
r"""
|
||||
A Spatial Transformer layer with Gated Linear Unit (GLU) activation function as described in:
|
||||
https://arxiv.org/pdf/1506.02025.pdf
|
||||
|
||||
@@ -126,68 +126,3 @@ class GaussianFourierProjection(nn.Module):
|
||||
else:
|
||||
out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
|
||||
return out
|
||||
|
||||
|
||||
class ImagePositionalEmbeddings(nn.Module):
|
||||
"""
|
||||
Converts latent image classes into vector embeddings. Sums the vector embeddings with positional embeddings for the
|
||||
height and width of the latent space.
|
||||
|
||||
For more details, see figure 10 of the dall-e paper: https://arxiv.org/abs/2102.12092
|
||||
|
||||
For VQ-diffusion:
|
||||
|
||||
Output vector embeddings are used as input for the transformer.
|
||||
|
||||
Note that the vector embeddings for the transformer are different than the vector embeddings from the VQVAE.
|
||||
|
||||
Args:
|
||||
num_embed (`int`):
|
||||
Number of embeddings for the latent pixels embeddings.
|
||||
height (`int`):
|
||||
Height of the latent image i.e. the number of height embeddings.
|
||||
width (`int`):
|
||||
Width of the latent image i.e. the number of width embeddings.
|
||||
embed_dim (`int`):
|
||||
Dimension of the produced vector embeddings. Used for the latent pixel, height, and width embeddings.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_embed: int,
|
||||
height: int,
|
||||
width: int,
|
||||
embed_dim: int,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.height = height
|
||||
self.width = width
|
||||
self.num_embed = num_embed
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
self.emb = nn.Embedding(self.num_embed, embed_dim)
|
||||
self.height_emb = nn.Embedding(self.height, embed_dim)
|
||||
self.width_emb = nn.Embedding(self.width, embed_dim)
|
||||
|
||||
def forward(self, index):
|
||||
emb = self.emb(index)
|
||||
|
||||
height_emb = self.height_emb(torch.arange(self.height, device=index.device).view(1, self.height))
|
||||
|
||||
# 1 x H x D -> 1 x H x 1 x D
|
||||
height_emb = height_emb.unsqueeze(2)
|
||||
|
||||
width_emb = self.width_emb(torch.arange(self.width, device=index.device).view(1, self.width))
|
||||
|
||||
# 1 x W x D -> 1 x 1 x W x D
|
||||
width_emb = width_emb.unsqueeze(1)
|
||||
|
||||
pos_emb = height_emb + width_emb
|
||||
|
||||
# 1 x H x W x D -> 1 x L xD
|
||||
pos_emb = pos_emb.view(1, self.height * self.width, -1)
|
||||
|
||||
emb = emb + pos_emb[:, : emb.shape[1], :]
|
||||
|
||||
return emb
|
||||
|
||||
@@ -17,41 +17,23 @@ import flax.linen as nn
|
||||
import jax.numpy as jnp
|
||||
|
||||
|
||||
def get_sinusoidal_embeddings(
|
||||
timesteps: jnp.ndarray,
|
||||
embedding_dim: int,
|
||||
freq_shift: float = 1,
|
||||
min_timescale: float = 1,
|
||||
max_timescale: float = 1.0e4,
|
||||
flip_sin_to_cos: bool = False,
|
||||
scale: float = 1.0,
|
||||
) -> jnp.ndarray:
|
||||
"""Returns the positional encoding (same as Tensor2Tensor).
|
||||
Args:
|
||||
timesteps: a 1-D Tensor of N indices, one per batch element.
|
||||
These may be fractional.
|
||||
embedding_dim: The number of output channels.
|
||||
min_timescale: The smallest time unit (should probably be 0.0).
|
||||
max_timescale: The largest time unit.
|
||||
Returns:
|
||||
a Tensor of timing signals [N, num_channels]
|
||||
# This is like models.embeddings.get_timestep_embedding (PyTorch) but
|
||||
# less general (only handles the case we currently need).
|
||||
def get_sinusoidal_embeddings(timesteps, embedding_dim, freq_shift: float = 1):
|
||||
"""
|
||||
assert timesteps.ndim == 1, "Timesteps should be a 1d-array"
|
||||
assert embedding_dim % 2 == 0, f"Embedding dimension {embedding_dim} should be even"
|
||||
num_timescales = float(embedding_dim // 2)
|
||||
log_timescale_increment = math.log(max_timescale / min_timescale) / (num_timescales - freq_shift)
|
||||
inv_timescales = min_timescale * jnp.exp(jnp.arange(num_timescales, dtype=jnp.float32) * -log_timescale_increment)
|
||||
emb = jnp.expand_dims(timesteps, 1) * jnp.expand_dims(inv_timescales, 0)
|
||||
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
|
||||
|
||||
# scale embeddings
|
||||
scaled_time = scale * emb
|
||||
|
||||
if flip_sin_to_cos:
|
||||
signal = jnp.concatenate([jnp.cos(scaled_time), jnp.sin(scaled_time)], axis=1)
|
||||
else:
|
||||
signal = jnp.concatenate([jnp.sin(scaled_time), jnp.cos(scaled_time)], axis=1)
|
||||
signal = jnp.reshape(signal, [jnp.shape(timesteps)[0], embedding_dim])
|
||||
return signal
|
||||
:param timesteps: a 1-D tensor of N indices, one per batch element.
|
||||
These may be fractional.
|
||||
:param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
|
||||
embeddings. :return: an [N x dim] tensor of positional embeddings.
|
||||
"""
|
||||
half_dim = embedding_dim // 2
|
||||
emb = math.log(10000) / (half_dim - freq_shift)
|
||||
emb = jnp.exp(jnp.arange(half_dim) * -emb)
|
||||
emb = timesteps[:, None] * emb[None, :]
|
||||
emb = jnp.concatenate([jnp.cos(emb), jnp.sin(emb)], -1)
|
||||
return emb
|
||||
|
||||
|
||||
class FlaxTimestepEmbedding(nn.Module):
|
||||
@@ -88,6 +70,4 @@ class FlaxTimesteps(nn.Module):
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, timesteps):
|
||||
return get_sinusoidal_embeddings(
|
||||
timesteps, embedding_dim=self.dim, freq_shift=self.freq_shift, flip_sin_to_cos=True
|
||||
)
|
||||
return get_sinusoidal_embeddings(timesteps, self.dim, freq_shift=self.freq_shift)
|
||||
|
||||
@@ -15,7 +15,7 @@ import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from .attention import AttentionBlock, Transformer2DModel
|
||||
from .attention import AttentionBlock, SpatialTransformer
|
||||
from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, ResnetBlock2D, Upsample2D
|
||||
|
||||
|
||||
@@ -109,19 +109,6 @@ def get_down_block(
|
||||
resnet_groups=resnet_groups,
|
||||
downsample_padding=downsample_padding,
|
||||
)
|
||||
elif down_block_type == "AttnDownEncoderBlock2D":
|
||||
return AttnDownEncoderBlock2D(
|
||||
num_layers=num_layers,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
add_downsample=add_downsample,
|
||||
resnet_eps=resnet_eps,
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
resnet_groups=resnet_groups,
|
||||
downsample_padding=downsample_padding,
|
||||
attn_num_head_channels=attn_num_head_channels,
|
||||
)
|
||||
raise ValueError(f"{down_block_type} does not exist.")
|
||||
|
||||
|
||||
def get_up_block(
|
||||
@@ -213,17 +200,6 @@ def get_up_block(
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
resnet_groups=resnet_groups,
|
||||
)
|
||||
elif up_block_type == "AttnUpDecoderBlock2D":
|
||||
return AttnUpDecoderBlock2D(
|
||||
num_layers=num_layers,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
add_upsample=add_upsample,
|
||||
resnet_eps=resnet_eps,
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
resnet_groups=resnet_groups,
|
||||
attn_num_head_channels=attn_num_head_channels,
|
||||
)
|
||||
raise ValueError(f"{up_block_type} does not exist.")
|
||||
|
||||
|
||||
@@ -273,7 +249,7 @@ class UNetMidBlock2D(nn.Module):
|
||||
num_head_channels=attn_num_head_channels,
|
||||
rescale_output_factor=output_scale_factor,
|
||||
eps=resnet_eps,
|
||||
norm_num_groups=resnet_groups,
|
||||
num_groups=resnet_groups,
|
||||
)
|
||||
)
|
||||
resnets.append(
|
||||
@@ -349,13 +325,13 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
||||
|
||||
for _ in range(num_layers):
|
||||
attentions.append(
|
||||
Transformer2DModel(
|
||||
SpatialTransformer(
|
||||
in_channels,
|
||||
attn_num_head_channels,
|
||||
in_channels // attn_num_head_channels,
|
||||
in_channels=in_channels,
|
||||
num_layers=1,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
norm_num_groups=resnet_groups,
|
||||
depth=1,
|
||||
context_dim=cross_attention_dim,
|
||||
num_groups=resnet_groups,
|
||||
)
|
||||
)
|
||||
resnets.append(
|
||||
@@ -391,14 +367,10 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
||||
for attn in self.attentions:
|
||||
attn._set_attention_slice(slice_size)
|
||||
|
||||
def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
|
||||
for attn in self.attentions:
|
||||
attn._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
|
||||
|
||||
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
|
||||
hidden_states = self.resnets[0](hidden_states, temb)
|
||||
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
||||
hidden_states = attn(hidden_states, encoder_hidden_states).sample
|
||||
hidden_states = attn(hidden_states, encoder_hidden_states)
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
return hidden_states
|
||||
@@ -451,7 +423,7 @@ class AttnDownBlock2D(nn.Module):
|
||||
num_head_channels=attn_num_head_channels,
|
||||
rescale_output_factor=output_scale_factor,
|
||||
eps=resnet_eps,
|
||||
norm_num_groups=resnet_groups,
|
||||
num_groups=resnet_groups,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -462,7 +434,7 @@ class AttnDownBlock2D(nn.Module):
|
||||
self.downsamplers = nn.ModuleList(
|
||||
[
|
||||
Downsample2D(
|
||||
out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
|
||||
in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
|
||||
)
|
||||
]
|
||||
)
|
||||
@@ -530,13 +502,13 @@ class CrossAttnDownBlock2D(nn.Module):
|
||||
)
|
||||
)
|
||||
attentions.append(
|
||||
Transformer2DModel(
|
||||
SpatialTransformer(
|
||||
out_channels,
|
||||
attn_num_head_channels,
|
||||
out_channels // attn_num_head_channels,
|
||||
in_channels=out_channels,
|
||||
num_layers=1,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
norm_num_groups=resnet_groups,
|
||||
depth=1,
|
||||
context_dim=cross_attention_dim,
|
||||
num_groups=resnet_groups,
|
||||
)
|
||||
)
|
||||
self.attentions = nn.ModuleList(attentions)
|
||||
@@ -546,7 +518,7 @@ class CrossAttnDownBlock2D(nn.Module):
|
||||
self.downsamplers = nn.ModuleList(
|
||||
[
|
||||
Downsample2D(
|
||||
out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
|
||||
in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
|
||||
)
|
||||
]
|
||||
)
|
||||
@@ -570,32 +542,25 @@ class CrossAttnDownBlock2D(nn.Module):
|
||||
for attn in self.attentions:
|
||||
attn._set_attention_slice(slice_size)
|
||||
|
||||
def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
|
||||
for attn in self.attentions:
|
||||
attn._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
|
||||
|
||||
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
|
||||
output_states = ()
|
||||
|
||||
for resnet, attn in zip(self.resnets, self.attentions):
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states
|
||||
)[0]
|
||||
create_custom_forward(attn), hidden_states, encoder_hidden_states
|
||||
)
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
|
||||
hidden_states = attn(hidden_states, context=encoder_hidden_states)
|
||||
|
||||
output_states += (hidden_states,)
|
||||
|
||||
@@ -651,7 +616,7 @@ class DownBlock2D(nn.Module):
|
||||
self.downsamplers = nn.ModuleList(
|
||||
[
|
||||
Downsample2D(
|
||||
out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
|
||||
in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
|
||||
)
|
||||
]
|
||||
)
|
||||
@@ -729,7 +694,7 @@ class DownEncoderBlock2D(nn.Module):
|
||||
self.downsamplers = nn.ModuleList(
|
||||
[
|
||||
Downsample2D(
|
||||
out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
|
||||
in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
|
||||
)
|
||||
]
|
||||
)
|
||||
@@ -790,7 +755,7 @@ class AttnDownEncoderBlock2D(nn.Module):
|
||||
num_head_channels=attn_num_head_channels,
|
||||
rescale_output_factor=output_scale_factor,
|
||||
eps=resnet_eps,
|
||||
norm_num_groups=resnet_groups,
|
||||
num_groups=resnet_groups,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -801,7 +766,7 @@ class AttnDownEncoderBlock2D(nn.Module):
|
||||
self.downsamplers = nn.ModuleList(
|
||||
[
|
||||
Downsample2D(
|
||||
out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
|
||||
in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
|
||||
)
|
||||
]
|
||||
)
|
||||
@@ -886,7 +851,7 @@ class AttnSkipDownBlock2D(nn.Module):
|
||||
down=True,
|
||||
kernel="fir",
|
||||
)
|
||||
self.downsamplers = nn.ModuleList([FirDownsample2D(out_channels, out_channels=out_channels)])
|
||||
self.downsamplers = nn.ModuleList([FirDownsample2D(in_channels, out_channels=out_channels)])
|
||||
self.skip_conv = nn.Conv2d(3, out_channels, kernel_size=(1, 1), stride=(1, 1))
|
||||
else:
|
||||
self.resnet_down = None
|
||||
@@ -966,7 +931,7 @@ class SkipDownBlock2D(nn.Module):
|
||||
down=True,
|
||||
kernel="fir",
|
||||
)
|
||||
self.downsamplers = nn.ModuleList([FirDownsample2D(out_channels, out_channels=out_channels)])
|
||||
self.downsamplers = nn.ModuleList([FirDownsample2D(in_channels, out_channels=out_channels)])
|
||||
self.skip_conv = nn.Conv2d(3, out_channels, kernel_size=(1, 1), stride=(1, 1))
|
||||
else:
|
||||
self.resnet_down = None
|
||||
@@ -1041,7 +1006,7 @@ class AttnUpBlock2D(nn.Module):
|
||||
num_head_channels=attn_num_head_channels,
|
||||
rescale_output_factor=output_scale_factor,
|
||||
eps=resnet_eps,
|
||||
norm_num_groups=resnet_groups,
|
||||
num_groups=resnet_groups,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1088,6 +1053,7 @@ class CrossAttnUpBlock2D(nn.Module):
|
||||
cross_attention_dim=1280,
|
||||
attention_type="default",
|
||||
output_scale_factor=1.0,
|
||||
downsample_padding=1,
|
||||
add_upsample=True,
|
||||
):
|
||||
super().__init__()
|
||||
@@ -1116,13 +1082,13 @@ class CrossAttnUpBlock2D(nn.Module):
|
||||
)
|
||||
)
|
||||
attentions.append(
|
||||
Transformer2DModel(
|
||||
SpatialTransformer(
|
||||
out_channels,
|
||||
attn_num_head_channels,
|
||||
out_channels // attn_num_head_channels,
|
||||
in_channels=out_channels,
|
||||
num_layers=1,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
norm_num_groups=resnet_groups,
|
||||
depth=1,
|
||||
context_dim=cross_attention_dim,
|
||||
num_groups=resnet_groups,
|
||||
)
|
||||
)
|
||||
self.attentions = nn.ModuleList(attentions)
|
||||
@@ -1152,10 +1118,6 @@ class CrossAttnUpBlock2D(nn.Module):
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
|
||||
for attn in self.attentions:
|
||||
attn._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
@@ -1172,22 +1134,19 @@ class CrossAttnUpBlock2D(nn.Module):
|
||||
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states
|
||||
)[0]
|
||||
create_custom_forward(attn), hidden_states, encoder_hidden_states
|
||||
)
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
|
||||
hidden_states = attn(hidden_states, context=encoder_hidden_states)
|
||||
|
||||
if self.upsamplers is not None:
|
||||
for upsampler in self.upsamplers:
|
||||
@@ -1367,7 +1326,7 @@ class AttnUpDecoderBlock2D(nn.Module):
|
||||
num_head_channels=attn_num_head_channels,
|
||||
rescale_output_factor=output_scale_factor,
|
||||
eps=resnet_eps,
|
||||
norm_num_groups=resnet_groups,
|
||||
num_groups=resnet_groups,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
import flax.linen as nn
|
||||
import jax.numpy as jnp
|
||||
|
||||
from .attention_flax import FlaxTransformer2DModel
|
||||
from .attention_flax import FlaxSpatialTransformer
|
||||
from .resnet_flax import FlaxDownsample2D, FlaxResnetBlock2D, FlaxUpsample2D
|
||||
|
||||
|
||||
@@ -63,7 +63,7 @@ class FlaxCrossAttnDownBlock2D(nn.Module):
|
||||
)
|
||||
resnets.append(res_block)
|
||||
|
||||
attn_block = FlaxTransformer2DModel(
|
||||
attn_block = FlaxSpatialTransformer(
|
||||
in_channels=self.out_channels,
|
||||
n_heads=self.attn_num_head_channels,
|
||||
d_head=self.out_channels // self.attn_num_head_channels,
|
||||
@@ -196,7 +196,7 @@ class FlaxCrossAttnUpBlock2D(nn.Module):
|
||||
)
|
||||
resnets.append(res_block)
|
||||
|
||||
attn_block = FlaxTransformer2DModel(
|
||||
attn_block = FlaxSpatialTransformer(
|
||||
in_channels=self.out_channels,
|
||||
n_heads=self.attn_num_head_channels,
|
||||
d_head=self.out_channels // self.attn_num_head_channels,
|
||||
@@ -326,7 +326,7 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module):
|
||||
attentions = []
|
||||
|
||||
for _ in range(self.num_layers):
|
||||
attn_block = FlaxTransformer2DModel(
|
||||
attn_block = FlaxSpatialTransformer(
|
||||
in_channels=self.in_channels,
|
||||
n_heads=self.attn_num_head_channels,
|
||||
d_head=self.in_channels // self.attn_num_head_channels,
|
||||
|
||||
@@ -225,17 +225,6 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
|
||||
if hasattr(block, "attentions") and block.attentions is not None:
|
||||
block.set_attention_slice(slice_size)
|
||||
|
||||
def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
|
||||
for block in self.down_blocks:
|
||||
if hasattr(block, "attentions") and block.attentions is not None:
|
||||
block.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
|
||||
|
||||
self.mid_block.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
|
||||
|
||||
for block in self.up_blocks:
|
||||
if hasattr(block, "attentions") and block.attentions is not None:
|
||||
block.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D)):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
@@ -233,16 +233,14 @@ class VectorQuantizer(nn.Module):
|
||||
# NOTE: due to a bug the beta term was applied to the wrong term. for
|
||||
# backwards compatibility we use the buggy version by default, but you can
|
||||
# specify legacy=False to fix it.
|
||||
def __init__(
|
||||
self, n_e, vq_embed_dim, beta, remap=None, unknown_index="random", sane_index_shape=False, legacy=True
|
||||
):
|
||||
def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random", sane_index_shape=False, legacy=True):
|
||||
super().__init__()
|
||||
self.n_e = n_e
|
||||
self.vq_embed_dim = vq_embed_dim
|
||||
self.e_dim = e_dim
|
||||
self.beta = beta
|
||||
self.legacy = legacy
|
||||
|
||||
self.embedding = nn.Embedding(self.n_e, self.vq_embed_dim)
|
||||
self.embedding = nn.Embedding(self.n_e, self.e_dim)
|
||||
self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
|
||||
|
||||
self.remap = remap
|
||||
@@ -289,7 +287,7 @@ class VectorQuantizer(nn.Module):
|
||||
def forward(self, z):
|
||||
# reshape z -> (batch, height, width, channel) and flatten
|
||||
z = z.permute(0, 2, 3, 1).contiguous()
|
||||
z_flattened = z.view(-1, self.vq_embed_dim)
|
||||
z_flattened = z.view(-1, self.e_dim)
|
||||
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
|
||||
|
||||
d = (
|
||||
@@ -411,7 +409,6 @@ class VQModel(ModelMixin, ConfigMixin):
|
||||
latent_channels (`int`, *optional*, defaults to `3`): Number of channels in the latent space.
|
||||
sample_size (`int`, *optional*, defaults to `32`): TODO
|
||||
num_vq_embeddings (`int`, *optional*, defaults to `256`): Number of codebook vectors in the VQ-VAE.
|
||||
vq_embed_dim (`int`, *optional*): Hidden dim of codebook vectors in the VQ-VAE.
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
@@ -428,7 +425,6 @@ class VQModel(ModelMixin, ConfigMixin):
|
||||
sample_size: int = 32,
|
||||
num_vq_embeddings: int = 256,
|
||||
norm_num_groups: int = 32,
|
||||
vq_embed_dim: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -444,11 +440,11 @@ class VQModel(ModelMixin, ConfigMixin):
|
||||
double_z=False,
|
||||
)
|
||||
|
||||
vq_embed_dim = vq_embed_dim if vq_embed_dim is not None else latent_channels
|
||||
|
||||
self.quant_conv = torch.nn.Conv2d(latent_channels, vq_embed_dim, 1)
|
||||
self.quantize = VectorQuantizer(num_vq_embeddings, vq_embed_dim, beta=0.25, remap=None, sane_index_shape=False)
|
||||
self.post_quant_conv = torch.nn.Conv2d(vq_embed_dim, latent_channels, 1)
|
||||
self.quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1)
|
||||
self.quantize = VectorQuantizer(
|
||||
num_vq_embeddings, latent_channels, beta=0.25, remap=None, sane_index_shape=False
|
||||
)
|
||||
self.post_quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1)
|
||||
|
||||
# pass init params to Decoder
|
||||
self.decoder = Decoder(
|
||||
|
||||
@@ -29,7 +29,6 @@ from PIL import Image
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from .configuration_utils import ConfigMixin
|
||||
from .hub_utils import http_user_agent
|
||||
from .modeling_flax_utils import FLAX_WEIGHTS_NAME, FlaxModelMixin
|
||||
from .schedulers.scheduling_utils_flax import SCHEDULER_CONFIG_NAME, FlaxSchedulerMixin
|
||||
from .utils import CONFIG_NAME, DIFFUSERS_CACHE, BaseOutput, is_transformers_available, logging
|
||||
@@ -161,10 +160,6 @@ class FlaxDiffusionPipeline(ConfigMixin):
|
||||
|
||||
for pipeline_component_name in model_index_dict.keys():
|
||||
sub_model = getattr(self, pipeline_component_name)
|
||||
if sub_model is None:
|
||||
# edge case for saving a pipeline with safety_checker=None
|
||||
continue
|
||||
|
||||
model_cls = sub_model.__class__
|
||||
|
||||
save_method_name = None
|
||||
@@ -276,7 +271,7 @@ class FlaxDiffusionPipeline(ConfigMixin):
|
||||
>>> # Download pipeline, but overwrite scheduler
|
||||
>>> from diffusers import LMSDiscreteScheduler
|
||||
|
||||
>>> scheduler = LMSDiscreteScheduler.from_config("runwayml/stable-diffusion-v1-5", subfolder="scheduler")
|
||||
>>> scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear")
|
||||
>>> pipeline = FlaxDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", scheduler=scheduler)
|
||||
```
|
||||
"""
|
||||
@@ -306,22 +301,6 @@ class FlaxDiffusionPipeline(ConfigMixin):
|
||||
allow_patterns = [os.path.join(k, "*") for k in folder_names]
|
||||
allow_patterns += [FLAX_WEIGHTS_NAME, SCHEDULER_CONFIG_NAME, CONFIG_NAME, cls.config_name]
|
||||
|
||||
# make sure we don't download PyTorch weights
|
||||
ignore_patterns = "*.bin"
|
||||
|
||||
if cls != FlaxDiffusionPipeline:
|
||||
requested_pipeline_class = cls.__name__
|
||||
else:
|
||||
requested_pipeline_class = config_dict.get("_class_name", cls.__name__)
|
||||
requested_pipeline_class = (
|
||||
requested_pipeline_class
|
||||
if requested_pipeline_class.startswith("Flax")
|
||||
else "Flax" + requested_pipeline_class
|
||||
)
|
||||
|
||||
user_agent = {"pipeline_class": requested_pipeline_class}
|
||||
user_agent = http_user_agent(user_agent)
|
||||
|
||||
# download all allow_patterns
|
||||
cached_folder = snapshot_download(
|
||||
pretrained_model_name_or_path,
|
||||
@@ -332,8 +311,6 @@ class FlaxDiffusionPipeline(ConfigMixin):
|
||||
use_auth_token=use_auth_token,
|
||||
revision=revision,
|
||||
allow_patterns=allow_patterns,
|
||||
ignore_patterns=ignore_patterns,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
else:
|
||||
cached_folder = pretrained_model_name_or_path
|
||||
@@ -351,7 +328,7 @@ class FlaxDiffusionPipeline(ConfigMixin):
|
||||
if config_dict["_class_name"].startswith("Flax")
|
||||
else "Flax" + config_dict["_class_name"]
|
||||
)
|
||||
pipeline_class = getattr(diffusers_module, class_name)
|
||||
pipeline_class = getattr(diffusers_module, config_dict["_class_name"])
|
||||
|
||||
# some modules can be passed directly to the init
|
||||
# in this case they are already instantiated in `kwargs`
|
||||
@@ -371,11 +348,6 @@ class FlaxDiffusionPipeline(ConfigMixin):
|
||||
|
||||
# 3. Load each module in the pipeline
|
||||
for name, (library_name, class_name) in init_dict.items():
|
||||
if class_name is None:
|
||||
# edge case for when the pipeline was saved with safety_checker=None
|
||||
init_kwargs[name] = None
|
||||
continue
|
||||
|
||||
is_pipeline_module = hasattr(pipelines, library_name)
|
||||
loaded_sub_model = None
|
||||
sub_model_should_be_defined = True
|
||||
|
||||
@@ -30,10 +30,9 @@ from packaging import version
|
||||
from PIL import Image
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from . import __version__
|
||||
from .configuration_utils import ConfigMixin
|
||||
from .dynamic_modules_utils import get_class_from_dynamic_module
|
||||
from .hub_utils import http_user_agent
|
||||
from .modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT
|
||||
from .schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
|
||||
from .utils import (
|
||||
CONFIG_NAME,
|
||||
@@ -42,8 +41,6 @@ from .utils import (
|
||||
WEIGHTS_NAME,
|
||||
BaseOutput,
|
||||
deprecate,
|
||||
is_accelerate_available,
|
||||
is_torch_version,
|
||||
is_transformers_available,
|
||||
logging,
|
||||
)
|
||||
@@ -179,10 +176,6 @@ class DiffusionPipeline(ConfigMixin):
|
||||
|
||||
for pipeline_component_name in model_index_dict.keys():
|
||||
sub_model = getattr(self, pipeline_component_name)
|
||||
if sub_model is None:
|
||||
# edge case for saving a pipeline with safety_checker=None
|
||||
continue
|
||||
|
||||
model_cls = sub_model.__class__
|
||||
|
||||
save_method_name = None
|
||||
@@ -209,13 +202,13 @@ class DiffusionPipeline(ConfigMixin):
|
||||
for name in module_names.keys():
|
||||
module = getattr(self, name)
|
||||
if isinstance(module, torch.nn.Module):
|
||||
if module.dtype == torch.float16 and str(torch_device) in ["cpu"]:
|
||||
if module.dtype == torch.float16 and str(torch_device) in ["cpu", "mps"]:
|
||||
logger.warning(
|
||||
"Pipelines loaded with `torch_dtype=torch.float16` cannot run with `cpu` device. It"
|
||||
" is not recommended to move them to `cpu` as running them will fail. Please make"
|
||||
" sure to use an accelerator to run the pipeline in inference, due to the lack of"
|
||||
" support for`float16` operations on this device in PyTorch. Please, remove the"
|
||||
" `torch_dtype=torch.float16` argument, or use another device for inference."
|
||||
"Pipelines loaded with `torch_dtype=torch.float16` cannot run with `cpu` or `mps` device. It"
|
||||
" is not recommended to move them to `cpu` or `mps` as running them will fail. Please make"
|
||||
" sure to use a `cuda` device to run the pipeline in inference. due to the lack of support for"
|
||||
" `float16` operations on those devices in PyTorch. Please remove the"
|
||||
" `torch_dtype=torch.float16` argument, or use a `cuda` device to run inference."
|
||||
)
|
||||
module.to(torch_device)
|
||||
return self
|
||||
@@ -331,19 +324,6 @@ class DiffusionPipeline(ConfigMixin):
|
||||
Mirror source to accelerate downloads in China. If you are from China and have an accessibility
|
||||
problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
|
||||
Please refer to the mirror site for more information. specify the folder name here.
|
||||
device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
|
||||
A map that specifies where each submodule should go. It doesn't need to be refined to each
|
||||
parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the
|
||||
same device.
|
||||
|
||||
To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For
|
||||
more information about each option see [designing a device
|
||||
map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
|
||||
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
|
||||
Speed up model loading by not initializing the weights and only loading the pre-trained weights. This
|
||||
also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the
|
||||
model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch,
|
||||
setting this argument to `True` will raise an error.
|
||||
|
||||
kwargs (remaining dictionary of keyword arguments, *optional*):
|
||||
Can be used to overwrite load - and saveable variables - *i.e.* the pipeline components - of the
|
||||
@@ -380,7 +360,7 @@ class DiffusionPipeline(ConfigMixin):
|
||||
>>> # Download pipeline, but overwrite scheduler
|
||||
>>> from diffusers import LMSDiscreteScheduler
|
||||
|
||||
>>> scheduler = LMSDiscreteScheduler.from_config("runwayml/stable-diffusion-v1-5", subfolder="scheduler")
|
||||
>>> scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear")
|
||||
>>> pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", scheduler=scheduler)
|
||||
```
|
||||
"""
|
||||
@@ -396,34 +376,6 @@ class DiffusionPipeline(ConfigMixin):
|
||||
provider = kwargs.pop("provider", None)
|
||||
sess_options = kwargs.pop("sess_options", None)
|
||||
device_map = kwargs.pop("device_map", None)
|
||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
||||
|
||||
if low_cpu_mem_usage and not is_accelerate_available():
|
||||
low_cpu_mem_usage = False
|
||||
logger.warn(
|
||||
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
||||
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
||||
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
||||
" install accelerate\n```\n."
|
||||
)
|
||||
|
||||
if device_map is not None and not is_torch_version(">=", "1.9.0"):
|
||||
raise NotImplementedError(
|
||||
"Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
||||
" `device_map=None`."
|
||||
)
|
||||
|
||||
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
||||
raise NotImplementedError(
|
||||
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
||||
" `low_cpu_mem_usage=False`."
|
||||
)
|
||||
|
||||
if low_cpu_mem_usage is False and device_map is not None:
|
||||
raise ValueError(
|
||||
f"You cannot set `low_cpu_mem_usage` to False while using device_map={device_map} for loading and"
|
||||
" dispatching. Please make sure to set `low_cpu_mem_usage=True`."
|
||||
)
|
||||
|
||||
# 1. Download the checkpoints and configs
|
||||
# use snapshot download here to get it working from from_pretrained
|
||||
@@ -443,20 +395,13 @@ class DiffusionPipeline(ConfigMixin):
|
||||
allow_patterns = [os.path.join(k, "*") for k in folder_names]
|
||||
allow_patterns += [WEIGHTS_NAME, SCHEDULER_CONFIG_NAME, CONFIG_NAME, ONNX_WEIGHTS_NAME, cls.config_name]
|
||||
|
||||
# make sure we don't download flax weights
|
||||
ignore_patterns = "*.msgpack"
|
||||
|
||||
if custom_pipeline is not None:
|
||||
allow_patterns += [CUSTOM_PIPELINE_FILE_NAME]
|
||||
|
||||
if cls != DiffusionPipeline:
|
||||
requested_pipeline_class = cls.__name__
|
||||
else:
|
||||
requested_pipeline_class = config_dict.get("_class_name", cls.__name__)
|
||||
user_agent = {"pipeline_class": requested_pipeline_class}
|
||||
requested_pipeline_class = config_dict.get("_class_name", cls.__name__)
|
||||
user_agent = {"diffusers": __version__, "pipeline_class": requested_pipeline_class}
|
||||
if custom_pipeline is not None:
|
||||
user_agent["custom_pipeline"] = custom_pipeline
|
||||
user_agent = http_user_agent(user_agent)
|
||||
|
||||
# download all allow_patterns
|
||||
cached_folder = snapshot_download(
|
||||
@@ -468,7 +413,6 @@ class DiffusionPipeline(ConfigMixin):
|
||||
use_auth_token=use_auth_token,
|
||||
revision=revision,
|
||||
allow_patterns=allow_patterns,
|
||||
ignore_patterns=ignore_patterns,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
else:
|
||||
@@ -525,11 +469,6 @@ class DiffusionPipeline(ConfigMixin):
|
||||
|
||||
# 3. Load each module in the pipeline
|
||||
for name, (library_name, class_name) in init_dict.items():
|
||||
if class_name is None:
|
||||
# edge case for when the pipeline was saved with safety_checker=None
|
||||
init_kwargs[name] = None
|
||||
continue
|
||||
|
||||
# 3.1 - now that JAX/Flax is an official framework of the library, we might load from Flax names
|
||||
if class_name.startswith("Flax"):
|
||||
class_name = class_name[4:]
|
||||
@@ -616,12 +555,8 @@ class DiffusionPipeline(ConfigMixin):
|
||||
and version.parse(version.parse(transformers.__version__).base_version) >= version.parse("4.20.0")
|
||||
)
|
||||
|
||||
# When loading a transformers model, if the device_map is None, the weights will be initialized as opposed to diffusers.
|
||||
# To make default loading faster we set the `low_cpu_mem_usage=low_cpu_mem_usage` flag which is `True` by default.
|
||||
# This makes sure that the weights won't be initialized which significantly speeds up loading.
|
||||
if is_diffusers_model or is_transformers_model:
|
||||
loading_kwargs["device_map"] = device_map
|
||||
loading_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
|
||||
|
||||
# check if the module is in a subdirectory
|
||||
if os.path.isdir(os.path.join(cached_folder, name)):
|
||||
@@ -663,7 +598,7 @@ class DiffusionPipeline(ConfigMixin):
|
||||
... StableDiffusionInpaintPipeline,
|
||||
... )
|
||||
|
||||
>>> img2text = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
|
||||
>>> img2text = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
|
||||
>>> img2img = StableDiffusionImg2ImgPipeline(**img2text.components)
|
||||
>>> inpaint = StableDiffusionInpaintPipeline(**img2text.components)
|
||||
```
|
||||
|
||||
@@ -16,7 +16,7 @@ or created independently from each other.
|
||||
|
||||
To that end, we strive to offer all open-sourced, state-of-the-art diffusion system under a unified API.
|
||||
More specifically, we strive to provide pipelines that
|
||||
- 1. can load the officially published weights and yield 1-to-1 the same outputs as the original implementation according to the corresponding paper (*e.g.* [LDMTextToImagePipeline](https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines/latent_diffusion), uses the officially released weights of [High-Resolution Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752)),
|
||||
- 1. can load the officially published weights and yield 1-to-1 the same outputs as the original implementation according to the corresponding paper (*e.g.* [LatentDiffusionPipeline](https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines/latent_diffusion), uses the officially released weights of [High-Resolution Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752)),
|
||||
- 2. have a simple user interface to run the model in inference (see the [Pipelines API](#pipelines-api) section),
|
||||
- 3. are easy to understand with code that is self-explanatory and can be read along-side the official paper (see [Pipelines summary](#pipelines-summary)),
|
||||
- 4. can easily be contributed by the community (see the [Contribution](#contribution) section).
|
||||
|
||||
@@ -7,7 +7,6 @@ if is_torch_available():
|
||||
from .ddpm import DDPMPipeline
|
||||
from .latent_diffusion_uncond import LDMPipeline
|
||||
from .pndm import PNDMPipeline
|
||||
from .repaint import RePaintPipeline
|
||||
from .score_sde_ve import ScoreSdeVePipeline
|
||||
from .stochastic_karras_ve import KarrasVePipeline
|
||||
else:
|
||||
@@ -16,13 +15,11 @@ else:
|
||||
if is_torch_available() and is_transformers_available():
|
||||
from .latent_diffusion import LDMTextToImagePipeline
|
||||
from .stable_diffusion import (
|
||||
CycleDiffusionPipeline,
|
||||
StableDiffusionImg2ImgPipeline,
|
||||
StableDiffusionInpaintPipeline,
|
||||
StableDiffusionInpaintPipelineLegacy,
|
||||
StableDiffusionPipeline,
|
||||
)
|
||||
from .vq_diffusion import VQDiffusionPipeline
|
||||
|
||||
if is_transformers_available() and is_onnx_available():
|
||||
from .stable_diffusion import (
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
@@ -44,7 +44,6 @@ class DDIMPipeline(DiffusionPipeline):
|
||||
generator: Optional[torch.Generator] = None,
|
||||
eta: float = 0.0,
|
||||
num_inference_steps: int = 50,
|
||||
use_clipped_model_output: Optional[bool] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
**kwargs,
|
||||
@@ -61,9 +60,6 @@ class DDIMPipeline(DiffusionPipeline):
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
use_clipped_model_output (`bool`, *optional*, defaults to `None`):
|
||||
if `True` or `False`, see documentation for `DDIMScheduler.step`. If `None`, nothing is passed
|
||||
downstream to the scheduler. So use `None` for schedulers which don't support this argument.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
@@ -86,14 +82,6 @@ class DDIMPipeline(DiffusionPipeline):
|
||||
# set step values
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
# Ignore use_clipped_model_output if the scheduler doesn't accept this argument
|
||||
accepts_use_clipped_model_output = "use_clipped_model_output" in set(
|
||||
inspect.signature(self.scheduler.step).parameters.keys()
|
||||
)
|
||||
extra_kwargs = {}
|
||||
if accepts_use_clipped_model_output:
|
||||
extra_kwargs["use_clipped_model_output"] = use_clipped_model_output
|
||||
|
||||
for t in self.progress_bar(self.scheduler.timesteps):
|
||||
# 1. predict noise model_output
|
||||
model_output = self.unet(image, t).sample
|
||||
@@ -101,7 +89,7 @@ class DDIMPipeline(DiffusionPipeline):
|
||||
# 2. predict previous mean of image x_t-1 and add variance depending on eta
|
||||
# eta corresponds to η in paper and should be between [0, 1]
|
||||
# do x_t -> x_t-1
|
||||
image = self.scheduler.step(model_output, t, image, eta, **extra_kwargs).prev_sample
|
||||
image = self.scheduler.step(model_output, t, image, eta).prev_sample
|
||||
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
||||
|
||||
@@ -45,7 +45,6 @@ class DDPMPipeline(DiffusionPipeline):
|
||||
num_inference_steps: int = 1000,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
predict_epsilon: bool = True,
|
||||
**kwargs,
|
||||
) -> Union[ImagePipelineOutput, Tuple]:
|
||||
r"""
|
||||
@@ -85,9 +84,7 @@ class DDPMPipeline(DiffusionPipeline):
|
||||
model_output = self.unet(image, t).sample
|
||||
|
||||
# 2. compute previous image: x_t -> x_t-1
|
||||
image = self.scheduler.step(
|
||||
model_output, t, image, generator=generator, predict_epsilon=predict_epsilon
|
||||
).prev_sample
|
||||
image = self.scheduler.step(model_output, t, image, generator=generator).prev_sample
|
||||
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
from .pipeline_repaint import RePaintPipeline
|
||||
@@ -1,140 +0,0 @@
|
||||
# Copyright 2022 ETH Zurich Computer Vision Lab and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import PIL
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from ...models import UNet2DModel
|
||||
from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||
from ...schedulers import RePaintScheduler
|
||||
|
||||
|
||||
def _preprocess_image(image: PIL.Image.Image):
|
||||
image = np.array(image.convert("RGB"))
|
||||
image = image[None].transpose(0, 3, 1, 2)
|
||||
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
|
||||
return image
|
||||
|
||||
|
||||
def _preprocess_mask(mask: PIL.Image.Image):
|
||||
mask = np.array(mask.convert("L"))
|
||||
mask = mask.astype(np.float32) / 255.0
|
||||
mask = mask[None, None]
|
||||
mask[mask < 0.5] = 0
|
||||
mask[mask >= 0.5] = 1
|
||||
mask = torch.from_numpy(mask)
|
||||
return mask
|
||||
|
||||
|
||||
class RePaintPipeline(DiffusionPipeline):
|
||||
unet: UNet2DModel
|
||||
scheduler: RePaintScheduler
|
||||
|
||||
def __init__(self, unet, scheduler):
|
||||
super().__init__()
|
||||
self.register_modules(unet=unet, scheduler=scheduler)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
original_image: Union[torch.FloatTensor, PIL.Image.Image],
|
||||
mask_image: Union[torch.FloatTensor, PIL.Image.Image],
|
||||
num_inference_steps: int = 250,
|
||||
eta: float = 0.0,
|
||||
jump_length: int = 10,
|
||||
jump_n_sample: int = 10,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
) -> Union[ImagePipelineOutput, Tuple]:
|
||||
r"""
|
||||
Args:
|
||||
original_image (`torch.FloatTensor` or `PIL.Image.Image`):
|
||||
The original image to inpaint on.
|
||||
mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
|
||||
The mask_image where 0.0 values define which part of the original image to inpaint (change).
|
||||
num_inference_steps (`int`, *optional*, defaults to 1000):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
eta (`float`):
|
||||
The weight of noise for added noise in a diffusion step. Its value is between 0.0 and 1.0 - 0.0 is DDIM
|
||||
and 1.0 is DDPM scheduler respectively.
|
||||
jump_length (`int`, *optional*, defaults to 10):
|
||||
The number of steps taken forward in time before going backward in time for a single jump ("j" in
|
||||
RePaint paper). Take a look at Figure 9 and 10 in https://arxiv.org/pdf/2201.09865.pdf.
|
||||
jump_n_sample (`int`, *optional*, defaults to 10):
|
||||
The number of times we will make forward time jump for a given chosen time sample. Take a look at
|
||||
Figure 9 and 10 in https://arxiv.org/pdf/2201.09865.pdf.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
||||
deterministic.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
[`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if
|
||||
`return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
|
||||
generated images.
|
||||
"""
|
||||
|
||||
if not isinstance(original_image, torch.FloatTensor):
|
||||
original_image = _preprocess_image(original_image)
|
||||
original_image = original_image.to(self.device)
|
||||
if not isinstance(mask_image, torch.FloatTensor):
|
||||
mask_image = _preprocess_mask(mask_image)
|
||||
mask_image = mask_image.to(self.device)
|
||||
|
||||
# sample gaussian noise to begin the loop
|
||||
image = torch.randn(
|
||||
original_image.shape,
|
||||
generator=generator,
|
||||
device=self.device,
|
||||
)
|
||||
image = image.to(self.device)
|
||||
|
||||
# set step values
|
||||
self.scheduler.set_timesteps(num_inference_steps, jump_length, jump_n_sample, self.device)
|
||||
self.scheduler.eta = eta
|
||||
|
||||
t_last = self.scheduler.timesteps[0] + 1
|
||||
for i, t in enumerate(tqdm(self.scheduler.timesteps)):
|
||||
if t < t_last:
|
||||
# predict the noise residual
|
||||
model_output = self.unet(image, t).sample
|
||||
# compute previous image: x_t -> x_t-1
|
||||
image = self.scheduler.step(model_output, t, image, original_image, mask_image, generator).prev_sample
|
||||
|
||||
else:
|
||||
# compute the reverse: x_t-1 -> x_t
|
||||
image = self.scheduler.undo_step(image, t_last, generator)
|
||||
t_last = t
|
||||
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
||||
if output_type == "pil":
|
||||
image = self.numpy_to_pil(image)
|
||||
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
return ImagePipelineOutput(images=image)
|
||||
@@ -72,7 +72,7 @@ image.save("astronaut_rides_horse.png")
|
||||
# make sure you're logged in with `huggingface-cli login`
|
||||
from diffusers import StableDiffusionPipeline, DDIMScheduler
|
||||
|
||||
scheduler = DDIMScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
|
||||
scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
|
||||
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5",
|
||||
@@ -91,7 +91,11 @@ image.save("astronaut_rides_horse.png")
|
||||
# make sure you're logged in with `huggingface-cli login`
|
||||
from diffusers import StableDiffusionPipeline, LMSDiscreteScheduler
|
||||
|
||||
lms = LMSDiscreteScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
|
||||
lms = LMSDiscreteScheduler(
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
beta_schedule="scaled_linear"
|
||||
)
|
||||
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5",
|
||||
@@ -103,74 +107,3 @@ image = pipe(prompt).sample[0]
|
||||
|
||||
image.save("astronaut_rides_horse.png")
|
||||
```
|
||||
|
||||
### CycleDiffusion using Stable Diffusion and DDIM scheduler
|
||||
|
||||
```python
|
||||
import requests
|
||||
import torch
|
||||
from PIL import Image
|
||||
from io import BytesIO
|
||||
|
||||
from diffusers import CycleDiffusionPipeline, DDIMScheduler
|
||||
|
||||
|
||||
# load the scheduler. CycleDiffusion only supports stochastic schedulers.
|
||||
|
||||
# load the pipeline
|
||||
# make sure you're logged in with `huggingface-cli login`
|
||||
model_id_or_path = "CompVis/stable-diffusion-v1-4"
|
||||
scheduler = DDIMScheduler.from_config(model_id_or_path, subfolder="scheduler")
|
||||
pipe = CycleDiffusionPipeline.from_pretrained(model_id_or_path, scheduler=scheduler).to("cuda")
|
||||
|
||||
# let's download an initial image
|
||||
url = "https://raw.githubusercontent.com/ChenWu98/cycle-diffusion/main/data/dalle2/An%20astronaut%20riding%20a%20horse.png"
|
||||
response = requests.get(url)
|
||||
init_image = Image.open(BytesIO(response.content)).convert("RGB")
|
||||
init_image = init_image.resize((512, 512))
|
||||
init_image.save("horse.png")
|
||||
|
||||
# let's specify a prompt
|
||||
source_prompt = "An astronaut riding a horse"
|
||||
prompt = "An astronaut riding an elephant"
|
||||
|
||||
# call the pipeline
|
||||
image = pipe(
|
||||
prompt=prompt,
|
||||
source_prompt=source_prompt,
|
||||
init_image=init_image,
|
||||
num_inference_steps=100,
|
||||
eta=0.1,
|
||||
strength=0.8,
|
||||
guidance_scale=2,
|
||||
source_guidance_scale=1,
|
||||
).images[0]
|
||||
|
||||
image.save("horse_to_elephant.png")
|
||||
|
||||
# let's try another example
|
||||
# See more samples at the original repo: https://github.com/ChenWu98/cycle-diffusion
|
||||
url = "https://raw.githubusercontent.com/ChenWu98/cycle-diffusion/main/data/dalle2/A%20black%20colored%20car.png"
|
||||
response = requests.get(url)
|
||||
init_image = Image.open(BytesIO(response.content)).convert("RGB")
|
||||
init_image = init_image.resize((512, 512))
|
||||
init_image.save("black.png")
|
||||
|
||||
source_prompt = "A black colored car"
|
||||
prompt = "A blue colored car"
|
||||
|
||||
# call the pipeline
|
||||
torch.manual_seed(0)
|
||||
image = pipe(
|
||||
prompt=prompt,
|
||||
source_prompt=source_prompt,
|
||||
init_image=init_image,
|
||||
num_inference_steps=100,
|
||||
eta=0.1,
|
||||
strength=0.85,
|
||||
guidance_scale=3,
|
||||
source_guidance_scale=1,
|
||||
).images[0]
|
||||
|
||||
image.save("black_to_blue.png")
|
||||
```
|
||||
|
||||
@@ -28,7 +28,6 @@ class StableDiffusionPipelineOutput(BaseOutput):
|
||||
|
||||
|
||||
if is_transformers_available() and is_torch_available():
|
||||
from .pipeline_cycle_diffusion import CycleDiffusionPipeline
|
||||
from .pipeline_stable_diffusion import StableDiffusionPipeline
|
||||
from .pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline
|
||||
from .pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipeline
|
||||
|
||||
@@ -1,527 +0,0 @@
|
||||
import inspect
|
||||
from typing import Callable, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import PIL
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...pipeline_utils import DiffusionPipeline
|
||||
from ...schedulers import DDIMScheduler
|
||||
from ...utils import deprecate, logging
|
||||
from . import StableDiffusionPipelineOutput
|
||||
from .safety_checker import StableDiffusionSafetyChecker
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def preprocess(image):
|
||||
w, h = image.size
|
||||
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
|
||||
image = image.resize((w, h), resample=PIL.Image.LANCZOS)
|
||||
image = np.array(image).astype(np.float32) / 255.0
|
||||
image = image[None].transpose(0, 3, 1, 2)
|
||||
image = torch.from_numpy(image)
|
||||
return 2.0 * image - 1.0
|
||||
|
||||
|
||||
def posterior_sample(scheduler, latents, timestep, clean_latents, eta):
|
||||
# 1. get previous step value (=t-1)
|
||||
prev_timestep = timestep - scheduler.config.num_train_timesteps // scheduler.num_inference_steps
|
||||
|
||||
if prev_timestep <= 0:
|
||||
return clean_latents
|
||||
|
||||
# 2. compute alphas, betas
|
||||
alpha_prod_t = scheduler.alphas_cumprod[timestep]
|
||||
alpha_prod_t_prev = (
|
||||
scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else scheduler.final_alpha_cumprod
|
||||
)
|
||||
|
||||
variance = scheduler._get_variance(timestep, prev_timestep)
|
||||
std_dev_t = eta * variance ** (0.5)
|
||||
|
||||
# direction pointing to x_t
|
||||
e_t = (latents - alpha_prod_t ** (0.5) * clean_latents) / (1 - alpha_prod_t) ** (0.5)
|
||||
dir_xt = (1.0 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * e_t
|
||||
noise = std_dev_t * torch.randn(clean_latents.shape, dtype=clean_latents.dtype, device=clean_latents.device)
|
||||
prev_latents = alpha_prod_t_prev ** (0.5) * clean_latents + dir_xt + noise
|
||||
|
||||
return prev_latents
|
||||
|
||||
|
||||
def compute_noise(scheduler, prev_latents, latents, timestep, noise_pred, eta):
|
||||
# 1. get previous step value (=t-1)
|
||||
prev_timestep = timestep - scheduler.config.num_train_timesteps // scheduler.num_inference_steps
|
||||
|
||||
# 2. compute alphas, betas
|
||||
alpha_prod_t = scheduler.alphas_cumprod[timestep]
|
||||
alpha_prod_t_prev = (
|
||||
scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else scheduler.final_alpha_cumprod
|
||||
)
|
||||
|
||||
beta_prod_t = 1 - alpha_prod_t
|
||||
|
||||
# 3. compute predicted original sample from predicted noise also called
|
||||
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
||||
pred_original_sample = (latents - beta_prod_t ** (0.5) * noise_pred) / alpha_prod_t ** (0.5)
|
||||
|
||||
# 4. Clip "predicted x_0"
|
||||
if scheduler.config.clip_sample:
|
||||
pred_original_sample = torch.clamp(pred_original_sample, -1, 1)
|
||||
|
||||
# 5. compute variance: "sigma_t(η)" -> see formula (16)
|
||||
# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
|
||||
variance = scheduler._get_variance(timestep, prev_timestep)
|
||||
std_dev_t = eta * variance ** (0.5)
|
||||
|
||||
# 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
||||
pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * noise_pred
|
||||
|
||||
noise = (prev_latents - (alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction)) / (
|
||||
variance ** (0.5) * eta
|
||||
)
|
||||
return noise
|
||||
|
||||
|
||||
class CycleDiffusionPipeline(DiffusionPipeline):
|
||||
r"""
|
||||
Pipeline for text-guided image to image generation using Stable Diffusion.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
text_encoder ([`CLIPTextModel`]):
|
||||
Frozen text-encoder. Stable Diffusion uses the text portion of
|
||||
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
|
||||
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
|
||||
tokenizer (`CLIPTokenizer`):
|
||||
Tokenizer of class
|
||||
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
||||
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
|
||||
scheduler ([`SchedulerMixin`]):
|
||||
A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
|
||||
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
||||
safety_checker ([`StableDiffusionSafetyChecker`]):
|
||||
Classification module that estimates whether generated images could be considered offensive or harmful.
|
||||
Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
|
||||
feature_extractor ([`CLIPFeatureExtractor`]):
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vae: AutoencoderKL,
|
||||
text_encoder: CLIPTextModel,
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: DDIMScheduler,
|
||||
safety_checker: StableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
|
||||
deprecation_message = (
|
||||
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
|
||||
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
|
||||
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
|
||||
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
|
||||
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
|
||||
" file"
|
||||
)
|
||||
deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
|
||||
new_config = dict(scheduler.config)
|
||||
new_config["steps_offset"] = 1
|
||||
scheduler._internal_dict = FrozenDict(new_config)
|
||||
|
||||
if safety_checker is None:
|
||||
logger.warn(
|
||||
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
||||
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
|
||||
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
|
||||
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
|
||||
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
|
||||
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
|
||||
)
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
|
||||
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
|
||||
r"""
|
||||
Enable sliced attention computation.
|
||||
|
||||
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
|
||||
in several steps. This is useful to save some memory in exchange for a small speed decrease.
|
||||
|
||||
Args:
|
||||
slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
|
||||
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
|
||||
a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
|
||||
`attention_head_dim` must be a multiple of `slice_size`.
|
||||
"""
|
||||
if slice_size == "auto":
|
||||
# half the attention head size is usually a good trade-off between
|
||||
# speed and memory
|
||||
slice_size = self.unet.config.attention_head_dim // 2
|
||||
self.unet.set_attention_slice(slice_size)
|
||||
|
||||
def disable_attention_slicing(self):
|
||||
r"""
|
||||
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
|
||||
back to computing attention in one step.
|
||||
"""
|
||||
# set slice_size = `None` to disable `set_attention_slice`
|
||||
self.enable_attention_slicing(None)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
source_prompt: Union[str, List[str]],
|
||||
init_image: Union[torch.FloatTensor, PIL.Image.Image],
|
||||
strength: float = 0.8,
|
||||
num_inference_steps: Optional[int] = 50,
|
||||
guidance_scale: Optional[float] = 7.5,
|
||||
source_guidance_scale: Optional[float] = 1,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
eta: Optional[float] = 0.1,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
callback_steps: Optional[int] = 1,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`):
|
||||
The prompt or prompts to guide the image generation.
|
||||
init_image (`torch.FloatTensor` or `PIL.Image.Image`):
|
||||
`Image`, or tensor representing an image batch, that will be used as the starting point for the
|
||||
process.
|
||||
strength (`float`, *optional*, defaults to 0.8):
|
||||
Conceptually, indicates how much to transform the reference `init_image`. Must be between 0 and 1.
|
||||
`init_image` will be used as a starting point, adding more noise to it the larger the `strength`. The
|
||||
number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added
|
||||
noise will be maximum and the denoising process will run for the full number of iterations specified in
|
||||
`num_inference_steps`. A value of 1, therefore, essentially ignores `init_image`.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference. This parameter will be modulated by `strength`.
|
||||
guidance_scale (`float`, *optional*, defaults to 7.5):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
source_guidance_scale (`float`, *optional*, defaults to 1):
|
||||
Guidance scale for the source prompt. This is useful to control the amount of influence the source
|
||||
prompt for encoding.
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
eta (`float`, *optional*, defaults to 0.1):
|
||||
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
||||
[`schedulers.DDIMScheduler`], will be ignored for others.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
||||
deterministic.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
||||
plain tuple.
|
||||
callback (`Callable`, *optional*):
|
||||
A function that will be called every `callback_steps` steps during inference. The function will be
|
||||
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
||||
callback_steps (`int`, *optional*, defaults to 1):
|
||||
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
||||
called at every step.
|
||||
|
||||
Returns:
|
||||
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
||||
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
|
||||
When returning a tuple, the first element is a list with the generated images, and the second element is a
|
||||
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
||||
(nsfw) content, according to the `safety_checker`.
|
||||
"""
|
||||
if isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if batch_size != 1:
|
||||
raise ValueError(
|
||||
"At the moment only `batch_size=1` is supported for prompts, but you seem to have passed multiple"
|
||||
f" prompts: {prompt}. Please make sure to pass only a single prompt."
|
||||
)
|
||||
|
||||
if strength < 0 or strength > 1:
|
||||
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
|
||||
|
||||
if (callback_steps is None) or (
|
||||
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
||||
f" {type(callback_steps)}."
|
||||
)
|
||||
|
||||
# set timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
if isinstance(init_image, PIL.Image.Image):
|
||||
init_image = preprocess(init_image)
|
||||
|
||||
# get prompt text embeddings
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
return_tensors="pt",
|
||||
)
|
||||
source_text_inputs = self.tokenizer(
|
||||
source_prompt,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
source_text_input_ids = source_text_inputs.input_ids
|
||||
|
||||
if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
|
||||
removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
||||
)
|
||||
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
|
||||
if source_text_input_ids.shape[-1] > self.tokenizer.model_max_length:
|
||||
removed_text = self.tokenizer.batch_decode(source_text_input_ids[:, self.tokenizer.model_max_length :])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
||||
)
|
||||
source_text_input_ids = source_text_input_ids[:, : self.tokenizer.model_max_length]
|
||||
text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
|
||||
source_text_embeddings = self.text_encoder(source_text_input_ids.to(self.device))[0]
|
||||
|
||||
# duplicate text embeddings for each generation per prompt
|
||||
text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
source_text_embeddings = source_text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
|
||||
# get unconditional embeddings for classifier free guidance
|
||||
uncond_tokens = [""]
|
||||
|
||||
max_length = text_input_ids.shape[-1]
|
||||
uncond_input = self.tokenizer(
|
||||
uncond_tokens,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
|
||||
|
||||
# duplicate unconditional embeddings for each generation per prompt
|
||||
uncond_embeddings = uncond_embeddings.repeat_interleave(batch_size * num_images_per_prompt, dim=0)
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
||||
|
||||
source_uncond_tokens = [""]
|
||||
|
||||
max_length = source_text_input_ids.shape[-1]
|
||||
source_uncond_input = self.tokenizer(
|
||||
source_uncond_tokens,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
source_uncond_embeddings = self.text_encoder(source_uncond_input.input_ids.to(self.device))[0]
|
||||
|
||||
# duplicate unconditional embeddings for each generation per prompt
|
||||
source_uncond_embeddings = source_uncond_embeddings.repeat_interleave(
|
||||
batch_size * num_images_per_prompt, dim=0
|
||||
)
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
source_text_embeddings = torch.cat([source_uncond_embeddings, source_text_embeddings])
|
||||
|
||||
# encode the init image into latents and scale the latents
|
||||
latents_dtype = text_embeddings.dtype
|
||||
init_image = init_image.to(device=self.device, dtype=latents_dtype)
|
||||
init_latent_dist = self.vae.encode(init_image).latent_dist
|
||||
init_latents = init_latent_dist.sample(generator=generator)
|
||||
init_latents = 0.18215 * init_latents
|
||||
|
||||
if isinstance(prompt, str):
|
||||
prompt = [prompt]
|
||||
if len(prompt) > init_latents.shape[0] and len(prompt) % init_latents.shape[0] == 0:
|
||||
# expand init_latents for batch_size
|
||||
deprecation_message = (
|
||||
f"You have passed {len(prompt)} text prompts (`prompt`), but only {init_latents.shape[0]} initial"
|
||||
" images (`init_image`). Initial images are now duplicating to match the number of text prompts. Note"
|
||||
" that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
|
||||
" your script to pass as many init images as text prompts to suppress this warning."
|
||||
)
|
||||
deprecate("len(prompt) != len(init_image)", "1.0.0", deprecation_message, standard_warn=False)
|
||||
additional_image_per_prompt = len(prompt) // init_latents.shape[0]
|
||||
init_latents = torch.cat([init_latents] * additional_image_per_prompt * num_images_per_prompt, dim=0)
|
||||
elif len(prompt) > init_latents.shape[0] and len(prompt) % init_latents.shape[0] != 0:
|
||||
raise ValueError(
|
||||
f"Cannot duplicate `init_image` of batch size {init_latents.shape[0]} to {len(prompt)} text prompts."
|
||||
)
|
||||
else:
|
||||
init_latents = torch.cat([init_latents] * num_images_per_prompt, dim=0)
|
||||
|
||||
# get the original timestep using init_timestep
|
||||
offset = self.scheduler.config.get("steps_offset", 0)
|
||||
init_timestep = int(num_inference_steps * strength) + offset
|
||||
init_timestep = min(init_timestep, num_inference_steps)
|
||||
|
||||
timesteps = self.scheduler.timesteps[-init_timestep]
|
||||
timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt, device=self.device)
|
||||
|
||||
# add noise to latents using the timesteps
|
||||
noise = torch.randn(init_latents.shape, generator=generator, device=self.device, dtype=latents_dtype)
|
||||
clean_latents = init_latents
|
||||
init_latents = self.scheduler.add_noise(init_latents, noise, timesteps)
|
||||
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
||||
# and should be between [0, 1]
|
||||
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
extra_step_kwargs = {}
|
||||
|
||||
if not (accepts_eta and (0 < eta <= 1)):
|
||||
raise ValueError(
|
||||
"Currently, only the DDIM scheduler is supported. Please make sure that `pipeline.scheduler` is of"
|
||||
f" type {DDIMScheduler.__class__} and not {self.scheduler.__class__}."
|
||||
)
|
||||
|
||||
extra_step_kwargs["eta"] = eta
|
||||
|
||||
latents = init_latents
|
||||
source_latents = init_latents
|
||||
|
||||
t_start = max(num_inference_steps - init_timestep + offset, 0)
|
||||
|
||||
# Some schedulers like PNDM have timesteps as arrays
|
||||
# It's more optimized to move all timesteps to correct device beforehand
|
||||
timesteps = self.scheduler.timesteps[t_start:].to(self.device)
|
||||
|
||||
for i, t in enumerate(self.progress_bar(timesteps)):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2)
|
||||
source_latent_model_input = torch.cat([source_latents] * 2)
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
source_latent_model_input = self.scheduler.scale_model_input(source_latent_model_input, t)
|
||||
|
||||
# predict the noise residual
|
||||
concat_latent_model_input = torch.stack(
|
||||
[
|
||||
source_latent_model_input[0],
|
||||
latent_model_input[0],
|
||||
source_latent_model_input[1],
|
||||
latent_model_input[1],
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
concat_text_embeddings = torch.stack(
|
||||
[
|
||||
source_text_embeddings[0],
|
||||
text_embeddings[0],
|
||||
source_text_embeddings[1],
|
||||
text_embeddings[1],
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
concat_noise_pred = self.unet(
|
||||
concat_latent_model_input, t, encoder_hidden_states=concat_text_embeddings
|
||||
).sample
|
||||
|
||||
# perform guidance
|
||||
(
|
||||
source_noise_pred_uncond,
|
||||
noise_pred_uncond,
|
||||
source_noise_pred_text,
|
||||
noise_pred_text,
|
||||
) = concat_noise_pred.chunk(4, dim=0)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
source_noise_pred = source_noise_pred_uncond + source_guidance_scale * (
|
||||
source_noise_pred_text - source_noise_pred_uncond
|
||||
)
|
||||
|
||||
# Sample source_latents from the posterior distribution.
|
||||
prev_source_latents = posterior_sample(
|
||||
self.scheduler, source_latents, t, clean_latents, **extra_step_kwargs
|
||||
)
|
||||
# Compute noise.
|
||||
noise = compute_noise(
|
||||
self.scheduler, prev_source_latents, source_latents, t, source_noise_pred, **extra_step_kwargs
|
||||
)
|
||||
source_latents = prev_source_latents
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(
|
||||
noise_pred, t, latents, variance_noise=noise, **extra_step_kwargs
|
||||
).prev_sample
|
||||
|
||||
# call the callback, if provided
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
|
||||
latents = 1 / 0.18215 * latents
|
||||
image = self.vae.decode(latents).sample
|
||||
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
||||
|
||||
if self.safety_checker is not None:
|
||||
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(
|
||||
self.device
|
||||
)
|
||||
image, has_nsfw_concept = self.safety_checker(
|
||||
images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
|
||||
)
|
||||
else:
|
||||
has_nsfw_concept = None
|
||||
|
||||
if output_type == "pil":
|
||||
image = self.numpy_to_pil(image)
|
||||
|
||||
if not return_dict:
|
||||
return (image, has_nsfw_concept)
|
||||
|
||||
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
||||
@@ -14,12 +14,7 @@ from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel
|
||||
|
||||
from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel
|
||||
from ...pipeline_flax_utils import FlaxDiffusionPipeline
|
||||
from ...schedulers import (
|
||||
FlaxDDIMScheduler,
|
||||
FlaxDPMSolverMultistepScheduler,
|
||||
FlaxLMSDiscreteScheduler,
|
||||
FlaxPNDMScheduler,
|
||||
)
|
||||
from ...schedulers import FlaxDDIMScheduler, FlaxLMSDiscreteScheduler, FlaxPNDMScheduler
|
||||
from ...utils import logging
|
||||
from . import FlaxStableDiffusionPipelineOutput
|
||||
from .safety_checker_flax import FlaxStableDiffusionSafetyChecker
|
||||
@@ -48,8 +43,7 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
|
||||
unet ([`FlaxUNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
|
||||
scheduler ([`SchedulerMixin`]):
|
||||
A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
|
||||
[`FlaxDDIMScheduler`], [`FlaxLMSDiscreteScheduler`], [`FlaxPNDMScheduler`], or
|
||||
[`FlaxDPMSolverMultistepScheduler`].
|
||||
[`FlaxDDIMScheduler`], [`FlaxLMSDiscreteScheduler`], or [`FlaxPNDMScheduler`].
|
||||
safety_checker ([`FlaxStableDiffusionSafetyChecker`]):
|
||||
Classification module that estimates whether generated images could be considered offensive or harmful.
|
||||
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
|
||||
@@ -63,9 +57,7 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
|
||||
text_encoder: FlaxCLIPTextModel,
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: FlaxUNet2DConditionModel,
|
||||
scheduler: Union[
|
||||
FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler, FlaxDPMSolverMultistepScheduler
|
||||
],
|
||||
scheduler: Union[FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler],
|
||||
safety_checker: FlaxStableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
dtype: jnp.dtype = jnp.float32,
|
||||
|
||||
@@ -5,7 +5,6 @@ import numpy as np
|
||||
|
||||
from transformers import CLIPFeatureExtractor, CLIPTokenizer
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...onnx_utils import OnnxRuntimeModel
|
||||
from ...pipeline_utils import DiffusionPipeline
|
||||
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
||||
@@ -37,34 +36,6 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
|
||||
deprecation_message = (
|
||||
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
|
||||
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
|
||||
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
|
||||
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
|
||||
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
|
||||
" file"
|
||||
)
|
||||
deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
|
||||
new_config = dict(scheduler.config)
|
||||
new_config["steps_offset"] = 1
|
||||
scheduler._internal_dict = FrozenDict(new_config)
|
||||
|
||||
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
|
||||
deprecation_message = (
|
||||
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
|
||||
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
|
||||
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
|
||||
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
|
||||
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
|
||||
)
|
||||
deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
|
||||
new_config = dict(scheduler.config)
|
||||
new_config["clip_sample"] = False
|
||||
scheduler._internal_dict = FrozenDict(new_config)
|
||||
|
||||
self.register_modules(
|
||||
vae_encoder=vae_encoder,
|
||||
vae_decoder=vae_decoder,
|
||||
|
||||
@@ -90,19 +90,6 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
new_config["steps_offset"] = 1
|
||||
scheduler._internal_dict = FrozenDict(new_config)
|
||||
|
||||
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
|
||||
deprecation_message = (
|
||||
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
|
||||
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
|
||||
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
|
||||
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
|
||||
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
|
||||
)
|
||||
deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
|
||||
new_config = dict(scheduler.config)
|
||||
new_config["clip_sample"] = False
|
||||
scheduler._internal_dict = FrozenDict(new_config)
|
||||
|
||||
if safety_checker is None:
|
||||
logger.warning(
|
||||
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
||||
|
||||
@@ -104,19 +104,6 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
new_config["steps_offset"] = 1
|
||||
scheduler._internal_dict = FrozenDict(new_config)
|
||||
|
||||
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
|
||||
deprecation_message = (
|
||||
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
|
||||
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
|
||||
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
|
||||
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
|
||||
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
|
||||
)
|
||||
deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
|
||||
new_config = dict(scheduler.config)
|
||||
new_config["clip_sample"] = False
|
||||
scheduler._internal_dict = FrozenDict(new_config)
|
||||
|
||||
if safety_checker is None:
|
||||
logger.warning(
|
||||
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
||||
|
||||
@@ -9,14 +9,7 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...pipeline_utils import DiffusionPipeline
|
||||
from ...schedulers import (
|
||||
DDIMScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
PNDMScheduler,
|
||||
)
|
||||
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
||||
from ...utils import deprecate, logging
|
||||
from . import StableDiffusionPipelineOutput
|
||||
from .safety_checker import StableDiffusionSafetyChecker
|
||||
@@ -59,14 +52,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
||||
text_encoder: CLIPTextModel,
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: Union[
|
||||
DDIMScheduler,
|
||||
PNDMScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
],
|
||||
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
|
||||
safety_checker: StableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
):
|
||||
@@ -86,19 +72,6 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
||||
new_config["steps_offset"] = 1
|
||||
scheduler._internal_dict = FrozenDict(new_config)
|
||||
|
||||
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
|
||||
deprecation_message = (
|
||||
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
|
||||
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
|
||||
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
|
||||
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
|
||||
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
|
||||
)
|
||||
deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
|
||||
new_config = dict(scheduler.config)
|
||||
new_config["clip_sample"] = False
|
||||
scheduler._internal_dict = FrozenDict(new_config)
|
||||
|
||||
if safety_checker is None:
|
||||
logger.warn(
|
||||
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
||||
@@ -119,24 +92,6 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
|
||||
def enable_xformers_memory_efficient_attention(self):
|
||||
r"""
|
||||
Enable memory efficient attention as implemented in xformers.
|
||||
|
||||
When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
|
||||
time. Speed up at training time is not guaranteed.
|
||||
|
||||
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
|
||||
is used.
|
||||
"""
|
||||
self.unet.set_use_memory_efficient_attention_xformers(True)
|
||||
|
||||
def disable_xformers_memory_efficient_attention(self):
|
||||
r"""
|
||||
Disable memory efficient attention as implemented in xformers.
|
||||
"""
|
||||
self.unet.set_use_memory_efficient_attention_xformers(False)
|
||||
|
||||
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
|
||||
r"""
|
||||
Enable sliced attention computation.
|
||||
@@ -178,8 +133,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
||||
device = torch.device("cuda")
|
||||
|
||||
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]:
|
||||
if cpu_offloaded_model is not None:
|
||||
cpu_offload(cpu_offloaded_model, device)
|
||||
cpu_offload(cpu_offloaded_model, device)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
@@ -303,7 +257,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
||||
if do_classifier_free_guidance:
|
||||
uncond_tokens: List[str]
|
||||
if negative_prompt is None:
|
||||
uncond_tokens = [""] * batch_size
|
||||
uncond_tokens = [""]
|
||||
elif type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
@@ -332,7 +286,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
||||
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
seq_len = uncond_embeddings.shape[1]
|
||||
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
|
||||
uncond_embeddings = uncond_embeddings.repeat(batch_size, num_images_per_prompt, 1)
|
||||
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
@@ -379,11 +333,6 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
||||
if accepts_eta:
|
||||
extra_step_kwargs["eta"] = eta
|
||||
|
||||
# check if the scheduler accepts generator
|
||||
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
if accepts_generator:
|
||||
extra_step_kwargs["generator"] = generator
|
||||
|
||||
for i, t in enumerate(self.progress_bar(timesteps_tensor)):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
|
||||
@@ -5,19 +5,12 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
import PIL
|
||||
from diffusers.utils import is_accelerate_available
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...pipeline_utils import DiffusionPipeline
|
||||
from ...schedulers import (
|
||||
DDIMScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
PNDMScheduler,
|
||||
)
|
||||
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
||||
from ...utils import deprecate, logging
|
||||
from . import StableDiffusionPipelineOutput
|
||||
from .safety_checker import StableDiffusionSafetyChecker
|
||||
@@ -70,9 +63,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
text_encoder: CLIPTextModel,
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: Union[
|
||||
DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler
|
||||
],
|
||||
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
|
||||
safety_checker: StableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
):
|
||||
@@ -92,19 +83,6 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
new_config["steps_offset"] = 1
|
||||
scheduler._internal_dict = FrozenDict(new_config)
|
||||
|
||||
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
|
||||
deprecation_message = (
|
||||
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
|
||||
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
|
||||
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
|
||||
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
|
||||
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
|
||||
)
|
||||
deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
|
||||
new_config = dict(scheduler.config)
|
||||
new_config["clip_sample"] = False
|
||||
scheduler._internal_dict = FrozenDict(new_config)
|
||||
|
||||
if safety_checker is None:
|
||||
logger.warn(
|
||||
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
||||
@@ -152,41 +130,6 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
# set slice_size = `None` to disable `set_attention_slice`
|
||||
self.enable_attention_slicing(None)
|
||||
|
||||
def enable_sequential_cpu_offload(self):
|
||||
r"""
|
||||
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
|
||||
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
|
||||
`torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
|
||||
"""
|
||||
if is_accelerate_available():
|
||||
from accelerate import cpu_offload
|
||||
else:
|
||||
raise ImportError("Please install accelerate via `pip install accelerate`")
|
||||
|
||||
device = torch.device("cuda")
|
||||
|
||||
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]:
|
||||
if cpu_offloaded_model is not None:
|
||||
cpu_offload(cpu_offloaded_model, device)
|
||||
|
||||
def enable_xformers_memory_efficient_attention(self):
|
||||
r"""
|
||||
Enable memory efficient attention as implemented in xformers.
|
||||
|
||||
When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
|
||||
time. Speed up at training time is not guaranteed.
|
||||
|
||||
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
|
||||
is used.
|
||||
"""
|
||||
self.unet.set_use_memory_efficient_attention_xformers(True)
|
||||
|
||||
def disable_xformers_memory_efficient_attention(self):
|
||||
r"""
|
||||
Disable memory efficient attention as implemented in xformers.
|
||||
"""
|
||||
self.unet.set_use_memory_efficient_attention_xformers(False)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
@@ -313,7 +256,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
if do_classifier_free_guidance:
|
||||
uncond_tokens: List[str]
|
||||
if negative_prompt is None:
|
||||
uncond_tokens = [""] * batch_size
|
||||
uncond_tokens = [""]
|
||||
elif type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
@@ -337,9 +280,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
|
||||
|
||||
# duplicate unconditional embeddings for each generation per prompt
|
||||
seq_len = uncond_embeddings.shape[1]
|
||||
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
|
||||
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
uncond_embeddings = uncond_embeddings.repeat_interleave(batch_size * num_images_per_prompt, dim=0)
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
@@ -394,11 +335,6 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
if accepts_eta:
|
||||
extra_step_kwargs["eta"] = eta
|
||||
|
||||
# check if the scheduler accepts generator
|
||||
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
if accepts_generator:
|
||||
extra_step_kwargs["generator"] = generator
|
||||
|
||||
latents = init_latents
|
||||
|
||||
t_start = max(num_inference_steps - init_timestep + offset, 0)
|
||||
|
||||
@@ -5,7 +5,6 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
import PIL
|
||||
from diffusers.utils import is_accelerate_available
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
@@ -91,20 +90,6 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
new_config["steps_offset"] = 1
|
||||
scheduler._internal_dict = FrozenDict(new_config)
|
||||
|
||||
if hasattr(scheduler.config, "skip_prk_steps") and scheduler.config.skip_prk_steps is False:
|
||||
deprecation_message = (
|
||||
f"The configuration file of this scheduler: {scheduler} has not set the configuration"
|
||||
" `skip_prk_steps`. `skip_prk_steps` should be set to True in the configuration file. Please make"
|
||||
" sure to update the config accordingly as not setting `skip_prk_steps` in the config might lead to"
|
||||
" incorrect results in future versions. If you have downloaded this checkpoint from the Hugging Face"
|
||||
" Hub, it would be very nice if you could open a Pull request for the"
|
||||
" `scheduler/scheduler_config.json` file"
|
||||
)
|
||||
deprecate("skip_prk_steps not set", "1.0.0", deprecation_message, standard_warn=False)
|
||||
new_config = dict(scheduler.config)
|
||||
new_config["skip_prk_steps"] = True
|
||||
scheduler._internal_dict = FrozenDict(new_config)
|
||||
|
||||
if safety_checker is None:
|
||||
logger.warn(
|
||||
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
||||
@@ -152,41 +137,6 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
# set slice_size = `None` to disable `attention slicing`
|
||||
self.enable_attention_slicing(None)
|
||||
|
||||
def enable_sequential_cpu_offload(self):
|
||||
r"""
|
||||
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
|
||||
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
|
||||
`torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
|
||||
"""
|
||||
if is_accelerate_available():
|
||||
from accelerate import cpu_offload
|
||||
else:
|
||||
raise ImportError("Please install accelerate via `pip install accelerate`")
|
||||
|
||||
device = torch.device("cuda")
|
||||
|
||||
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]:
|
||||
if cpu_offloaded_model is not None:
|
||||
cpu_offload(cpu_offloaded_model, device)
|
||||
|
||||
def enable_xformers_memory_efficient_attention(self):
|
||||
r"""
|
||||
Enable memory efficient attention as implemented in xformers.
|
||||
|
||||
When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
|
||||
time. Speed up at training time is not guaranteed.
|
||||
|
||||
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
|
||||
is used.
|
||||
"""
|
||||
self.unet.set_use_memory_efficient_attention_xformers(True)
|
||||
|
||||
def disable_xformers_memory_efficient_attention(self):
|
||||
r"""
|
||||
Disable memory efficient attention as implemented in xformers.
|
||||
"""
|
||||
self.unet.set_use_memory_efficient_attention_xformers(False)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
@@ -320,7 +270,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
if do_classifier_free_guidance:
|
||||
uncond_tokens: List[str]
|
||||
if negative_prompt is None:
|
||||
uncond_tokens = [""] * batch_size
|
||||
uncond_tokens = [""]
|
||||
elif type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
@@ -349,7 +299,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
seq_len = uncond_embeddings.shape[1]
|
||||
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
|
||||
uncond_embeddings = uncond_embeddings.repeat(batch_size, num_images_per_prompt, 1)
|
||||
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
@@ -379,14 +329,11 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
|
||||
# prepare mask and masked_image
|
||||
mask, masked_image = prepare_mask_and_masked_image(image, mask_image)
|
||||
mask = mask.to(device=self.device, dtype=text_embeddings.dtype)
|
||||
masked_image = masked_image.to(device=self.device, dtype=text_embeddings.dtype)
|
||||
|
||||
# resize the mask to latents shape as we concatenate the mask to the latents
|
||||
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload
|
||||
# and half precision
|
||||
mask = torch.nn.functional.interpolate(mask, size=(height // 8, width // 8))
|
||||
mask = mask.to(device=self.device, dtype=text_embeddings.dtype)
|
||||
|
||||
masked_image = masked_image.to(device=self.device, dtype=text_embeddings.dtype)
|
||||
|
||||
# encode the mask image into latents space so we can concatenate it to the latents
|
||||
masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(generator=generator)
|
||||
@@ -401,9 +348,6 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents
|
||||
)
|
||||
|
||||
# aligning device to prevent device errors when concating it with the latent model input
|
||||
masked_image_latents = masked_image_latents.to(device=self.device, dtype=text_embeddings.dtype)
|
||||
|
||||
num_channels_mask = mask.shape[1]
|
||||
num_channels_masked_image = masked_image_latents.shape[1]
|
||||
|
||||
@@ -435,11 +379,6 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
if accepts_eta:
|
||||
extra_step_kwargs["eta"] = eta
|
||||
|
||||
# check if the scheduler accepts generator
|
||||
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
if accepts_generator:
|
||||
extra_step_kwargs["generator"] = generator
|
||||
|
||||
for i, t in enumerate(self.progress_bar(timesteps_tensor)):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
|
||||
@@ -96,19 +96,6 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
|
||||
new_config["steps_offset"] = 1
|
||||
scheduler._internal_dict = FrozenDict(new_config)
|
||||
|
||||
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
|
||||
deprecation_message = (
|
||||
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
|
||||
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
|
||||
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
|
||||
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
|
||||
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
|
||||
)
|
||||
deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
|
||||
new_config = dict(scheduler.config)
|
||||
new_config["clip_sample"] = False
|
||||
scheduler._internal_dict = FrozenDict(new_config)
|
||||
|
||||
if safety_checker is None:
|
||||
logger.warn(
|
||||
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
||||
@@ -284,7 +271,7 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
|
||||
if do_classifier_free_guidance:
|
||||
uncond_tokens: List[str]
|
||||
if negative_prompt is None:
|
||||
uncond_tokens = [""] * batch_size
|
||||
uncond_tokens = [""]
|
||||
elif type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
@@ -312,9 +299,7 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
|
||||
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
|
||||
|
||||
# duplicate unconditional embeddings for each generation per prompt
|
||||
seq_len = uncond_embeddings.shape[1]
|
||||
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
|
||||
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
uncond_embeddings = uncond_embeddings.repeat_interleave(batch_size * num_images_per_prompt, dim=0)
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
@@ -367,11 +352,6 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
|
||||
if accepts_eta:
|
||||
extra_step_kwargs["eta"] = eta
|
||||
|
||||
# check if the scheduler accepts generator
|
||||
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
if accepts_generator:
|
||||
extra_step_kwargs["generator"] = generator
|
||||
|
||||
latents = init_latents
|
||||
|
||||
t_start = max(num_inference_steps - init_timestep + offset, 0)
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
from .pipeline_vq_diffusion import VQDiffusionPipeline
|
||||
@@ -1,253 +0,0 @@
|
||||
# Copyright 2022 Microsoft and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import Transformer2DModel, VQModel
|
||||
from diffusers.schedulers.scheduling_vq_diffusion import VQDiffusionScheduler
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class VQDiffusionPipeline(DiffusionPipeline):
|
||||
r"""
|
||||
Pipeline for text-to-image generation using VQ Diffusion
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
Args:
|
||||
vqvae ([`VQModel`]):
|
||||
Vector Quantized Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent
|
||||
representations.
|
||||
text_encoder ([`CLIPTextModel`]):
|
||||
Frozen text-encoder. VQ Diffusion uses the text portion of
|
||||
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
|
||||
the [clip-vit-base-patch32](https://huggingface.co/openai/clip-vit-base-patch32) variant.
|
||||
tokenizer (`CLIPTokenizer`):
|
||||
Tokenizer of class
|
||||
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
||||
transformer ([`Transformer2DModel`]):
|
||||
Conditional transformer to denoise the encoded image latents.
|
||||
scheduler ([`VQDiffusionScheduler`]):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
||||
"""
|
||||
|
||||
vqvae: VQModel
|
||||
text_encoder: CLIPTextModel
|
||||
tokenizer: CLIPTokenizer
|
||||
transformer: Transformer2DModel
|
||||
scheduler: VQDiffusionScheduler
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vqvae: VQModel,
|
||||
text_encoder: CLIPTextModel,
|
||||
tokenizer: CLIPTokenizer,
|
||||
transformer: Transformer2DModel,
|
||||
scheduler: VQDiffusionScheduler,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
vqvae=vqvae,
|
||||
transformer=transformer,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
num_inference_steps: int = 100,
|
||||
truncation_rate: float = 1.0,
|
||||
num_images_per_prompt: int = 1,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
callback_steps: Optional[int] = 1,
|
||||
) -> Union[ImagePipelineOutput, Tuple]:
|
||||
"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`):
|
||||
The prompt or prompts to guide the image generation.
|
||||
num_inference_steps (`int`, *optional*, defaults to 100):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
truncation_rate (`float`, *optional*, defaults to 1.0 (equivalent to no truncation)):
|
||||
Used to "truncate" the predicted classes for x_0 such that the cumulative probability for a pixel is at
|
||||
most `truncation_rate`. The lowest probabilities that would increase the cumulative probability above
|
||||
`truncation_rate` are set to zero.
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
||||
deterministic.
|
||||
latents (`torch.FloatTensor` of shape (batch), *optional*):
|
||||
Pre-generated noisy latents to be used as inputs for image generation. Must be valid embedding indices.
|
||||
Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will
|
||||
be generated of completely masked latent pixels.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generated image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple.
|
||||
callback (`Callable`, *optional*):
|
||||
A function that will be called every `callback_steps` steps during inference. The function will be
|
||||
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
||||
callback_steps (`int`, *optional*, defaults to 1):
|
||||
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
||||
called at every step.
|
||||
|
||||
Returns:
|
||||
[`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~ pipeline_utils.ImagePipelineOutput `] if
|
||||
`return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
|
||||
generated images.
|
||||
"""
|
||||
if isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
batch_size = batch_size * num_images_per_prompt
|
||||
|
||||
if (callback_steps is None) or (
|
||||
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
||||
f" {type(callback_steps)}."
|
||||
)
|
||||
|
||||
# get prompt text embeddings
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
|
||||
if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
|
||||
removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
||||
)
|
||||
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
|
||||
text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
|
||||
|
||||
# NOTE: This additional step of normalizing the text embeddings is from VQ-Diffusion.
|
||||
# While CLIP does normalize the pooled output of the text transformer when combining
|
||||
# the image and text embeddings, CLIP does not directly normalize the last hidden state.
|
||||
#
|
||||
# CLIP normalizing the pooled output.
|
||||
# https://github.com/huggingface/transformers/blob/d92e22d1f28324f513f3080e5c47c071a3916721/src/transformers/models/clip/modeling_clip.py#L1052-L1053
|
||||
text_embeddings = text_embeddings / text_embeddings.norm(dim=-1, keepdim=True)
|
||||
|
||||
# duplicate text embeddings for each generation per prompt
|
||||
text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
|
||||
# get the initial completely masked latents unless the user supplied it
|
||||
|
||||
latents_shape = (batch_size, self.transformer.num_latent_pixels)
|
||||
if latents is None:
|
||||
mask_class = self.transformer.num_vector_embeds - 1
|
||||
latents = torch.full(latents_shape, mask_class).to(self.device)
|
||||
else:
|
||||
if latents.shape != latents_shape:
|
||||
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
|
||||
if (latents < 0).any() or (latents >= self.transformer.num_vector_embeds).any():
|
||||
raise ValueError(
|
||||
"Unexpected latents value(s). All latents be valid embedding indices i.e. in the range 0,"
|
||||
f" {self.transformer.num_vector_embeds - 1} (inclusive)."
|
||||
)
|
||||
latents = latents.to(self.device)
|
||||
|
||||
# set timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=self.device)
|
||||
|
||||
timesteps_tensor = self.scheduler.timesteps.to(self.device)
|
||||
|
||||
sample = latents
|
||||
|
||||
for i, t in enumerate(self.progress_bar(timesteps_tensor)):
|
||||
# predict the un-noised image
|
||||
# model_output == `log_p_x_0`
|
||||
model_output = self.transformer(sample, encoder_hidden_states=text_embeddings, timestep=t).sample
|
||||
|
||||
model_output = self.truncate(model_output, truncation_rate)
|
||||
|
||||
# remove `log(0)`'s (`-inf`s)
|
||||
model_output = model_output.clamp(-70)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
sample = self.scheduler.step(model_output, timestep=t, sample=sample, generator=generator).prev_sample
|
||||
|
||||
# call the callback, if provided
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, sample)
|
||||
|
||||
embedding_channels = self.vqvae.config.vq_embed_dim
|
||||
embeddings_shape = (batch_size, self.transformer.height, self.transformer.width, embedding_channels)
|
||||
embeddings = self.vqvae.quantize.get_codebook_entry(sample, shape=embeddings_shape)
|
||||
image = self.vqvae.decode(embeddings, force_not_quantize=True).sample
|
||||
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
||||
|
||||
if output_type == "pil":
|
||||
image = self.numpy_to_pil(image)
|
||||
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
return ImagePipelineOutput(images=image)
|
||||
|
||||
def truncate(self, log_p_x_0: torch.FloatTensor, truncation_rate: float) -> torch.FloatTensor:
|
||||
"""
|
||||
Truncates log_p_x_0 such that for each column vector, the total cumulative probability is `truncation_rate` The
|
||||
lowest probabilities that would increase the cumulative probability above `truncation_rate` are set to zero.
|
||||
"""
|
||||
sorted_log_p_x_0, indices = torch.sort(log_p_x_0, 1, descending=True)
|
||||
sorted_p_x_0 = torch.exp(sorted_log_p_x_0)
|
||||
keep_mask = sorted_p_x_0.cumsum(dim=1) < truncation_rate
|
||||
|
||||
# Ensure that at least the largest probability is not zeroed out
|
||||
all_true = torch.full_like(keep_mask[:, 0:1, :], True)
|
||||
keep_mask = torch.cat((all_true, keep_mask), dim=1)
|
||||
keep_mask = keep_mask[:, :-1, :]
|
||||
|
||||
keep_mask = keep_mask.gather(1, indices.argsort(1))
|
||||
|
||||
rv = log_p_x_0.clone()
|
||||
|
||||
rv[~keep_mask] = -torch.inf # -inf = log(0)
|
||||
|
||||
return rv
|
||||
@@ -19,24 +19,18 @@ from ..utils import is_flax_available, is_scipy_available, is_torch_available
|
||||
if is_torch_available():
|
||||
from .scheduling_ddim import DDIMScheduler
|
||||
from .scheduling_ddpm import DDPMScheduler
|
||||
from .scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler
|
||||
from .scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler
|
||||
from .scheduling_euler_discrete import EulerDiscreteScheduler
|
||||
from .scheduling_ipndm import IPNDMScheduler
|
||||
from .scheduling_karras_ve import KarrasVeScheduler
|
||||
from .scheduling_pndm import PNDMScheduler
|
||||
from .scheduling_repaint import RePaintScheduler
|
||||
from .scheduling_sde_ve import ScoreSdeVeScheduler
|
||||
from .scheduling_sde_vp import ScoreSdeVpScheduler
|
||||
from .scheduling_utils import SchedulerMixin
|
||||
from .scheduling_vq_diffusion import VQDiffusionScheduler
|
||||
else:
|
||||
from ..utils.dummy_pt_objects import * # noqa F403
|
||||
|
||||
if is_flax_available():
|
||||
from .scheduling_ddim_flax import FlaxDDIMScheduler
|
||||
from .scheduling_ddpm_flax import FlaxDDPMScheduler
|
||||
from .scheduling_dpmsolver_multistep_flax import FlaxDPMSolverMultistepScheduler
|
||||
from .scheduling_karras_ve_flax import FlaxKarrasVeScheduler
|
||||
from .scheduling_lms_discrete_flax import FlaxLMSDiscreteScheduler
|
||||
from .scheduling_pndm_flax import FlaxPNDMScheduler
|
||||
|
||||
@@ -109,15 +109,6 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
"""
|
||||
|
||||
_compatible_classes = [
|
||||
"PNDMScheduler",
|
||||
"DDPMScheduler",
|
||||
"LMSDiscreteScheduler",
|
||||
"EulerDiscreteScheduler",
|
||||
"EulerAncestralDiscreteScheduler",
|
||||
"DPMSolverMultistepScheduler",
|
||||
]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
@@ -209,7 +200,6 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
eta: float = 0.0,
|
||||
use_clipped_model_output: bool = False,
|
||||
generator=None,
|
||||
variance_noise: Optional[torch.FloatTensor] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[DDIMSchedulerOutput, Tuple]:
|
||||
"""
|
||||
@@ -222,14 +212,8 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
sample (`torch.FloatTensor`):
|
||||
current instance of sample being created by diffusion process.
|
||||
eta (`float`): weight of noise for added noise in diffusion step.
|
||||
use_clipped_model_output (`bool`): if `True`, compute "corrected" `model_output` from the clipped
|
||||
predicted original sample. Necessary because predicted original sample is clipped to [-1, 1] when
|
||||
`self.config.clip_sample` is `True`. If no clipping has happened, "corrected" `model_output` would
|
||||
coincide with the one provided as input and `use_clipped_model_output` will have not effect.
|
||||
use_clipped_model_output (`bool`): TODO
|
||||
generator: random number generator.
|
||||
variance_noise (`torch.FloatTensor`): instead of generating noise for the variance using `generator`, we
|
||||
can directly provide the noise for the variance itself. This is useful for methods such as
|
||||
CycleDiffusion. (https://arxiv.org/abs/2210.05559)
|
||||
return_dict (`bool`): option for returning tuple rather than DDIMSchedulerOutput class
|
||||
|
||||
Returns:
|
||||
@@ -289,17 +273,8 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
if eta > 0:
|
||||
# randn_like does not support generator https://github.com/pytorch/pytorch/issues/27072
|
||||
device = model_output.device if torch.is_tensor(model_output) else "cpu"
|
||||
if variance_noise is not None and generator is not None:
|
||||
raise ValueError(
|
||||
"Cannot pass both generator and variance_noise. Please make sure that either `generator` or"
|
||||
" `variance_noise` stays `None`."
|
||||
)
|
||||
|
||||
if variance_noise is None:
|
||||
variance_noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator).to(
|
||||
device
|
||||
)
|
||||
variance = self._get_variance(timestep, prev_timestep) ** (0.5) * eta * variance_noise
|
||||
noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator).to(device)
|
||||
variance = self._get_variance(timestep, prev_timestep) ** (0.5) * eta * noise
|
||||
|
||||
prev_sample = prev_sample + variance
|
||||
|
||||
|
||||
@@ -102,15 +102,6 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
"""
|
||||
|
||||
_compatible_classes = [
|
||||
"DDIMScheduler",
|
||||
"PNDMScheduler",
|
||||
"LMSDiscreteScheduler",
|
||||
"EulerDiscreteScheduler",
|
||||
"EulerAncestralDiscreteScheduler",
|
||||
"DPMSolverMultistepScheduler",
|
||||
]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -1,506 +0,0 @@
|
||||
# Copyright 2022 TSAIL Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver
|
||||
|
||||
import math
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from .scheduling_utils import SchedulerMixin, SchedulerOutput
|
||||
|
||||
|
||||
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
|
||||
"""
|
||||
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
|
||||
(1-beta) over time from t = [0,1].
|
||||
|
||||
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
|
||||
to that part of the diffusion process.
|
||||
|
||||
|
||||
Args:
|
||||
num_diffusion_timesteps (`int`): the number of betas to produce.
|
||||
max_beta (`float`): the maximum beta to use; use values lower than 1 to
|
||||
prevent singularities.
|
||||
|
||||
Returns:
|
||||
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
|
||||
"""
|
||||
|
||||
def alpha_bar(time_step):
|
||||
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
|
||||
|
||||
betas = []
|
||||
for i in range(num_diffusion_timesteps):
|
||||
t1 = i / num_diffusion_timesteps
|
||||
t2 = (i + 1) / num_diffusion_timesteps
|
||||
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
|
||||
return torch.tensor(betas, dtype=torch.float32)
|
||||
|
||||
|
||||
class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
DPM-Solver (and the improved version DPM-Solver++) is a fast dedicated high-order solver for diffusion ODEs with
|
||||
the convergence order guarantee. Empirically, sampling by DPM-Solver with only 20 steps can generate high-quality
|
||||
samples, and it can generate quite good samples even in only 10 steps.
|
||||
|
||||
For more details, see the original paper: https://arxiv.org/abs/2206.00927 and https://arxiv.org/abs/2211.01095
|
||||
|
||||
Currently, we support the multistep DPM-Solver for both noise prediction models and data prediction models. We
|
||||
recommend to use `solver_order=2` for guided sampling, and `solver_order=3` for unconditional sampling.
|
||||
|
||||
We also support the "dynamic thresholding" method in Imagen (https://arxiv.org/abs/2205.11487). For pixel-space
|
||||
diffusion models, you can set both `algorithm_type="dpmsolver++"` and `thresholding=True` to use the dynamic
|
||||
thresholding. Note that the thresholding method is unsuitable for latent-space diffusion models (such as
|
||||
stable-diffusion).
|
||||
|
||||
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
|
||||
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
|
||||
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
|
||||
[`~ConfigMixin.from_config`] functions.
|
||||
|
||||
Args:
|
||||
num_train_timesteps (`int`): number of diffusion steps used to train the model.
|
||||
beta_start (`float`): the starting `beta` value of inference.
|
||||
beta_end (`float`): the final `beta` value.
|
||||
beta_schedule (`str`):
|
||||
the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
|
||||
`linear`, `scaled_linear`, or `squaredcos_cap_v2`.
|
||||
trained_betas (`np.ndarray`, optional):
|
||||
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
|
||||
solver_order (`int`, default `2`):
|
||||
the order of DPM-Solver; can be `1` or `2` or `3`. We recommend to use `solver_order=2` for guided
|
||||
sampling, and `solver_order=3` for unconditional sampling.
|
||||
predict_epsilon (`bool`, default `True`):
|
||||
we currently support both the noise prediction model and the data prediction model. If the model predicts
|
||||
the noise / epsilon, set `predict_epsilon` to `True`. If the model predicts the data / x0 directly, set
|
||||
`predict_epsilon` to `False`.
|
||||
thresholding (`bool`, default `False`):
|
||||
whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487).
|
||||
For pixel-space diffusion models, you can set both `algorithm_type=dpmsolver++` and `thresholding=True` to
|
||||
use the dynamic thresholding. Note that the thresholding method is unsuitable for latent-space diffusion
|
||||
models (such as stable-diffusion).
|
||||
dynamic_thresholding_ratio (`float`, default `0.995`):
|
||||
the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen
|
||||
(https://arxiv.org/abs/2205.11487).
|
||||
sample_max_value (`float`, default `1.0`):
|
||||
the threshold value for dynamic thresholding. Valid only when `thresholding=True` and
|
||||
`algorithm_type="dpmsolver++`.
|
||||
algorithm_type (`str`, default `dpmsolver++`):
|
||||
the algorithm type for the solver. Either `dpmsolver` or `dpmsolver++`. The `dpmsolver` type implements the
|
||||
algorithms in https://arxiv.org/abs/2206.00927, and the `dpmsolver++` type implements the algorithms in
|
||||
https://arxiv.org/abs/2211.01095. We recommend to use `dpmsolver++` with `solver_order=2` for guided
|
||||
sampling (e.g. stable-diffusion).
|
||||
solver_type (`str`, default `midpoint`):
|
||||
the solver type for the second-order solver. Either `midpoint` or `heun`. The solver type slightly affects
|
||||
the sample quality, especially for small number of steps. We empirically find that `midpoint` solvers are
|
||||
slightly better, so we recommend to use the `midpoint` type.
|
||||
lower_order_final (`bool`, default `True`):
|
||||
whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. We empirically
|
||||
find this trick can stabilize the sampling of DPM-Solver for steps < 15, especially for steps <= 10.
|
||||
|
||||
"""
|
||||
|
||||
_compatible_classes = [
|
||||
"DDIMScheduler",
|
||||
"DDPMScheduler",
|
||||
"PNDMScheduler",
|
||||
"LMSDiscreteScheduler",
|
||||
"EulerDiscreteScheduler",
|
||||
"EulerAncestralDiscreteScheduler",
|
||||
]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_train_timesteps: int = 1000,
|
||||
beta_start: float = 0.0001,
|
||||
beta_end: float = 0.02,
|
||||
beta_schedule: str = "linear",
|
||||
trained_betas: Optional[np.ndarray] = None,
|
||||
solver_order: int = 2,
|
||||
predict_epsilon: bool = True,
|
||||
thresholding: bool = False,
|
||||
dynamic_thresholding_ratio: float = 0.995,
|
||||
sample_max_value: float = 1.0,
|
||||
algorithm_type: str = "dpmsolver++",
|
||||
solver_type: str = "midpoint",
|
||||
lower_order_final: bool = True,
|
||||
):
|
||||
if trained_betas is not None:
|
||||
self.betas = torch.from_numpy(trained_betas)
|
||||
elif beta_schedule == "linear":
|
||||
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
||||
elif beta_schedule == "scaled_linear":
|
||||
# this schedule is very specific to the latent diffusion model.
|
||||
self.betas = (
|
||||
torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
||||
)
|
||||
elif beta_schedule == "squaredcos_cap_v2":
|
||||
# Glide cosine schedule
|
||||
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
||||
else:
|
||||
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
|
||||
|
||||
self.alphas = 1.0 - self.betas
|
||||
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
||||
# Currently we only support VP-type noise schedule
|
||||
self.alpha_t = torch.sqrt(self.alphas_cumprod)
|
||||
self.sigma_t = torch.sqrt(1 - self.alphas_cumprod)
|
||||
self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t)
|
||||
|
||||
# standard deviation of the initial noise distribution
|
||||
self.init_noise_sigma = 1.0
|
||||
|
||||
# settings for DPM-Solver
|
||||
if algorithm_type not in ["dpmsolver", "dpmsolver++"]:
|
||||
raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}")
|
||||
if solver_type not in ["midpoint", "heun"]:
|
||||
raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}")
|
||||
|
||||
# setable values
|
||||
self.num_inference_steps = None
|
||||
timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy()
|
||||
self.timesteps = torch.from_numpy(timesteps)
|
||||
self.model_outputs = [None] * solver_order
|
||||
self.lower_order_nums = 0
|
||||
|
||||
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
|
||||
"""
|
||||
Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
|
||||
|
||||
Args:
|
||||
num_inference_steps (`int`):
|
||||
the number of diffusion steps used when generating samples with a pre-trained model.
|
||||
device (`str` or `torch.device`, optional):
|
||||
the device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
"""
|
||||
self.num_inference_steps = num_inference_steps
|
||||
timesteps = (
|
||||
np.linspace(0, self.num_train_timesteps - 1, num_inference_steps + 1)
|
||||
.round()[::-1][:-1]
|
||||
.copy()
|
||||
.astype(np.int64)
|
||||
)
|
||||
self.timesteps = torch.from_numpy(timesteps).to(device)
|
||||
self.model_outputs = [
|
||||
None,
|
||||
] * self.config.solver_order
|
||||
self.lower_order_nums = 0
|
||||
|
||||
def convert_model_output(
|
||||
self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
Convert the model output to the corresponding type that the algorithm (DPM-Solver / DPM-Solver++) needs.
|
||||
|
||||
DPM-Solver is designed to discretize an integral of the noise prediciton model, and DPM-Solver++ is designed to
|
||||
discretize an integral of the data prediction model. So we need to first convert the model output to the
|
||||
corresponding type to match the algorithm.
|
||||
|
||||
Note that the algorithm type and the model type is decoupled. That is to say, we can use either DPM-Solver or
|
||||
DPM-Solver++ for both noise prediction model and data prediction model.
|
||||
|
||||
Args:
|
||||
model_output (`torch.FloatTensor`): direct output from learned diffusion model.
|
||||
timestep (`int`): current discrete timestep in the diffusion chain.
|
||||
sample (`torch.FloatTensor`):
|
||||
current instance of sample being created by diffusion process.
|
||||
|
||||
Returns:
|
||||
`torch.FloatTensor`: the converted model output.
|
||||
"""
|
||||
# DPM-Solver++ needs to solve an integral of the data prediction model.
|
||||
if self.config.algorithm_type == "dpmsolver++":
|
||||
if self.config.predict_epsilon:
|
||||
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
|
||||
x0_pred = (sample - sigma_t * model_output) / alpha_t
|
||||
else:
|
||||
x0_pred = model_output
|
||||
if self.config.thresholding:
|
||||
# Dynamic thresholding in https://arxiv.org/abs/2205.11487
|
||||
dynamic_max_val = torch.quantile(
|
||||
torch.abs(x0_pred).reshape((x0_pred.shape[0], -1)), self.config.dynamic_thresholding_ratio, dim=1
|
||||
)
|
||||
dynamic_max_val = torch.maximum(
|
||||
dynamic_max_val,
|
||||
self.config.sample_max_value * torch.ones_like(dynamic_max_val).to(dynamic_max_val.device),
|
||||
)[(...,) + (None,) * (x0_pred.ndim - 1)]
|
||||
x0_pred = torch.clamp(x0_pred, -dynamic_max_val, dynamic_max_val) / dynamic_max_val
|
||||
return x0_pred
|
||||
# DPM-Solver needs to solve an integral of the noise prediction model.
|
||||
elif self.config.algorithm_type == "dpmsolver":
|
||||
if self.config.predict_epsilon:
|
||||
return model_output
|
||||
else:
|
||||
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
|
||||
epsilon = (sample - alpha_t * model_output) / sigma_t
|
||||
return epsilon
|
||||
|
||||
def dpm_solver_first_order_update(
|
||||
self,
|
||||
model_output: torch.FloatTensor,
|
||||
timestep: int,
|
||||
prev_timestep: int,
|
||||
sample: torch.FloatTensor,
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
One step for the first-order DPM-Solver (equivalent to DDIM).
|
||||
|
||||
See https://arxiv.org/abs/2206.00927 for the detailed derivation.
|
||||
|
||||
Args:
|
||||
model_output (`torch.FloatTensor`): direct output from learned diffusion model.
|
||||
timestep (`int`): current discrete timestep in the diffusion chain.
|
||||
prev_timestep (`int`): previous discrete timestep in the diffusion chain.
|
||||
sample (`torch.FloatTensor`):
|
||||
current instance of sample being created by diffusion process.
|
||||
|
||||
Returns:
|
||||
`torch.FloatTensor`: the sample tensor at the previous timestep.
|
||||
"""
|
||||
lambda_t, lambda_s = self.lambda_t[prev_timestep], self.lambda_t[timestep]
|
||||
alpha_t, alpha_s = self.alpha_t[prev_timestep], self.alpha_t[timestep]
|
||||
sigma_t, sigma_s = self.sigma_t[prev_timestep], self.sigma_t[timestep]
|
||||
h = lambda_t - lambda_s
|
||||
if self.config.algorithm_type == "dpmsolver++":
|
||||
x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output
|
||||
elif self.config.algorithm_type == "dpmsolver":
|
||||
x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output
|
||||
return x_t
|
||||
|
||||
def multistep_dpm_solver_second_order_update(
|
||||
self,
|
||||
model_output_list: List[torch.FloatTensor],
|
||||
timestep_list: List[int],
|
||||
prev_timestep: int,
|
||||
sample: torch.FloatTensor,
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
One step for the second-order multistep DPM-Solver.
|
||||
|
||||
Args:
|
||||
model_output_list (`List[torch.FloatTensor]`):
|
||||
direct outputs from learned diffusion model at current and latter timesteps.
|
||||
timestep (`int`): current and latter discrete timestep in the diffusion chain.
|
||||
prev_timestep (`int`): previous discrete timestep in the diffusion chain.
|
||||
sample (`torch.FloatTensor`):
|
||||
current instance of sample being created by diffusion process.
|
||||
|
||||
Returns:
|
||||
`torch.FloatTensor`: the sample tensor at the previous timestep.
|
||||
"""
|
||||
t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2]
|
||||
m0, m1 = model_output_list[-1], model_output_list[-2]
|
||||
lambda_t, lambda_s0, lambda_s1 = self.lambda_t[t], self.lambda_t[s0], self.lambda_t[s1]
|
||||
alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0]
|
||||
sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0]
|
||||
h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1
|
||||
r0 = h_0 / h
|
||||
D0, D1 = m0, (1.0 / r0) * (m0 - m1)
|
||||
if self.config.algorithm_type == "dpmsolver++":
|
||||
# See https://arxiv.org/abs/2211.01095 for detailed derivations
|
||||
if self.config.solver_type == "midpoint":
|
||||
x_t = (
|
||||
(sigma_t / sigma_s0) * sample
|
||||
- (alpha_t * (torch.exp(-h) - 1.0)) * D0
|
||||
- 0.5 * (alpha_t * (torch.exp(-h) - 1.0)) * D1
|
||||
)
|
||||
elif self.config.solver_type == "heun":
|
||||
x_t = (
|
||||
(sigma_t / sigma_s0) * sample
|
||||
- (alpha_t * (torch.exp(-h) - 1.0)) * D0
|
||||
+ (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1
|
||||
)
|
||||
elif self.config.algorithm_type == "dpmsolver":
|
||||
# See https://arxiv.org/abs/2206.00927 for detailed derivations
|
||||
if self.config.solver_type == "midpoint":
|
||||
x_t = (
|
||||
(alpha_t / alpha_s0) * sample
|
||||
- (sigma_t * (torch.exp(h) - 1.0)) * D0
|
||||
- 0.5 * (sigma_t * (torch.exp(h) - 1.0)) * D1
|
||||
)
|
||||
elif self.config.solver_type == "heun":
|
||||
x_t = (
|
||||
(alpha_t / alpha_s0) * sample
|
||||
- (sigma_t * (torch.exp(h) - 1.0)) * D0
|
||||
- (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
|
||||
)
|
||||
return x_t
|
||||
|
||||
def multistep_dpm_solver_third_order_update(
|
||||
self,
|
||||
model_output_list: List[torch.FloatTensor],
|
||||
timestep_list: List[int],
|
||||
prev_timestep: int,
|
||||
sample: torch.FloatTensor,
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
One step for the third-order multistep DPM-Solver.
|
||||
|
||||
Args:
|
||||
model_output_list (`List[torch.FloatTensor]`):
|
||||
direct outputs from learned diffusion model at current and latter timesteps.
|
||||
timestep (`int`): current and latter discrete timestep in the diffusion chain.
|
||||
prev_timestep (`int`): previous discrete timestep in the diffusion chain.
|
||||
sample (`torch.FloatTensor`):
|
||||
current instance of sample being created by diffusion process.
|
||||
|
||||
Returns:
|
||||
`torch.FloatTensor`: the sample tensor at the previous timestep.
|
||||
"""
|
||||
t, s0, s1, s2 = prev_timestep, timestep_list[-1], timestep_list[-2], timestep_list[-3]
|
||||
m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3]
|
||||
lambda_t, lambda_s0, lambda_s1, lambda_s2 = (
|
||||
self.lambda_t[t],
|
||||
self.lambda_t[s0],
|
||||
self.lambda_t[s1],
|
||||
self.lambda_t[s2],
|
||||
)
|
||||
alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0]
|
||||
sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0]
|
||||
h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2
|
||||
r0, r1 = h_0 / h, h_1 / h
|
||||
D0 = m0
|
||||
D1_0, D1_1 = (1.0 / r0) * (m0 - m1), (1.0 / r1) * (m1 - m2)
|
||||
D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1)
|
||||
D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1)
|
||||
if self.config.algorithm_type == "dpmsolver++":
|
||||
# See https://arxiv.org/abs/2206.00927 for detailed derivations
|
||||
x_t = (
|
||||
(sigma_t / sigma_s0) * sample
|
||||
- (alpha_t * (torch.exp(-h) - 1.0)) * D0
|
||||
+ (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1
|
||||
- (alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2
|
||||
)
|
||||
elif self.config.algorithm_type == "dpmsolver":
|
||||
# See https://arxiv.org/abs/2206.00927 for detailed derivations
|
||||
x_t = (
|
||||
(alpha_t / alpha_s0) * sample
|
||||
- (sigma_t * (torch.exp(h) - 1.0)) * D0
|
||||
- (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
|
||||
- (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2
|
||||
)
|
||||
return x_t
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: torch.FloatTensor,
|
||||
timestep: int,
|
||||
sample: torch.FloatTensor,
|
||||
return_dict: bool = True,
|
||||
) -> Union[SchedulerOutput, Tuple]:
|
||||
"""
|
||||
Step function propagating the sample with the multistep DPM-Solver.
|
||||
|
||||
Args:
|
||||
model_output (`torch.FloatTensor`): direct output from learned diffusion model.
|
||||
timestep (`int`): current discrete timestep in the diffusion chain.
|
||||
sample (`torch.FloatTensor`):
|
||||
current instance of sample being created by diffusion process.
|
||||
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
|
||||
|
||||
Returns:
|
||||
[`~scheduling_utils.SchedulerOutput`] or `tuple`: [`~scheduling_utils.SchedulerOutput`] if `return_dict` is
|
||||
True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
|
||||
|
||||
"""
|
||||
if self.num_inference_steps is None:
|
||||
raise ValueError(
|
||||
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
|
||||
)
|
||||
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
step_index = (self.timesteps == timestep).nonzero()
|
||||
if len(step_index) == 0:
|
||||
step_index = len(self.timesteps) - 1
|
||||
else:
|
||||
step_index = step_index.item()
|
||||
prev_timestep = 0 if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1]
|
||||
lower_order_final = (
|
||||
(step_index == len(self.timesteps) - 1) and self.config.lower_order_final and len(self.timesteps) < 15
|
||||
)
|
||||
lower_order_second = (
|
||||
(step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15
|
||||
)
|
||||
|
||||
model_output = self.convert_model_output(model_output, timestep, sample)
|
||||
for i in range(self.config.solver_order - 1):
|
||||
self.model_outputs[i] = self.model_outputs[i + 1]
|
||||
self.model_outputs[-1] = model_output
|
||||
|
||||
if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final:
|
||||
prev_sample = self.dpm_solver_first_order_update(model_output, timestep, prev_timestep, sample)
|
||||
elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second:
|
||||
timestep_list = [self.timesteps[step_index - 1], timestep]
|
||||
prev_sample = self.multistep_dpm_solver_second_order_update(
|
||||
self.model_outputs, timestep_list, prev_timestep, sample
|
||||
)
|
||||
else:
|
||||
timestep_list = [self.timesteps[step_index - 2], self.timesteps[step_index - 1], timestep]
|
||||
prev_sample = self.multistep_dpm_solver_third_order_update(
|
||||
self.model_outputs, timestep_list, prev_timestep, sample
|
||||
)
|
||||
|
||||
if self.lower_order_nums < self.config.solver_order:
|
||||
self.lower_order_nums += 1
|
||||
|
||||
if not return_dict:
|
||||
return (prev_sample,)
|
||||
|
||||
return SchedulerOutput(prev_sample=prev_sample)
|
||||
|
||||
def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
|
||||
"""
|
||||
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
||||
current timestep.
|
||||
|
||||
Args:
|
||||
sample (`torch.FloatTensor`): input sample
|
||||
|
||||
Returns:
|
||||
`torch.FloatTensor`: scaled input sample
|
||||
"""
|
||||
return sample
|
||||
|
||||
def add_noise(
|
||||
self,
|
||||
original_samples: torch.FloatTensor,
|
||||
noise: torch.FloatTensor,
|
||||
timesteps: torch.IntTensor,
|
||||
) -> torch.FloatTensor:
|
||||
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
|
||||
self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
|
||||
timesteps = timesteps.to(original_samples.device)
|
||||
|
||||
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
|
||||
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
|
||||
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
|
||||
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
|
||||
|
||||
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
|
||||
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
|
||||
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
|
||||
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
|
||||
|
||||
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
|
||||
return noisy_samples
|
||||
|
||||
def __len__(self):
|
||||
return self.config.num_train_timesteps
|
||||
@@ -1,590 +0,0 @@
|
||||
# Copyright 2022 TSAIL Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import flax
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left
|
||||
|
||||
|
||||
def betas_for_alpha_bar(num_diffusion_timesteps: int, max_beta=0.999) -> jnp.ndarray:
|
||||
"""
|
||||
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
|
||||
(1-beta) over time from t = [0,1].
|
||||
|
||||
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
|
||||
to that part of the diffusion process.
|
||||
|
||||
|
||||
Args:
|
||||
num_diffusion_timesteps (`int`): the number of betas to produce.
|
||||
max_beta (`float`): the maximum beta to use; use values lower than 1 to
|
||||
prevent singularities.
|
||||
|
||||
Returns:
|
||||
betas (`jnp.ndarray`): the betas used by the scheduler to step the model outputs
|
||||
"""
|
||||
|
||||
def alpha_bar(time_step):
|
||||
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
|
||||
|
||||
betas = []
|
||||
for i in range(num_diffusion_timesteps):
|
||||
t1 = i / num_diffusion_timesteps
|
||||
t2 = (i + 1) / num_diffusion_timesteps
|
||||
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
|
||||
return jnp.array(betas, dtype=jnp.float32)
|
||||
|
||||
|
||||
@flax.struct.dataclass
|
||||
class DPMSolverMultistepSchedulerState:
|
||||
# setable values
|
||||
num_inference_steps: Optional[int] = None
|
||||
timesteps: Optional[jnp.ndarray] = None
|
||||
|
||||
# running values
|
||||
model_outputs: Optional[jnp.ndarray] = None
|
||||
lower_order_nums: Optional[int] = None
|
||||
step_index: Optional[int] = None
|
||||
prev_timestep: Optional[int] = None
|
||||
cur_sample: Optional[jnp.ndarray] = None
|
||||
|
||||
@classmethod
|
||||
def create(cls, num_train_timesteps: int):
|
||||
return cls(timesteps=jnp.arange(0, num_train_timesteps)[::-1])
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlaxDPMSolverMultistepSchedulerOutput(FlaxSchedulerOutput):
|
||||
state: DPMSolverMultistepSchedulerState
|
||||
|
||||
|
||||
class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
DPM-Solver (and the improved version DPM-Solver++) is a fast dedicated high-order solver for diffusion ODEs with
|
||||
the convergence order guarantee. Empirically, sampling by DPM-Solver with only 20 steps can generate high-quality
|
||||
samples, and it can generate quite good samples even in only 10 steps.
|
||||
|
||||
For more details, see the original paper: https://arxiv.org/abs/2206.00927 and https://arxiv.org/abs/2211.01095
|
||||
|
||||
Currently, we support the multistep DPM-Solver for both noise prediction models and data prediction models. We
|
||||
recommend to use `solver_order=2` for guided sampling, and `solver_order=3` for unconditional sampling.
|
||||
|
||||
We also support the "dynamic thresholding" method in Imagen (https://arxiv.org/abs/2205.11487). For pixel-space
|
||||
diffusion models, you can set both `algorithm_type="dpmsolver++"` and `thresholding=True` to use the dynamic
|
||||
thresholding. Note that the thresholding method is unsuitable for latent-space diffusion models (such as
|
||||
stable-diffusion).
|
||||
|
||||
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
|
||||
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
|
||||
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
|
||||
[`~ConfigMixin.from_config`] functions.
|
||||
|
||||
For more details, see the original paper: https://arxiv.org/abs/2206.00927 and https://arxiv.org/abs/2211.01095
|
||||
|
||||
Args:
|
||||
num_train_timesteps (`int`): number of diffusion steps used to train the model.
|
||||
beta_start (`float`): the starting `beta` value of inference.
|
||||
beta_end (`float`): the final `beta` value.
|
||||
beta_schedule (`str`):
|
||||
the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
|
||||
`linear`, `scaled_linear`, or `squaredcos_cap_v2`.
|
||||
trained_betas (`np.ndarray`, optional):
|
||||
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
|
||||
solver_order (`int`, default `2`):
|
||||
the order of DPM-Solver; can be `1` or `2` or `3`. We recommend to use `solver_order=2` for guided
|
||||
sampling, and `solver_order=3` for unconditional sampling.
|
||||
predict_epsilon (`bool`, default `True`):
|
||||
we currently support both the noise prediction model and the data prediction model. If the model predicts
|
||||
the noise / epsilon, set `predict_epsilon` to `True`. If the model predicts the data / x0 directly, set
|
||||
`predict_epsilon` to `False`.
|
||||
thresholding (`bool`, default `False`):
|
||||
whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487).
|
||||
For pixel-space diffusion models, you can set both `algorithm_type=dpmsolver++` and `thresholding=True` to
|
||||
use the dynamic thresholding. Note that the thresholding method is unsuitable for latent-space diffusion
|
||||
models (such as stable-diffusion).
|
||||
dynamic_thresholding_ratio (`float`, default `0.995`):
|
||||
the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen
|
||||
(https://arxiv.org/abs/2205.11487).
|
||||
sample_max_value (`float`, default `1.0`):
|
||||
the threshold value for dynamic thresholding. Valid only when `thresholding=True` and
|
||||
`algorithm_type="dpmsolver++`.
|
||||
algorithm_type (`str`, default `dpmsolver++`):
|
||||
the algorithm type for the solver. Either `dpmsolver` or `dpmsolver++`. The `dpmsolver` type implements the
|
||||
algorithms in https://arxiv.org/abs/2206.00927, and the `dpmsolver++` type implements the algorithms in
|
||||
https://arxiv.org/abs/2211.01095. We recommend to use `dpmsolver++` with `solver_order=2` for guided
|
||||
sampling (e.g. stable-diffusion).
|
||||
solver_type (`str`, default `midpoint`):
|
||||
the solver type for the second-order solver. Either `midpoint` or `heun`. The solver type slightly affects
|
||||
the sample quality, especially for small number of steps. We empirically find that `midpoint` solvers are
|
||||
slightly better, so we recommend to use the `midpoint` type.
|
||||
lower_order_final (`bool`, default `True`):
|
||||
whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. We empirically
|
||||
find this trick can stabilize the sampling of DPM-Solver for steps < 15, especially for steps <= 10.
|
||||
|
||||
"""
|
||||
|
||||
@property
|
||||
def has_state(self):
|
||||
return True
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_train_timesteps: int = 1000,
|
||||
beta_start: float = 0.0001,
|
||||
beta_end: float = 0.02,
|
||||
beta_schedule: str = "linear",
|
||||
trained_betas: Optional[jnp.ndarray] = None,
|
||||
solver_order: int = 2,
|
||||
predict_epsilon: bool = True,
|
||||
thresholding: bool = False,
|
||||
dynamic_thresholding_ratio: float = 0.995,
|
||||
sample_max_value: float = 1.0,
|
||||
algorithm_type: str = "dpmsolver++",
|
||||
solver_type: str = "midpoint",
|
||||
lower_order_final: bool = True,
|
||||
):
|
||||
if trained_betas is not None:
|
||||
self.betas = jnp.asarray(trained_betas)
|
||||
elif beta_schedule == "linear":
|
||||
self.betas = jnp.linspace(beta_start, beta_end, num_train_timesteps, dtype=jnp.float32)
|
||||
elif beta_schedule == "scaled_linear":
|
||||
# this schedule is very specific to the latent diffusion model.
|
||||
self.betas = jnp.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=jnp.float32) ** 2
|
||||
elif beta_schedule == "squaredcos_cap_v2":
|
||||
# Glide cosine schedule
|
||||
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
||||
else:
|
||||
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
|
||||
|
||||
self.alphas = 1.0 - self.betas
|
||||
self.alphas_cumprod = jnp.cumprod(self.alphas, axis=0)
|
||||
# Currently we only support VP-type noise schedule
|
||||
self.alpha_t = jnp.sqrt(self.alphas_cumprod)
|
||||
self.sigma_t = jnp.sqrt(1 - self.alphas_cumprod)
|
||||
self.lambda_t = jnp.log(self.alpha_t) - jnp.log(self.sigma_t)
|
||||
|
||||
# standard deviation of the initial noise distribution
|
||||
self.init_noise_sigma = 1.0
|
||||
|
||||
# settings for DPM-Solver
|
||||
if algorithm_type not in ["dpmsolver", "dpmsolver++"]:
|
||||
raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}")
|
||||
if solver_type not in ["midpoint", "heun"]:
|
||||
raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}")
|
||||
|
||||
def create_state(self):
|
||||
return DPMSolverMultistepSchedulerState.create(num_train_timesteps=self.config.num_train_timesteps)
|
||||
|
||||
def set_timesteps(
|
||||
self, state: DPMSolverMultistepSchedulerState, num_inference_steps: int, shape: Tuple
|
||||
) -> DPMSolverMultistepSchedulerState:
|
||||
"""
|
||||
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
|
||||
|
||||
Args:
|
||||
state (`DPMSolverMultistepSchedulerState`):
|
||||
the `FlaxDPMSolverMultistepScheduler` state data class instance.
|
||||
num_inference_steps (`int`):
|
||||
the number of diffusion steps used when generating samples with a pre-trained model.
|
||||
shape (`Tuple`):
|
||||
the shape of the samples to be generated.
|
||||
"""
|
||||
timesteps = (
|
||||
jnp.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1)
|
||||
.round()[::-1][:-1]
|
||||
.astype(jnp.int32)
|
||||
)
|
||||
|
||||
return state.replace(
|
||||
num_inference_steps=num_inference_steps,
|
||||
timesteps=timesteps,
|
||||
model_outputs=jnp.zeros((self.config.solver_order,) + shape),
|
||||
lower_order_nums=0,
|
||||
step_index=0,
|
||||
prev_timestep=-1,
|
||||
cur_sample=jnp.zeros(shape),
|
||||
)
|
||||
|
||||
def convert_model_output(
|
||||
self,
|
||||
model_output: jnp.ndarray,
|
||||
timestep: int,
|
||||
sample: jnp.ndarray,
|
||||
) -> jnp.ndarray:
|
||||
"""
|
||||
Convert the model output to the corresponding type that the algorithm (DPM-Solver / DPM-Solver++) needs.
|
||||
|
||||
DPM-Solver is designed to discretize an integral of the noise prediciton model, and DPM-Solver++ is designed to
|
||||
discretize an integral of the data prediction model. So we need to first convert the model output to the
|
||||
corresponding type to match the algorithm.
|
||||
|
||||
Note that the algorithm type and the model type is decoupled. That is to say, we can use either DPM-Solver or
|
||||
DPM-Solver++ for both noise prediction model and data prediction model.
|
||||
|
||||
Args:
|
||||
model_output (`jnp.ndarray`): direct output from learned diffusion model.
|
||||
timestep (`int`): current discrete timestep in the diffusion chain.
|
||||
sample (`jnp.ndarray`):
|
||||
current instance of sample being created by diffusion process.
|
||||
|
||||
Returns:
|
||||
`jnp.ndarray`: the converted model output.
|
||||
"""
|
||||
# DPM-Solver++ needs to solve an integral of the data prediction model.
|
||||
if self.config.algorithm_type == "dpmsolver++":
|
||||
if self.config.predict_epsilon:
|
||||
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
|
||||
x0_pred = (sample - sigma_t * model_output) / alpha_t
|
||||
else:
|
||||
x0_pred = model_output
|
||||
if self.config.thresholding:
|
||||
# Dynamic thresholding in https://arxiv.org/abs/2205.11487
|
||||
dynamic_max_val = jnp.percentile(
|
||||
jnp.abs(x0_pred), self.config.dynamic_thresholding_ratio, axis=tuple(range(1, x0_pred.ndim))
|
||||
)
|
||||
dynamic_max_val = jnp.maximum(
|
||||
dynamic_max_val, self.config.sample_max_value * jnp.ones_like(dynamic_max_val)
|
||||
)
|
||||
x0_pred = jnp.clip(x0_pred, -dynamic_max_val, dynamic_max_val) / dynamic_max_val
|
||||
return x0_pred
|
||||
# DPM-Solver needs to solve an integral of the noise prediction model.
|
||||
elif self.config.algorithm_type == "dpmsolver":
|
||||
if self.config.predict_epsilon:
|
||||
return model_output
|
||||
else:
|
||||
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
|
||||
epsilon = (sample - alpha_t * model_output) / sigma_t
|
||||
return epsilon
|
||||
|
||||
def dpm_solver_first_order_update(
|
||||
self, model_output: jnp.ndarray, timestep: int, prev_timestep: int, sample: jnp.ndarray
|
||||
) -> jnp.ndarray:
|
||||
"""
|
||||
One step for the first-order DPM-Solver (equivalent to DDIM).
|
||||
|
||||
See https://arxiv.org/abs/2206.00927 for the detailed derivation.
|
||||
|
||||
Args:
|
||||
model_output (`jnp.ndarray`): direct output from learned diffusion model.
|
||||
timestep (`int`): current discrete timestep in the diffusion chain.
|
||||
prev_timestep (`int`): previous discrete timestep in the diffusion chain.
|
||||
sample (`jnp.ndarray`):
|
||||
current instance of sample being created by diffusion process.
|
||||
|
||||
Returns:
|
||||
`jnp.ndarray`: the sample tensor at the previous timestep.
|
||||
"""
|
||||
t, s0 = prev_timestep, timestep
|
||||
m0 = model_output
|
||||
lambda_t, lambda_s = self.lambda_t[t], self.lambda_t[s0]
|
||||
alpha_t, alpha_s = self.alpha_t[t], self.alpha_t[s0]
|
||||
sigma_t, sigma_s = self.sigma_t[t], self.sigma_t[s0]
|
||||
h = lambda_t - lambda_s
|
||||
if self.config.algorithm_type == "dpmsolver++":
|
||||
x_t = (sigma_t / sigma_s) * sample - (alpha_t * (jnp.exp(-h) - 1.0)) * m0
|
||||
elif self.config.algorithm_type == "dpmsolver":
|
||||
x_t = (alpha_t / alpha_s) * sample - (sigma_t * (jnp.exp(h) - 1.0)) * m0
|
||||
return x_t
|
||||
|
||||
def multistep_dpm_solver_second_order_update(
|
||||
self,
|
||||
model_output_list: jnp.ndarray,
|
||||
timestep_list: List[int],
|
||||
prev_timestep: int,
|
||||
sample: jnp.ndarray,
|
||||
) -> jnp.ndarray:
|
||||
"""
|
||||
One step for the second-order multistep DPM-Solver.
|
||||
|
||||
Args:
|
||||
model_output_list (`List[jnp.ndarray]`):
|
||||
direct outputs from learned diffusion model at current and latter timesteps.
|
||||
timestep (`int`): current and latter discrete timestep in the diffusion chain.
|
||||
prev_timestep (`int`): previous discrete timestep in the diffusion chain.
|
||||
sample (`jnp.ndarray`):
|
||||
current instance of sample being created by diffusion process.
|
||||
|
||||
Returns:
|
||||
`jnp.ndarray`: the sample tensor at the previous timestep.
|
||||
"""
|
||||
t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2]
|
||||
m0, m1 = model_output_list[-1], model_output_list[-2]
|
||||
lambda_t, lambda_s0, lambda_s1 = self.lambda_t[t], self.lambda_t[s0], self.lambda_t[s1]
|
||||
alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0]
|
||||
sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0]
|
||||
h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1
|
||||
r0 = h_0 / h
|
||||
D0, D1 = m0, (1.0 / r0) * (m0 - m1)
|
||||
if self.config.algorithm_type == "dpmsolver++":
|
||||
# See https://arxiv.org/abs/2211.01095 for detailed derivations
|
||||
if self.config.solver_type == "midpoint":
|
||||
x_t = (
|
||||
(sigma_t / sigma_s0) * sample
|
||||
- (alpha_t * (jnp.exp(-h) - 1.0)) * D0
|
||||
- 0.5 * (alpha_t * (jnp.exp(-h) - 1.0)) * D1
|
||||
)
|
||||
elif self.config.solver_type == "heun":
|
||||
x_t = (
|
||||
(sigma_t / sigma_s0) * sample
|
||||
- (alpha_t * (jnp.exp(-h) - 1.0)) * D0
|
||||
+ (alpha_t * ((jnp.exp(-h) - 1.0) / h + 1.0)) * D1
|
||||
)
|
||||
elif self.config.algorithm_type == "dpmsolver":
|
||||
# See https://arxiv.org/abs/2206.00927 for detailed derivations
|
||||
if self.config.solver_type == "midpoint":
|
||||
x_t = (
|
||||
(alpha_t / alpha_s0) * sample
|
||||
- (sigma_t * (jnp.exp(h) - 1.0)) * D0
|
||||
- 0.5 * (sigma_t * (jnp.exp(h) - 1.0)) * D1
|
||||
)
|
||||
elif self.config.solver_type == "heun":
|
||||
x_t = (
|
||||
(alpha_t / alpha_s0) * sample
|
||||
- (sigma_t * (jnp.exp(h) - 1.0)) * D0
|
||||
- (sigma_t * ((jnp.exp(h) - 1.0) / h - 1.0)) * D1
|
||||
)
|
||||
return x_t
|
||||
|
||||
def multistep_dpm_solver_third_order_update(
|
||||
self,
|
||||
model_output_list: jnp.ndarray,
|
||||
timestep_list: List[int],
|
||||
prev_timestep: int,
|
||||
sample: jnp.ndarray,
|
||||
) -> jnp.ndarray:
|
||||
"""
|
||||
One step for the third-order multistep DPM-Solver.
|
||||
|
||||
Args:
|
||||
model_output_list (`List[jnp.ndarray]`):
|
||||
direct outputs from learned diffusion model at current and latter timesteps.
|
||||
timestep (`int`): current and latter discrete timestep in the diffusion chain.
|
||||
prev_timestep (`int`): previous discrete timestep in the diffusion chain.
|
||||
sample (`jnp.ndarray`):
|
||||
current instance of sample being created by diffusion process.
|
||||
|
||||
Returns:
|
||||
`jnp.ndarray`: the sample tensor at the previous timestep.
|
||||
"""
|
||||
t, s0, s1, s2 = prev_timestep, timestep_list[-1], timestep_list[-2], timestep_list[-3]
|
||||
m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3]
|
||||
lambda_t, lambda_s0, lambda_s1, lambda_s2 = (
|
||||
self.lambda_t[t],
|
||||
self.lambda_t[s0],
|
||||
self.lambda_t[s1],
|
||||
self.lambda_t[s2],
|
||||
)
|
||||
alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0]
|
||||
sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0]
|
||||
h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2
|
||||
r0, r1 = h_0 / h, h_1 / h
|
||||
D0 = m0
|
||||
D1_0, D1_1 = (1.0 / r0) * (m0 - m1), (1.0 / r1) * (m1 - m2)
|
||||
D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1)
|
||||
D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1)
|
||||
if self.config.algorithm_type == "dpmsolver++":
|
||||
# See https://arxiv.org/abs/2206.00927 for detailed derivations
|
||||
x_t = (
|
||||
(sigma_t / sigma_s0) * sample
|
||||
- (alpha_t * (jnp.exp(-h) - 1.0)) * D0
|
||||
+ (alpha_t * ((jnp.exp(-h) - 1.0) / h + 1.0)) * D1
|
||||
- (alpha_t * ((jnp.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2
|
||||
)
|
||||
elif self.config.algorithm_type == "dpmsolver":
|
||||
# See https://arxiv.org/abs/2206.00927 for detailed derivations
|
||||
x_t = (
|
||||
(alpha_t / alpha_s0) * sample
|
||||
- (sigma_t * (jnp.exp(h) - 1.0)) * D0
|
||||
- (sigma_t * ((jnp.exp(h) - 1.0) / h - 1.0)) * D1
|
||||
- (sigma_t * ((jnp.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2
|
||||
)
|
||||
return x_t
|
||||
|
||||
def step(
|
||||
self,
|
||||
state: DPMSolverMultistepSchedulerState,
|
||||
model_output: jnp.ndarray,
|
||||
timestep: int,
|
||||
sample: jnp.ndarray,
|
||||
return_dict: bool = True,
|
||||
) -> Union[FlaxDPMSolverMultistepSchedulerOutput, Tuple]:
|
||||
"""
|
||||
Predict the sample at the previous timestep by DPM-Solver. Core function to propagate the diffusion process
|
||||
from the learned model outputs (most often the predicted noise).
|
||||
|
||||
Args:
|
||||
state (`DPMSolverMultistepSchedulerState`):
|
||||
the `FlaxDPMSolverMultistepScheduler` state data class instance.
|
||||
model_output (`jnp.ndarray`): direct output from learned diffusion model.
|
||||
timestep (`int`): current discrete timestep in the diffusion chain.
|
||||
sample (`jnp.ndarray`):
|
||||
current instance of sample being created by diffusion process.
|
||||
return_dict (`bool`): option for returning tuple rather than FlaxDPMSolverMultistepSchedulerOutput class
|
||||
|
||||
Returns:
|
||||
[`FlaxDPMSolverMultistepSchedulerOutput`] or `tuple`: [`FlaxDPMSolverMultistepSchedulerOutput`] if
|
||||
`return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
|
||||
|
||||
"""
|
||||
prev_timestep = jax.lax.cond(
|
||||
state.step_index == len(state.timesteps) - 1,
|
||||
lambda _: 0,
|
||||
lambda _: state.timesteps[state.step_index + 1],
|
||||
(),
|
||||
)
|
||||
|
||||
model_output = self.convert_model_output(model_output, timestep, sample)
|
||||
|
||||
model_outputs_new = jnp.roll(state.model_outputs, -1, axis=0)
|
||||
model_outputs_new = model_outputs_new.at[-1].set(model_output)
|
||||
state = state.replace(
|
||||
model_outputs=model_outputs_new,
|
||||
prev_timestep=prev_timestep,
|
||||
cur_sample=sample,
|
||||
)
|
||||
|
||||
def step_1(state: DPMSolverMultistepSchedulerState) -> jnp.ndarray:
|
||||
return self.dpm_solver_first_order_update(
|
||||
state.model_outputs[-1],
|
||||
state.timesteps[state.step_index],
|
||||
state.prev_timestep,
|
||||
state.cur_sample,
|
||||
)
|
||||
|
||||
def step_23(state: DPMSolverMultistepSchedulerState) -> jnp.ndarray:
|
||||
def step_2(state: DPMSolverMultistepSchedulerState) -> jnp.ndarray:
|
||||
timestep_list = jnp.array([state.timesteps[state.step_index - 1], state.timesteps[state.step_index]])
|
||||
return self.multistep_dpm_solver_second_order_update(
|
||||
state.model_outputs,
|
||||
timestep_list,
|
||||
state.prev_timestep,
|
||||
state.cur_sample,
|
||||
)
|
||||
|
||||
def step_3(state: DPMSolverMultistepSchedulerState) -> jnp.ndarray:
|
||||
timestep_list = jnp.array(
|
||||
[
|
||||
state.timesteps[state.step_index - 2],
|
||||
state.timesteps[state.step_index - 1],
|
||||
state.timesteps[state.step_index],
|
||||
]
|
||||
)
|
||||
return self.multistep_dpm_solver_third_order_update(
|
||||
state.model_outputs,
|
||||
timestep_list,
|
||||
state.prev_timestep,
|
||||
state.cur_sample,
|
||||
)
|
||||
|
||||
if self.config.solver_order == 2:
|
||||
return step_2(state)
|
||||
elif self.config.lower_order_final and len(state.timesteps) < 15:
|
||||
return jax.lax.cond(
|
||||
state.lower_order_nums < 2,
|
||||
step_2,
|
||||
lambda state: jax.lax.cond(
|
||||
state.step_index == len(state.timesteps) - 2,
|
||||
step_2,
|
||||
step_3,
|
||||
state,
|
||||
),
|
||||
state,
|
||||
)
|
||||
else:
|
||||
return jax.lax.cond(
|
||||
state.lower_order_nums < 2,
|
||||
step_2,
|
||||
step_3,
|
||||
state,
|
||||
)
|
||||
|
||||
if self.config.solver_order == 1:
|
||||
prev_sample = step_1(state)
|
||||
elif self.config.lower_order_final and len(state.timesteps) < 15:
|
||||
prev_sample = jax.lax.cond(
|
||||
state.lower_order_nums < 1,
|
||||
step_1,
|
||||
lambda state: jax.lax.cond(
|
||||
state.step_index == len(state.timesteps) - 1,
|
||||
step_1,
|
||||
step_23,
|
||||
state,
|
||||
),
|
||||
state,
|
||||
)
|
||||
else:
|
||||
prev_sample = jax.lax.cond(
|
||||
state.lower_order_nums < 1,
|
||||
step_1,
|
||||
step_23,
|
||||
state,
|
||||
)
|
||||
|
||||
state = state.replace(
|
||||
lower_order_nums=jnp.minimum(state.lower_order_nums + 1, self.config.solver_order),
|
||||
step_index=(state.step_index + 1),
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
return (prev_sample, state)
|
||||
|
||||
return FlaxDPMSolverMultistepSchedulerOutput(prev_sample=prev_sample, state=state)
|
||||
|
||||
def scale_model_input(
|
||||
self, state: DPMSolverMultistepSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None
|
||||
) -> jnp.ndarray:
|
||||
"""
|
||||
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
||||
current timestep.
|
||||
|
||||
Args:
|
||||
state (`DPMSolverMultistepSchedulerState`):
|
||||
the `FlaxDPMSolverMultistepScheduler` state data class instance.
|
||||
sample (`jnp.ndarray`): input sample
|
||||
timestep (`int`, optional): current timestep
|
||||
|
||||
Returns:
|
||||
`jnp.ndarray`: scaled input sample
|
||||
"""
|
||||
return sample
|
||||
|
||||
def add_noise(
|
||||
self,
|
||||
original_samples: jnp.ndarray,
|
||||
noise: jnp.ndarray,
|
||||
timesteps: jnp.ndarray,
|
||||
) -> jnp.ndarray:
|
||||
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
|
||||
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
|
||||
sqrt_alpha_prod = broadcast_to_shape_from_left(sqrt_alpha_prod, original_samples.shape)
|
||||
|
||||
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.0
|
||||
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
|
||||
sqrt_one_minus_alpha_prod = broadcast_to_shape_from_left(sqrt_one_minus_alpha_prod, original_samples.shape)
|
||||
|
||||
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
|
||||
return noisy_samples
|
||||
|
||||
def __len__(self):
|
||||
return self.config.num_train_timesteps
|
||||
@@ -1,267 +0,0 @@
|
||||
# Copyright 2022 Katherine Crowson and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import BaseOutput, logging
|
||||
from .scheduling_utils import SchedulerMixin
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
@dataclass
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->EulerAncestralDiscrete
|
||||
class EulerAncestralDiscreteSchedulerOutput(BaseOutput):
|
||||
"""
|
||||
Output class for the scheduler's step function output.
|
||||
|
||||
Args:
|
||||
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
|
||||
Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
|
||||
denoising loop.
|
||||
pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
|
||||
The predicted denoised sample (x_{0}) based on the model output from the current timestep.
|
||||
`pred_original_sample` can be used to preview progress or for guidance.
|
||||
"""
|
||||
|
||||
prev_sample: torch.FloatTensor
|
||||
pred_original_sample: Optional[torch.FloatTensor] = None
|
||||
|
||||
|
||||
class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
Ancestral sampling with Euler method steps. Based on the original k-diffusion implementation by Katherine Crowson:
|
||||
https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L72
|
||||
|
||||
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
|
||||
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
|
||||
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
|
||||
[`~ConfigMixin.from_config`] functions.
|
||||
|
||||
Args:
|
||||
num_train_timesteps (`int`): number of diffusion steps used to train the model.
|
||||
beta_start (`float`): the starting `beta` value of inference.
|
||||
beta_end (`float`): the final `beta` value.
|
||||
beta_schedule (`str`):
|
||||
the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
|
||||
`linear` or `scaled_linear`.
|
||||
trained_betas (`np.ndarray`, optional):
|
||||
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
|
||||
|
||||
"""
|
||||
|
||||
_compatible_classes = [
|
||||
"DDIMScheduler",
|
||||
"DDPMScheduler",
|
||||
"LMSDiscreteScheduler",
|
||||
"PNDMScheduler",
|
||||
"EulerDiscreteScheduler",
|
||||
"DPMSolverMultistepScheduler",
|
||||
]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_train_timesteps: int = 1000,
|
||||
beta_start: float = 0.0001,
|
||||
beta_end: float = 0.02,
|
||||
beta_schedule: str = "linear",
|
||||
trained_betas: Optional[np.ndarray] = None,
|
||||
):
|
||||
if trained_betas is not None:
|
||||
self.betas = torch.from_numpy(trained_betas)
|
||||
elif beta_schedule == "linear":
|
||||
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
||||
elif beta_schedule == "scaled_linear":
|
||||
# this schedule is very specific to the latent diffusion model.
|
||||
self.betas = (
|
||||
torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
|
||||
|
||||
self.alphas = 1.0 - self.betas
|
||||
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
||||
|
||||
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
||||
sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32)
|
||||
self.sigmas = torch.from_numpy(sigmas)
|
||||
|
||||
# standard deviation of the initial noise distribution
|
||||
self.init_noise_sigma = self.sigmas.max()
|
||||
|
||||
# setable values
|
||||
self.num_inference_steps = None
|
||||
timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy()
|
||||
self.timesteps = torch.from_numpy(timesteps)
|
||||
self.is_scale_input_called = False
|
||||
|
||||
def scale_model_input(
|
||||
self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm.
|
||||
|
||||
Args:
|
||||
sample (`torch.FloatTensor`): input sample
|
||||
timestep (`float` or `torch.FloatTensor`): the current timestep in the diffusion chain
|
||||
|
||||
Returns:
|
||||
`torch.FloatTensor`: scaled input sample
|
||||
"""
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
step_index = (self.timesteps == timestep).nonzero().item()
|
||||
sigma = self.sigmas[step_index]
|
||||
sample = sample / ((sigma**2 + 1) ** 0.5)
|
||||
self.is_scale_input_called = True
|
||||
return sample
|
||||
|
||||
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
|
||||
"""
|
||||
Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
|
||||
|
||||
Args:
|
||||
num_inference_steps (`int`):
|
||||
the number of diffusion steps used when generating samples with a pre-trained model.
|
||||
device (`str` or `torch.device`, optional):
|
||||
the device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
"""
|
||||
self.num_inference_steps = num_inference_steps
|
||||
|
||||
timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy()
|
||||
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
||||
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
|
||||
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
|
||||
self.sigmas = torch.from_numpy(sigmas).to(device=device)
|
||||
self.timesteps = torch.from_numpy(timesteps).to(device=device)
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: torch.FloatTensor,
|
||||
timestep: Union[float, torch.FloatTensor],
|
||||
sample: torch.FloatTensor,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[EulerAncestralDiscreteSchedulerOutput, Tuple]:
|
||||
"""
|
||||
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
|
||||
process from the learned model outputs (most often the predicted noise).
|
||||
|
||||
Args:
|
||||
model_output (`torch.FloatTensor`): direct output from learned diffusion model.
|
||||
timestep (`float`): current timestep in the diffusion chain.
|
||||
sample (`torch.FloatTensor`):
|
||||
current instance of sample being created by diffusion process.
|
||||
generator (`torch.Generator`, optional): Random number generator.
|
||||
return_dict (`bool`): option for returning tuple rather than EulerAncestralDiscreteSchedulerOutput class
|
||||
|
||||
Returns:
|
||||
[`~schedulers.scheduling_utils.EulerAncestralDiscreteSchedulerOutput`] or `tuple`:
|
||||
[`~schedulers.scheduling_utils.EulerAncestralDiscreteSchedulerOutput`] if `return_dict` is True, otherwise
|
||||
a `tuple`. When returning a tuple, the first element is the sample tensor.
|
||||
|
||||
"""
|
||||
|
||||
if (
|
||||
isinstance(timestep, int)
|
||||
or isinstance(timestep, torch.IntTensor)
|
||||
or isinstance(timestep, torch.LongTensor)
|
||||
):
|
||||
raise ValueError(
|
||||
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
|
||||
" `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
|
||||
" one of the `scheduler.timesteps` as a timestep.",
|
||||
)
|
||||
|
||||
if not self.is_scale_input_called:
|
||||
logger.warn(
|
||||
"The `scale_model_input` function should be called before `step` to ensure correct denoising. "
|
||||
"See `StableDiffusionPipeline` for a usage example."
|
||||
)
|
||||
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
|
||||
step_index = (self.timesteps == timestep).nonzero().item()
|
||||
sigma = self.sigmas[step_index]
|
||||
|
||||
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
|
||||
pred_original_sample = sample - sigma * model_output
|
||||
sigma_from = self.sigmas[step_index]
|
||||
sigma_to = self.sigmas[step_index + 1]
|
||||
sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5
|
||||
sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
|
||||
|
||||
# 2. Convert to an ODE derivative
|
||||
derivative = (sample - pred_original_sample) / sigma
|
||||
|
||||
dt = sigma_down - sigma
|
||||
|
||||
prev_sample = sample + derivative * dt
|
||||
|
||||
device = model_output.device if torch.is_tensor(model_output) else "cpu"
|
||||
if str(device) == "mps":
|
||||
# randn does not work reproducibly on mps
|
||||
noise = torch.randn(model_output.shape, dtype=model_output.dtype, device="cpu", generator=generator).to(
|
||||
device
|
||||
)
|
||||
else:
|
||||
noise = torch.randn(model_output.shape, dtype=model_output.dtype, device=device, generator=generator).to(
|
||||
device
|
||||
)
|
||||
|
||||
prev_sample = prev_sample + noise * sigma_up
|
||||
|
||||
if not return_dict:
|
||||
return (prev_sample,)
|
||||
|
||||
return EulerAncestralDiscreteSchedulerOutput(
|
||||
prev_sample=prev_sample, pred_original_sample=pred_original_sample
|
||||
)
|
||||
|
||||
def add_noise(
|
||||
self,
|
||||
original_samples: torch.FloatTensor,
|
||||
noise: torch.FloatTensor,
|
||||
timesteps: torch.FloatTensor,
|
||||
) -> torch.FloatTensor:
|
||||
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
||||
self.sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
|
||||
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
|
||||
# mps does not support float64
|
||||
self.timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
|
||||
timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
|
||||
else:
|
||||
self.timesteps = self.timesteps.to(original_samples.device)
|
||||
timesteps = timesteps.to(original_samples.device)
|
||||
|
||||
schedule_timesteps = self.timesteps
|
||||
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
|
||||
|
||||
sigma = self.sigmas[step_indices].flatten()
|
||||
while len(sigma.shape) < len(original_samples.shape):
|
||||
sigma = sigma.unsqueeze(-1)
|
||||
|
||||
noisy_samples = original_samples + noise * sigma
|
||||
return noisy_samples
|
||||
|
||||
def __len__(self):
|
||||
return self.config.num_train_timesteps
|
||||
@@ -1,276 +0,0 @@
|
||||
# Copyright 2022 Katherine Crowson and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import BaseOutput, logging
|
||||
from .scheduling_utils import SchedulerMixin
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
@dataclass
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->EulerDiscrete
|
||||
class EulerDiscreteSchedulerOutput(BaseOutput):
|
||||
"""
|
||||
Output class for the scheduler's step function output.
|
||||
|
||||
Args:
|
||||
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
|
||||
Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
|
||||
denoising loop.
|
||||
pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
|
||||
The predicted denoised sample (x_{0}) based on the model output from the current timestep.
|
||||
`pred_original_sample` can be used to preview progress or for guidance.
|
||||
"""
|
||||
|
||||
prev_sample: torch.FloatTensor
|
||||
pred_original_sample: Optional[torch.FloatTensor] = None
|
||||
|
||||
|
||||
class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
Euler scheduler (Algorithm 2) from Karras et al. (2022) https://arxiv.org/abs/2206.00364. . Based on the original
|
||||
k-diffusion implementation by Katherine Crowson:
|
||||
https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L51
|
||||
|
||||
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
|
||||
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
|
||||
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
|
||||
[`~ConfigMixin.from_config`] functions.
|
||||
|
||||
Args:
|
||||
num_train_timesteps (`int`): number of diffusion steps used to train the model.
|
||||
beta_start (`float`): the starting `beta` value of inference.
|
||||
beta_end (`float`): the final `beta` value.
|
||||
beta_schedule (`str`):
|
||||
the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
|
||||
`linear` or `scaled_linear`.
|
||||
trained_betas (`np.ndarray`, optional):
|
||||
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
|
||||
|
||||
"""
|
||||
|
||||
_compatible_classes = [
|
||||
"DDIMScheduler",
|
||||
"DDPMScheduler",
|
||||
"LMSDiscreteScheduler",
|
||||
"PNDMScheduler",
|
||||
"EulerAncestralDiscreteScheduler",
|
||||
"DPMSolverMultistepScheduler",
|
||||
]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_train_timesteps: int = 1000,
|
||||
beta_start: float = 0.0001,
|
||||
beta_end: float = 0.02,
|
||||
beta_schedule: str = "linear",
|
||||
trained_betas: Optional[np.ndarray] = None,
|
||||
):
|
||||
if trained_betas is not None:
|
||||
self.betas = torch.from_numpy(trained_betas)
|
||||
elif beta_schedule == "linear":
|
||||
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
||||
elif beta_schedule == "scaled_linear":
|
||||
# this schedule is very specific to the latent diffusion model.
|
||||
self.betas = (
|
||||
torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
|
||||
|
||||
self.alphas = 1.0 - self.betas
|
||||
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
||||
|
||||
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
||||
sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32)
|
||||
self.sigmas = torch.from_numpy(sigmas)
|
||||
|
||||
# standard deviation of the initial noise distribution
|
||||
self.init_noise_sigma = self.sigmas.max()
|
||||
|
||||
# setable values
|
||||
self.num_inference_steps = None
|
||||
timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy()
|
||||
self.timesteps = torch.from_numpy(timesteps)
|
||||
self.is_scale_input_called = False
|
||||
|
||||
def scale_model_input(
|
||||
self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm.
|
||||
|
||||
Args:
|
||||
sample (`torch.FloatTensor`): input sample
|
||||
timestep (`float` or `torch.FloatTensor`): the current timestep in the diffusion chain
|
||||
|
||||
Returns:
|
||||
`torch.FloatTensor`: scaled input sample
|
||||
"""
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
step_index = (self.timesteps == timestep).nonzero().item()
|
||||
sigma = self.sigmas[step_index]
|
||||
sample = sample / ((sigma**2 + 1) ** 0.5)
|
||||
self.is_scale_input_called = True
|
||||
return sample
|
||||
|
||||
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
|
||||
"""
|
||||
Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
|
||||
|
||||
Args:
|
||||
num_inference_steps (`int`):
|
||||
the number of diffusion steps used when generating samples with a pre-trained model.
|
||||
device (`str` or `torch.device`, optional):
|
||||
the device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
"""
|
||||
self.num_inference_steps = num_inference_steps
|
||||
|
||||
timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy()
|
||||
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
||||
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
|
||||
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
|
||||
self.sigmas = torch.from_numpy(sigmas).to(device=device)
|
||||
self.timesteps = torch.from_numpy(timesteps).to(device=device)
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: torch.FloatTensor,
|
||||
timestep: Union[float, torch.FloatTensor],
|
||||
sample: torch.FloatTensor,
|
||||
s_churn: float = 0.0,
|
||||
s_tmin: float = 0.0,
|
||||
s_tmax: float = float("inf"),
|
||||
s_noise: float = 1.0,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[EulerDiscreteSchedulerOutput, Tuple]:
|
||||
"""
|
||||
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
|
||||
process from the learned model outputs (most often the predicted noise).
|
||||
|
||||
Args:
|
||||
model_output (`torch.FloatTensor`): direct output from learned diffusion model.
|
||||
timestep (`float`): current timestep in the diffusion chain.
|
||||
sample (`torch.FloatTensor`):
|
||||
current instance of sample being created by diffusion process.
|
||||
s_churn (`float`)
|
||||
s_tmin (`float`)
|
||||
s_tmax (`float`)
|
||||
s_noise (`float`)
|
||||
generator (`torch.Generator`, optional): Random number generator.
|
||||
return_dict (`bool`): option for returning tuple rather than EulerDiscreteSchedulerOutput class
|
||||
|
||||
Returns:
|
||||
[`~schedulers.scheduling_utils.EulerDiscreteSchedulerOutput`] or `tuple`:
|
||||
[`~schedulers.scheduling_utils.EulerDiscreteSchedulerOutput`] if `return_dict` is True, otherwise a
|
||||
`tuple`. When returning a tuple, the first element is the sample tensor.
|
||||
|
||||
"""
|
||||
|
||||
if (
|
||||
isinstance(timestep, int)
|
||||
or isinstance(timestep, torch.IntTensor)
|
||||
or isinstance(timestep, torch.LongTensor)
|
||||
):
|
||||
raise ValueError(
|
||||
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
|
||||
" `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
|
||||
" one of the `scheduler.timesteps` as a timestep.",
|
||||
)
|
||||
|
||||
if not self.is_scale_input_called:
|
||||
logger.warn(
|
||||
"The `scale_model_input` function should be called before `step` to ensure correct denoising. "
|
||||
"See `StableDiffusionPipeline` for a usage example."
|
||||
)
|
||||
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
|
||||
step_index = (self.timesteps == timestep).nonzero().item()
|
||||
sigma = self.sigmas[step_index]
|
||||
|
||||
gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0
|
||||
|
||||
device = model_output.device if torch.is_tensor(model_output) else "cpu"
|
||||
if str(device) == "mps":
|
||||
# randn does not work reproducibly on mps
|
||||
noise = torch.randn(model_output.shape, dtype=model_output.dtype, device="cpu", generator=generator).to(
|
||||
device
|
||||
)
|
||||
else:
|
||||
noise = torch.randn(model_output.shape, dtype=model_output.dtype, device=device, generator=generator).to(
|
||||
device
|
||||
)
|
||||
|
||||
eps = noise * s_noise
|
||||
sigma_hat = sigma * (gamma + 1)
|
||||
|
||||
if gamma > 0:
|
||||
sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5
|
||||
|
||||
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
|
||||
pred_original_sample = sample - sigma_hat * model_output
|
||||
|
||||
# 2. Convert to an ODE derivative
|
||||
derivative = (sample - pred_original_sample) / sigma_hat
|
||||
|
||||
dt = self.sigmas[step_index + 1] - sigma_hat
|
||||
|
||||
prev_sample = sample + derivative * dt
|
||||
|
||||
if not return_dict:
|
||||
return (prev_sample,)
|
||||
|
||||
return EulerDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
|
||||
|
||||
def add_noise(
|
||||
self,
|
||||
original_samples: torch.FloatTensor,
|
||||
noise: torch.FloatTensor,
|
||||
timesteps: torch.FloatTensor,
|
||||
) -> torch.FloatTensor:
|
||||
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
||||
self.sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
|
||||
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
|
||||
# mps does not support float64
|
||||
self.timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
|
||||
timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
|
||||
else:
|
||||
self.timesteps = self.timesteps.to(original_samples.device)
|
||||
timesteps = timesteps.to(original_samples.device)
|
||||
|
||||
schedule_timesteps = self.timesteps
|
||||
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
|
||||
|
||||
sigma = self.sigmas[step_indices].flatten()
|
||||
while len(sigma.shape) < len(original_samples.shape):
|
||||
sigma = sigma.unsqueeze(-1)
|
||||
|
||||
noisy_samples = original_samples + noise * sigma
|
||||
return noisy_samples
|
||||
|
||||
def __len__(self):
|
||||
return self.config.num_train_timesteps
|
||||
@@ -21,7 +21,7 @@ import torch
|
||||
from scipy import integrate
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import BaseOutput
|
||||
from ..utils import BaseOutput, deprecate
|
||||
from .scheduling_utils import SchedulerMixin
|
||||
|
||||
|
||||
@@ -67,15 +67,6 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
"""
|
||||
|
||||
_compatible_classes = [
|
||||
"DDIMScheduler",
|
||||
"DDPMScheduler",
|
||||
"PNDMScheduler",
|
||||
"EulerDiscreteScheduler",
|
||||
"EulerAncestralDiscreteScheduler",
|
||||
"DPMSolverMultistepScheduler",
|
||||
]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
@@ -212,7 +203,22 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
step_index = (self.timesteps == timestep).nonzero().item()
|
||||
if (
|
||||
isinstance(timestep, int)
|
||||
or isinstance(timestep, torch.IntTensor)
|
||||
or isinstance(timestep, torch.LongTensor)
|
||||
):
|
||||
deprecate(
|
||||
"timestep as an index",
|
||||
"0.8.0",
|
||||
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
|
||||
" `LMSDiscreteScheduler.step()` will not be supported in future versions. Make sure to pass"
|
||||
" one of the `scheduler.timesteps` as a timestep.",
|
||||
standard_warn=False,
|
||||
)
|
||||
step_index = timestep
|
||||
else:
|
||||
step_index = (self.timesteps == timestep).nonzero().item()
|
||||
sigma = self.sigmas[step_index]
|
||||
|
||||
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
|
||||
@@ -255,7 +261,19 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
timesteps = timesteps.to(original_samples.device)
|
||||
|
||||
schedule_timesteps = self.timesteps
|
||||
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
|
||||
|
||||
if isinstance(timesteps, torch.IntTensor) or isinstance(timesteps, torch.LongTensor):
|
||||
deprecate(
|
||||
"timesteps as indices",
|
||||
"0.8.0",
|
||||
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
|
||||
" `LMSDiscreteScheduler.add_noise()` will not be supported in future versions. Make sure to"
|
||||
" pass values from `scheduler.timesteps` as timesteps.",
|
||||
standard_warn=False,
|
||||
)
|
||||
step_indices = timesteps
|
||||
else:
|
||||
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
|
||||
|
||||
sigma = self.sigmas[step_indices].flatten()
|
||||
while len(sigma.shape) < len(original_samples.shape):
|
||||
|
||||
@@ -88,15 +88,6 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
"""
|
||||
|
||||
_compatible_classes = [
|
||||
"DDIMScheduler",
|
||||
"DDPMScheduler",
|
||||
"LMSDiscreteScheduler",
|
||||
"EulerDiscreteScheduler",
|
||||
"EulerAncestralDiscreteScheduler",
|
||||
"DPMSolverMultistepScheduler",
|
||||
]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -1,322 +0,0 @@
|
||||
# Copyright 2022 ETH Zurich Computer Vision Lab and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import BaseOutput
|
||||
from .scheduling_utils import SchedulerMixin
|
||||
|
||||
|
||||
@dataclass
|
||||
class RePaintSchedulerOutput(BaseOutput):
|
||||
"""
|
||||
Output class for the scheduler's step function output.
|
||||
|
||||
Args:
|
||||
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
|
||||
Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
|
||||
denoising loop.
|
||||
pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
|
||||
The predicted denoised sample (x_{0}) based on the model output from
|
||||
the current timestep. `pred_original_sample` can be used to preview progress or for guidance.
|
||||
"""
|
||||
|
||||
prev_sample: torch.FloatTensor
|
||||
pred_original_sample: torch.FloatTensor
|
||||
|
||||
|
||||
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
|
||||
"""
|
||||
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
|
||||
(1-beta) over time from t = [0,1].
|
||||
|
||||
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
|
||||
to that part of the diffusion process.
|
||||
|
||||
|
||||
Args:
|
||||
num_diffusion_timesteps (`int`): the number of betas to produce.
|
||||
max_beta (`float`): the maximum beta to use; use values lower than 1 to
|
||||
prevent singularities.
|
||||
|
||||
Returns:
|
||||
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
|
||||
"""
|
||||
|
||||
def alpha_bar(time_step):
|
||||
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
|
||||
|
||||
betas = []
|
||||
for i in range(num_diffusion_timesteps):
|
||||
t1 = i / num_diffusion_timesteps
|
||||
t2 = (i + 1) / num_diffusion_timesteps
|
||||
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
|
||||
return torch.tensor(betas, dtype=torch.float32)
|
||||
|
||||
|
||||
class RePaintScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
RePaint is a schedule for DDPM inpainting inside a given mask.
|
||||
|
||||
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
|
||||
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
|
||||
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
|
||||
[`~ConfigMixin.from_config`] functions.
|
||||
|
||||
For more details, see the original paper: https://arxiv.org/pdf/2201.09865.pdf
|
||||
|
||||
Args:
|
||||
num_train_timesteps (`int`): number of diffusion steps used to train the model.
|
||||
beta_start (`float`): the starting `beta` value of inference.
|
||||
beta_end (`float`): the final `beta` value.
|
||||
beta_schedule (`str`):
|
||||
the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
|
||||
`linear`, `scaled_linear`, or `squaredcos_cap_v2`.
|
||||
eta (`float`):
|
||||
The weight of noise for added noise in a diffusion step. Its value is between 0.0 and 1.0 -0.0 is DDIM and
|
||||
1.0 is DDPM scheduler respectively.
|
||||
trained_betas (`np.ndarray`, optional):
|
||||
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
|
||||
variance_type (`str`):
|
||||
options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`,
|
||||
`fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
|
||||
clip_sample (`bool`, default `True`):
|
||||
option to clip predicted sample between -1 and 1 for numerical stability.
|
||||
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_train_timesteps: int = 1000,
|
||||
beta_start: float = 0.0001,
|
||||
beta_end: float = 0.02,
|
||||
beta_schedule: str = "linear",
|
||||
eta: float = 0.0,
|
||||
trained_betas: Optional[np.ndarray] = None,
|
||||
clip_sample: bool = True,
|
||||
):
|
||||
if trained_betas is not None:
|
||||
self.betas = torch.from_numpy(trained_betas)
|
||||
elif beta_schedule == "linear":
|
||||
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
||||
elif beta_schedule == "scaled_linear":
|
||||
# this schedule is very specific to the latent diffusion model.
|
||||
self.betas = (
|
||||
torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
||||
)
|
||||
elif beta_schedule == "squaredcos_cap_v2":
|
||||
# Glide cosine schedule
|
||||
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
||||
elif beta_schedule == "sigmoid":
|
||||
# GeoDiff sigmoid schedule
|
||||
betas = torch.linspace(-6, 6, num_train_timesteps)
|
||||
self.betas = torch.sigmoid(betas) * (beta_end - beta_start) + beta_start
|
||||
else:
|
||||
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
|
||||
|
||||
self.alphas = 1.0 - self.betas
|
||||
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
||||
self.one = torch.tensor(1.0)
|
||||
|
||||
self.final_alpha_cumprod = torch.tensor(1.0)
|
||||
|
||||
# standard deviation of the initial noise distribution
|
||||
self.init_noise_sigma = 1.0
|
||||
|
||||
# setable values
|
||||
self.num_inference_steps = None
|
||||
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy())
|
||||
|
||||
self.eta = eta
|
||||
|
||||
def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
|
||||
"""
|
||||
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
||||
current timestep.
|
||||
|
||||
Args:
|
||||
sample (`torch.FloatTensor`): input sample
|
||||
timestep (`int`, optional): current timestep
|
||||
|
||||
Returns:
|
||||
`torch.FloatTensor`: scaled input sample
|
||||
"""
|
||||
return sample
|
||||
|
||||
def set_timesteps(
|
||||
self,
|
||||
num_inference_steps: int,
|
||||
jump_length: int = 10,
|
||||
jump_n_sample: int = 10,
|
||||
device: Union[str, torch.device] = None,
|
||||
):
|
||||
num_inference_steps = min(self.config.num_train_timesteps, num_inference_steps)
|
||||
self.num_inference_steps = num_inference_steps
|
||||
|
||||
timesteps = []
|
||||
|
||||
jumps = {}
|
||||
for j in range(0, num_inference_steps - jump_length, jump_length):
|
||||
jumps[j] = jump_n_sample - 1
|
||||
|
||||
t = num_inference_steps
|
||||
while t >= 1:
|
||||
t = t - 1
|
||||
timesteps.append(t)
|
||||
|
||||
if jumps.get(t, 0) > 0:
|
||||
jumps[t] = jumps[t] - 1
|
||||
for _ in range(jump_length):
|
||||
t = t + 1
|
||||
timesteps.append(t)
|
||||
|
||||
timesteps = np.array(timesteps) * (self.config.num_train_timesteps // self.num_inference_steps)
|
||||
self.timesteps = torch.from_numpy(timesteps).to(device)
|
||||
|
||||
def _get_variance(self, t):
|
||||
prev_timestep = t - self.config.num_train_timesteps // self.num_inference_steps
|
||||
|
||||
alpha_prod_t = self.alphas_cumprod[t]
|
||||
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
|
||||
beta_prod_t = 1 - alpha_prod_t
|
||||
beta_prod_t_prev = 1 - alpha_prod_t_prev
|
||||
|
||||
# For t > 0, compute predicted variance βt (see formula (6) and (7) from
|
||||
# https://arxiv.org/pdf/2006.11239.pdf) and sample from it to get
|
||||
# previous sample x_{t-1} ~ N(pred_prev_sample, variance) == add
|
||||
# variance to pred_sample
|
||||
# Is equivalent to formula (16) in https://arxiv.org/pdf/2010.02502.pdf
|
||||
# without eta.
|
||||
# variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.betas[t]
|
||||
variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
|
||||
|
||||
return variance
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: torch.FloatTensor,
|
||||
timestep: int,
|
||||
sample: torch.FloatTensor,
|
||||
original_image: torch.FloatTensor,
|
||||
mask: torch.FloatTensor,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[RePaintSchedulerOutput, Tuple]:
|
||||
"""
|
||||
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
|
||||
process from the learned model outputs (most often the predicted noise).
|
||||
|
||||
Args:
|
||||
model_output (`torch.FloatTensor`): direct output from learned
|
||||
diffusion model.
|
||||
timestep (`int`): current discrete timestep in the diffusion chain.
|
||||
sample (`torch.FloatTensor`):
|
||||
current instance of sample being created by diffusion process.
|
||||
original_image (`torch.FloatTensor`):
|
||||
the original image to inpaint on.
|
||||
mask (`torch.FloatTensor`):
|
||||
the mask where 0.0 values define which part of the original image to inpaint (change).
|
||||
generator (`torch.Generator`, *optional*): random number generator.
|
||||
return_dict (`bool`): option for returning tuple rather than
|
||||
DDPMSchedulerOutput class
|
||||
|
||||
Returns:
|
||||
[`~schedulers.scheduling_utils.RePaintSchedulerOutput`] or `tuple`:
|
||||
[`~schedulers.scheduling_utils.RePaintSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
|
||||
returning a tuple, the first element is the sample tensor.
|
||||
|
||||
"""
|
||||
t = timestep
|
||||
prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps
|
||||
|
||||
# 1. compute alphas, betas
|
||||
alpha_prod_t = self.alphas_cumprod[t]
|
||||
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
|
||||
beta_prod_t = 1 - alpha_prod_t
|
||||
|
||||
# 2. compute predicted original sample from predicted noise also called
|
||||
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
|
||||
pred_original_sample = (sample - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5
|
||||
|
||||
# 3. Clip "predicted x_0"
|
||||
if self.config.clip_sample:
|
||||
pred_original_sample = torch.clamp(pred_original_sample, -1, 1)
|
||||
|
||||
# We choose to follow RePaint Algorithm 1 to get x_{t-1}, however we
|
||||
# substitute formula (7) in the algorithm coming from DDPM paper
|
||||
# (formula (4) Algorithm 2 - Sampling) with formula (12) from DDIM paper.
|
||||
# DDIM schedule gives the same results as DDPM with eta = 1.0
|
||||
# Noise is being reused in 7. and 8., but no impact on quality has
|
||||
# been observed.
|
||||
|
||||
# 5. Add noise
|
||||
noise = torch.randn(
|
||||
model_output.shape, dtype=model_output.dtype, generator=generator, device=model_output.device
|
||||
)
|
||||
std_dev_t = self.eta * self._get_variance(timestep) ** 0.5
|
||||
|
||||
variance = 0
|
||||
if t > 0 and self.eta > 0:
|
||||
variance = std_dev_t * noise
|
||||
|
||||
# 6. compute "direction pointing to x_t" of formula (12)
|
||||
# from https://arxiv.org/pdf/2010.02502.pdf
|
||||
pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** 0.5 * model_output
|
||||
|
||||
# 7. compute x_{t-1} of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
||||
prev_unknown_part = alpha_prod_t_prev**0.5 * pred_original_sample + pred_sample_direction + variance
|
||||
|
||||
# 8. Algorithm 1 Line 5 https://arxiv.org/pdf/2201.09865.pdf
|
||||
prev_known_part = (alpha_prod_t**0.5) * original_image + ((1 - alpha_prod_t) ** 0.5) * noise
|
||||
|
||||
# 9. Algorithm 1 Line 8 https://arxiv.org/pdf/2201.09865.pdf
|
||||
pred_prev_sample = mask * prev_known_part + (1.0 - mask) * prev_unknown_part
|
||||
|
||||
if not return_dict:
|
||||
return (
|
||||
pred_prev_sample,
|
||||
pred_original_sample,
|
||||
)
|
||||
|
||||
return RePaintSchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample)
|
||||
|
||||
def undo_step(self, sample, timestep, generator=None):
|
||||
n = self.config.num_train_timesteps // self.num_inference_steps
|
||||
|
||||
for i in range(n):
|
||||
beta = self.betas[timestep + i]
|
||||
noise = torch.randn(sample.shape, generator=generator, device=sample.device)
|
||||
|
||||
# 10. Algorithm 1 Line 10 https://arxiv.org/pdf/2201.09865.pdf
|
||||
sample = (1 - beta) ** 0.5 * sample + beta**0.5 * noise
|
||||
|
||||
return sample
|
||||
|
||||
def add_noise(
|
||||
self,
|
||||
original_samples: torch.FloatTensor,
|
||||
noise: torch.FloatTensor,
|
||||
timesteps: torch.IntTensor,
|
||||
) -> torch.FloatTensor:
|
||||
raise NotImplementedError("Use `DDPMScheduler.add_noise()` to train for sampling with RePaint.")
|
||||
|
||||
def __len__(self):
|
||||
return self.config.num_train_timesteps
|
||||
@@ -1,494 +0,0 @@
|
||||
# Copyright 2022 Microsoft and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import BaseOutput
|
||||
from .scheduling_utils import SchedulerMixin
|
||||
|
||||
|
||||
@dataclass
|
||||
class VQDiffusionSchedulerOutput(BaseOutput):
|
||||
"""
|
||||
Output class for the scheduler's step function output.
|
||||
|
||||
Args:
|
||||
prev_sample (`torch.LongTensor` of shape `(batch size, num latent pixels)`):
|
||||
Computed sample x_{t-1} of previous timestep. `prev_sample` should be used as next model input in the
|
||||
denoising loop.
|
||||
"""
|
||||
|
||||
prev_sample: torch.LongTensor
|
||||
|
||||
|
||||
def index_to_log_onehot(x: torch.LongTensor, num_classes: int) -> torch.FloatTensor:
|
||||
"""
|
||||
Convert batch of vector of class indices into batch of log onehot vectors
|
||||
|
||||
Args:
|
||||
x (`torch.LongTensor` of shape `(batch size, vector length)`):
|
||||
Batch of class indices
|
||||
|
||||
num_classes (`int`):
|
||||
number of classes to be used for the onehot vectors
|
||||
|
||||
Returns:
|
||||
`torch.FloatTensor` of shape `(batch size, num classes, vector length)`:
|
||||
Log onehot vectors
|
||||
"""
|
||||
x_onehot = F.one_hot(x, num_classes)
|
||||
x_onehot = x_onehot.permute(0, 2, 1)
|
||||
log_x = torch.log(x_onehot.float().clamp(min=1e-30))
|
||||
return log_x
|
||||
|
||||
|
||||
def gumbel_noised(logits: torch.FloatTensor, generator: Optional[torch.Generator]) -> torch.FloatTensor:
|
||||
"""
|
||||
Apply gumbel noise to `logits`
|
||||
"""
|
||||
uniform = torch.rand(logits.shape, device=logits.device, generator=generator)
|
||||
gumbel_noise = -torch.log(-torch.log(uniform + 1e-30) + 1e-30)
|
||||
noised = gumbel_noise + logits
|
||||
return noised
|
||||
|
||||
|
||||
def alpha_schedules(num_diffusion_timesteps: int, alpha_cum_start=0.99999, alpha_cum_end=0.000009):
|
||||
"""
|
||||
Cumulative and non-cumulative alpha schedules.
|
||||
|
||||
See section 4.1.
|
||||
"""
|
||||
att = (
|
||||
np.arange(0, num_diffusion_timesteps) / (num_diffusion_timesteps - 1) * (alpha_cum_end - alpha_cum_start)
|
||||
+ alpha_cum_start
|
||||
)
|
||||
att = np.concatenate(([1], att))
|
||||
at = att[1:] / att[:-1]
|
||||
att = np.concatenate((att[1:], [1]))
|
||||
return at, att
|
||||
|
||||
|
||||
def gamma_schedules(num_diffusion_timesteps: int, gamma_cum_start=0.000009, gamma_cum_end=0.99999):
|
||||
"""
|
||||
Cumulative and non-cumulative gamma schedules.
|
||||
|
||||
See section 4.1.
|
||||
"""
|
||||
ctt = (
|
||||
np.arange(0, num_diffusion_timesteps) / (num_diffusion_timesteps - 1) * (gamma_cum_end - gamma_cum_start)
|
||||
+ gamma_cum_start
|
||||
)
|
||||
ctt = np.concatenate(([0], ctt))
|
||||
one_minus_ctt = 1 - ctt
|
||||
one_minus_ct = one_minus_ctt[1:] / one_minus_ctt[:-1]
|
||||
ct = 1 - one_minus_ct
|
||||
ctt = np.concatenate((ctt[1:], [0]))
|
||||
return ct, ctt
|
||||
|
||||
|
||||
class VQDiffusionScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
The VQ-diffusion transformer outputs predicted probabilities of the initial unnoised image.
|
||||
|
||||
The VQ-diffusion scheduler converts the transformer's output into a sample for the unnoised image at the previous
|
||||
diffusion timestep.
|
||||
|
||||
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
|
||||
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
|
||||
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
|
||||
[`~ConfigMixin.from_config`] functions.
|
||||
|
||||
For more details, see the original paper: https://arxiv.org/abs/2111.14822
|
||||
|
||||
Args:
|
||||
num_vec_classes (`int`):
|
||||
The number of classes of the vector embeddings of the latent pixels. Includes the class for the masked
|
||||
latent pixel.
|
||||
|
||||
num_train_timesteps (`int`):
|
||||
Number of diffusion steps used to train the model.
|
||||
|
||||
alpha_cum_start (`float`):
|
||||
The starting cumulative alpha value.
|
||||
|
||||
alpha_cum_end (`float`):
|
||||
The ending cumulative alpha value.
|
||||
|
||||
gamma_cum_start (`float`):
|
||||
The starting cumulative gamma value.
|
||||
|
||||
gamma_cum_end (`float`):
|
||||
The ending cumulative gamma value.
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_vec_classes: int,
|
||||
num_train_timesteps: int = 100,
|
||||
alpha_cum_start: float = 0.99999,
|
||||
alpha_cum_end: float = 0.000009,
|
||||
gamma_cum_start: float = 0.000009,
|
||||
gamma_cum_end: float = 0.99999,
|
||||
):
|
||||
self.num_embed = num_vec_classes
|
||||
|
||||
# By convention, the index for the mask class is the last class index
|
||||
self.mask_class = self.num_embed - 1
|
||||
|
||||
at, att = alpha_schedules(num_train_timesteps, alpha_cum_start=alpha_cum_start, alpha_cum_end=alpha_cum_end)
|
||||
ct, ctt = gamma_schedules(num_train_timesteps, gamma_cum_start=gamma_cum_start, gamma_cum_end=gamma_cum_end)
|
||||
|
||||
num_non_mask_classes = self.num_embed - 1
|
||||
bt = (1 - at - ct) / num_non_mask_classes
|
||||
btt = (1 - att - ctt) / num_non_mask_classes
|
||||
|
||||
at = torch.tensor(at.astype("float64"))
|
||||
bt = torch.tensor(bt.astype("float64"))
|
||||
ct = torch.tensor(ct.astype("float64"))
|
||||
log_at = torch.log(at)
|
||||
log_bt = torch.log(bt)
|
||||
log_ct = torch.log(ct)
|
||||
|
||||
att = torch.tensor(att.astype("float64"))
|
||||
btt = torch.tensor(btt.astype("float64"))
|
||||
ctt = torch.tensor(ctt.astype("float64"))
|
||||
log_cumprod_at = torch.log(att)
|
||||
log_cumprod_bt = torch.log(btt)
|
||||
log_cumprod_ct = torch.log(ctt)
|
||||
|
||||
self.log_at = log_at.float()
|
||||
self.log_bt = log_bt.float()
|
||||
self.log_ct = log_ct.float()
|
||||
self.log_cumprod_at = log_cumprod_at.float()
|
||||
self.log_cumprod_bt = log_cumprod_bt.float()
|
||||
self.log_cumprod_ct = log_cumprod_ct.float()
|
||||
|
||||
# setable values
|
||||
self.num_inference_steps = None
|
||||
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy())
|
||||
|
||||
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
|
||||
"""
|
||||
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
|
||||
|
||||
Args:
|
||||
num_inference_steps (`int`):
|
||||
the number of diffusion steps used when generating samples with a pre-trained model.
|
||||
|
||||
device (`str` or `torch.device`):
|
||||
device to place the timesteps and the diffusion process parameters (alpha, beta, gamma) on.
|
||||
"""
|
||||
self.num_inference_steps = num_inference_steps
|
||||
timesteps = np.arange(0, self.num_inference_steps)[::-1].copy()
|
||||
self.timesteps = torch.from_numpy(timesteps).to(device)
|
||||
|
||||
self.log_at = self.log_at.to(device)
|
||||
self.log_bt = self.log_bt.to(device)
|
||||
self.log_ct = self.log_ct.to(device)
|
||||
self.log_cumprod_at = self.log_cumprod_at.to(device)
|
||||
self.log_cumprod_bt = self.log_cumprod_bt.to(device)
|
||||
self.log_cumprod_ct = self.log_cumprod_ct.to(device)
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: torch.FloatTensor,
|
||||
timestep: torch.long,
|
||||
sample: torch.LongTensor,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[VQDiffusionSchedulerOutput, Tuple]:
|
||||
"""
|
||||
Predict the sample at the previous timestep via the reverse transition distribution i.e. Equation (11). See the
|
||||
docstring for `self.q_posterior` for more in depth docs on how Equation (11) is computed.
|
||||
|
||||
Args:
|
||||
log_p_x_0: (`torch.FloatTensor` of shape `(batch size, num classes - 1, num latent pixels)`):
|
||||
The log probabilities for the predicted classes of the initial latent pixels. Does not include a
|
||||
prediction for the masked class as the initial unnoised image cannot be masked.
|
||||
|
||||
t (`torch.long`):
|
||||
The timestep that determines which transition matrices are used.
|
||||
|
||||
x_t: (`torch.LongTensor` of shape `(batch size, num latent pixels)`):
|
||||
The classes of each latent pixel at time `t`
|
||||
|
||||
generator: (`torch.Generator` or None):
|
||||
RNG for the noise applied to p(x_{t-1} | x_t) before it is sampled from.
|
||||
|
||||
return_dict (`bool`):
|
||||
option for returning tuple rather than VQDiffusionSchedulerOutput class
|
||||
|
||||
Returns:
|
||||
[`~schedulers.scheduling_utils.VQDiffusionSchedulerOutput`] or `tuple`:
|
||||
[`~schedulers.scheduling_utils.VQDiffusionSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`.
|
||||
When returning a tuple, the first element is the sample tensor.
|
||||
"""
|
||||
if timestep == 0:
|
||||
log_p_x_t_min_1 = model_output
|
||||
else:
|
||||
log_p_x_t_min_1 = self.q_posterior(model_output, sample, timestep)
|
||||
|
||||
log_p_x_t_min_1 = gumbel_noised(log_p_x_t_min_1, generator)
|
||||
|
||||
x_t_min_1 = log_p_x_t_min_1.argmax(dim=1)
|
||||
|
||||
if not return_dict:
|
||||
return (x_t_min_1,)
|
||||
|
||||
return VQDiffusionSchedulerOutput(prev_sample=x_t_min_1)
|
||||
|
||||
def q_posterior(self, log_p_x_0, x_t, t):
|
||||
"""
|
||||
Calculates the log probabilities for the predicted classes of the image at timestep `t-1`. I.e. Equation (11).
|
||||
|
||||
Instead of directly computing equation (11), we use Equation (5) to restate Equation (11) in terms of only
|
||||
forward probabilities.
|
||||
|
||||
Equation (11) stated in terms of forward probabilities via Equation (5):
|
||||
|
||||
Where:
|
||||
- the sum is over x_0 = {C_0 ... C_{k-1}} (classes for x_0)
|
||||
|
||||
p(x_{t-1} | x_t) = sum( q(x_t | x_{t-1}) * q(x_{t-1} | x_0) * p(x_0) / q(x_t | x_0) )
|
||||
|
||||
Args:
|
||||
log_p_x_0: (`torch.FloatTensor` of shape `(batch size, num classes - 1, num latent pixels)`):
|
||||
The log probabilities for the predicted classes of the initial latent pixels. Does not include a
|
||||
prediction for the masked class as the initial unnoised image cannot be masked.
|
||||
|
||||
x_t: (`torch.LongTensor` of shape `(batch size, num latent pixels)`):
|
||||
The classes of each latent pixel at time `t`
|
||||
|
||||
t (torch.Long):
|
||||
The timestep that determines which transition matrix is used.
|
||||
|
||||
Returns:
|
||||
`torch.FloatTensor` of shape `(batch size, num classes, num latent pixels)`:
|
||||
The log probabilities for the predicted classes of the image at timestep `t-1`. I.e. Equation (11).
|
||||
"""
|
||||
log_onehot_x_t = index_to_log_onehot(x_t, self.num_embed)
|
||||
|
||||
log_q_x_t_given_x_0 = self.log_Q_t_transitioning_to_known_class(
|
||||
t=t, x_t=x_t, log_onehot_x_t=log_onehot_x_t, cumulative=True
|
||||
)
|
||||
|
||||
log_q_t_given_x_t_min_1 = self.log_Q_t_transitioning_to_known_class(
|
||||
t=t, x_t=x_t, log_onehot_x_t=log_onehot_x_t, cumulative=False
|
||||
)
|
||||
|
||||
# p_0(x_0=C_0 | x_t) / q(x_t | x_0=C_0) ... p_n(x_0=C_0 | x_t) / q(x_t | x_0=C_0)
|
||||
# . . .
|
||||
# . . .
|
||||
# . . .
|
||||
# p_0(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) ... p_n(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1})
|
||||
q = log_p_x_0 - log_q_x_t_given_x_0
|
||||
|
||||
# sum_0 = p_0(x_0=C_0 | x_t) / q(x_t | x_0=C_0) + ... + p_0(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}), ... ,
|
||||
# sum_n = p_n(x_0=C_0 | x_t) / q(x_t | x_0=C_0) + ... + p_n(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1})
|
||||
q_log_sum_exp = torch.logsumexp(q, dim=1, keepdim=True)
|
||||
|
||||
# p_0(x_0=C_0 | x_t) / q(x_t | x_0=C_0) / sum_0 ... p_n(x_0=C_0 | x_t) / q(x_t | x_0=C_0) / sum_n
|
||||
# . . .
|
||||
# . . .
|
||||
# . . .
|
||||
# p_0(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) / sum_0 ... p_n(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) / sum_n
|
||||
q = q - q_log_sum_exp
|
||||
|
||||
# (p_0(x_0=C_0 | x_t) / q(x_t | x_0=C_0) / sum_0) * a_cumulative_{t-1} + b_cumulative_{t-1} ... (p_n(x_0=C_0 | x_t) / q(x_t | x_0=C_0) / sum_n) * a_cumulative_{t-1} + b_cumulative_{t-1}
|
||||
# . . .
|
||||
# . . .
|
||||
# . . .
|
||||
# (p_0(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) / sum_0) * a_cumulative_{t-1} + b_cumulative_{t-1} ... (p_n(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) / sum_n) * a_cumulative_{t-1} + b_cumulative_{t-1}
|
||||
# c_cumulative_{t-1} ... c_cumulative_{t-1}
|
||||
q = self.apply_cumulative_transitions(q, t - 1)
|
||||
|
||||
# ((p_0(x_0=C_0 | x_t) / q(x_t | x_0=C_0) / sum_0) * a_cumulative_{t-1} + b_cumulative_{t-1}) * q(x_t | x_{t-1}=C_0) * sum_0 ... ((p_n(x_0=C_0 | x_t) / q(x_t | x_0=C_0) / sum_n) * a_cumulative_{t-1} + b_cumulative_{t-1}) * q(x_t | x_{t-1}=C_0) * sum_n
|
||||
# . . .
|
||||
# . . .
|
||||
# . . .
|
||||
# ((p_0(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) / sum_0) * a_cumulative_{t-1} + b_cumulative_{t-1}) * q(x_t | x_{t-1}=C_{k-1}) * sum_0 ... ((p_n(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) / sum_n) * a_cumulative_{t-1} + b_cumulative_{t-1}) * q(x_t | x_{t-1}=C_{k-1}) * sum_n
|
||||
# c_cumulative_{t-1} * q(x_t | x_{t-1}=C_k) * sum_0 ... c_cumulative_{t-1} * q(x_t | x_{t-1}=C_k) * sum_0
|
||||
log_p_x_t_min_1 = q + log_q_t_given_x_t_min_1 + q_log_sum_exp
|
||||
|
||||
# For each column, there are two possible cases.
|
||||
#
|
||||
# Where:
|
||||
# - sum(p_n(x_0))) is summing over all classes for x_0
|
||||
# - C_i is the class transitioning from (not to be confused with c_t and c_cumulative_t being used for gamma's)
|
||||
# - C_j is the class transitioning to
|
||||
#
|
||||
# 1. x_t is masked i.e. x_t = c_k
|
||||
#
|
||||
# Simplifying the expression, the column vector is:
|
||||
# .
|
||||
# .
|
||||
# .
|
||||
# (c_t / c_cumulative_t) * (a_cumulative_{t-1} * p_n(x_0 = C_i | x_t) + b_cumulative_{t-1} * sum(p_n(x_0)))
|
||||
# .
|
||||
# .
|
||||
# .
|
||||
# (c_cumulative_{t-1} / c_cumulative_t) * sum(p_n(x_0))
|
||||
#
|
||||
# From equation (11) stated in terms of forward probabilities, the last row is trivially verified.
|
||||
#
|
||||
# For the other rows, we can state the equation as ...
|
||||
#
|
||||
# (c_t / c_cumulative_t) * [b_cumulative_{t-1} * p(x_0=c_0) + ... + (a_cumulative_{t-1} + b_cumulative_{t-1}) * p(x_0=C_i) + ... + b_cumulative_{k-1} * p(x_0=c_{k-1})]
|
||||
#
|
||||
# This verifies the other rows.
|
||||
#
|
||||
# 2. x_t is not masked
|
||||
#
|
||||
# Simplifying the expression, there are two cases for the rows of the column vector, where C_j = C_i and where C_j != C_i:
|
||||
# .
|
||||
# .
|
||||
# .
|
||||
# C_j != C_i: b_t * ((b_cumulative_{t-1} / b_cumulative_t) * p_n(x_0 = c_0) + ... + ((a_cumulative_{t-1} + b_cumulative_{t-1}) / b_cumulative_t) * p_n(x_0 = C_i) + ... + (b_cumulative_{t-1} / (a_cumulative_t + b_cumulative_t)) * p_n(c_0=C_j) + ... + (b_cumulative_{t-1} / b_cumulative_t) * p_n(x_0 = c_{k-1}))
|
||||
# .
|
||||
# .
|
||||
# .
|
||||
# C_j = C_i: (a_t + b_t) * ((b_cumulative_{t-1} / b_cumulative_t) * p_n(x_0 = c_0) + ... + ((a_cumulative_{t-1} + b_cumulative_{t-1}) / (a_cumulative_t + b_cumulative_t)) * p_n(x_0 = C_i = C_j) + ... + (b_cumulative_{t-1} / b_cumulative_t) * p_n(x_0 = c_{k-1}))
|
||||
# .
|
||||
# .
|
||||
# .
|
||||
# 0
|
||||
#
|
||||
# The last row is trivially verified. The other rows can be verified by directly expanding equation (11) stated in terms of forward probabilities.
|
||||
return log_p_x_t_min_1
|
||||
|
||||
def log_Q_t_transitioning_to_known_class(
|
||||
self, *, t: torch.int, x_t: torch.LongTensor, log_onehot_x_t: torch.FloatTensor, cumulative: bool
|
||||
):
|
||||
"""
|
||||
Returns the log probabilities of the rows from the (cumulative or non-cumulative) transition matrix for each
|
||||
latent pixel in `x_t`.
|
||||
|
||||
See equation (7) for the complete non-cumulative transition matrix. The complete cumulative transition matrix
|
||||
is the same structure except the parameters (alpha, beta, gamma) are the cumulative analogs.
|
||||
|
||||
Args:
|
||||
t (torch.Long):
|
||||
The timestep that determines which transition matrix is used.
|
||||
|
||||
x_t (`torch.LongTensor` of shape `(batch size, num latent pixels)`):
|
||||
The classes of each latent pixel at time `t`.
|
||||
|
||||
log_onehot_x_t (`torch.FloatTensor` of shape `(batch size, num classes, num latent pixels)`):
|
||||
The log one-hot vectors of `x_t`
|
||||
|
||||
cumulative (`bool`):
|
||||
If cumulative is `False`, we use the single step transition matrix `t-1`->`t`. If cumulative is `True`,
|
||||
we use the cumulative transition matrix `0`->`t`.
|
||||
|
||||
Returns:
|
||||
`torch.FloatTensor` of shape `(batch size, num classes - 1, num latent pixels)`:
|
||||
Each _column_ of the returned matrix is a _row_ of log probabilities of the complete probability
|
||||
transition matrix.
|
||||
|
||||
When non cumulative, returns `self.num_classes - 1` rows because the initial latent pixel cannot be
|
||||
masked.
|
||||
|
||||
Where:
|
||||
- `q_n` is the probability distribution for the forward process of the `n`th latent pixel.
|
||||
- C_0 is a class of a latent pixel embedding
|
||||
- C_k is the class of the masked latent pixel
|
||||
|
||||
non-cumulative result (omitting logarithms):
|
||||
```
|
||||
q_0(x_t | x_{t-1} = C_0) ... q_n(x_t | x_{t-1} = C_0)
|
||||
. . .
|
||||
. . .
|
||||
. . .
|
||||
q_0(x_t | x_{t-1} = C_k) ... q_n(x_t | x_{t-1} = C_k)
|
||||
```
|
||||
|
||||
cumulative result (omitting logarithms):
|
||||
```
|
||||
q_0_cumulative(x_t | x_0 = C_0) ... q_n_cumulative(x_t | x_0 = C_0)
|
||||
. . .
|
||||
. . .
|
||||
. . .
|
||||
q_0_cumulative(x_t | x_0 = C_{k-1}) ... q_n_cumulative(x_t | x_0 = C_{k-1})
|
||||
```
|
||||
"""
|
||||
if cumulative:
|
||||
a = self.log_cumprod_at[t]
|
||||
b = self.log_cumprod_bt[t]
|
||||
c = self.log_cumprod_ct[t]
|
||||
else:
|
||||
a = self.log_at[t]
|
||||
b = self.log_bt[t]
|
||||
c = self.log_ct[t]
|
||||
|
||||
if not cumulative:
|
||||
# The values in the onehot vector can also be used as the logprobs for transitioning
|
||||
# from masked latent pixels. If we are not calculating the cumulative transitions,
|
||||
# we need to save these vectors to be re-appended to the final matrix so the values
|
||||
# aren't overwritten.
|
||||
#
|
||||
# `P(x_t!=mask|x_{t-1=mask}) = 0` and 0 will be the value of the last row of the onehot vector
|
||||
# if x_t is not masked
|
||||
#
|
||||
# `P(x_t=mask|x_{t-1=mask}) = 1` and 1 will be the value of the last row of the onehot vector
|
||||
# if x_t is masked
|
||||
log_onehot_x_t_transitioning_from_masked = log_onehot_x_t[:, -1, :].unsqueeze(1)
|
||||
|
||||
# `index_to_log_onehot` will add onehot vectors for masked pixels,
|
||||
# so the default one hot matrix has one too many rows. See the doc string
|
||||
# for an explanation of the dimensionality of the returned matrix.
|
||||
log_onehot_x_t = log_onehot_x_t[:, :-1, :]
|
||||
|
||||
# this is a cheeky trick to produce the transition probabilities using log one-hot vectors.
|
||||
#
|
||||
# Don't worry about what values this sets in the columns that mark transitions
|
||||
# to masked latent pixels. They are overwrote later with the `mask_class_mask`.
|
||||
#
|
||||
# Looking at the below logspace formula in non-logspace, each value will evaluate to either
|
||||
# `1 * a + b = a + b` where `log_Q_t` has the one hot value in the column
|
||||
# or
|
||||
# `0 * a + b = b` where `log_Q_t` has the 0 values in the column.
|
||||
#
|
||||
# See equation 7 for more details.
|
||||
log_Q_t = (log_onehot_x_t + a).logaddexp(b)
|
||||
|
||||
# The whole column of each masked pixel is `c`
|
||||
mask_class_mask = x_t == self.mask_class
|
||||
mask_class_mask = mask_class_mask.unsqueeze(1).expand(-1, self.num_embed - 1, -1)
|
||||
log_Q_t[mask_class_mask] = c
|
||||
|
||||
if not cumulative:
|
||||
log_Q_t = torch.cat((log_Q_t, log_onehot_x_t_transitioning_from_masked), dim=1)
|
||||
|
||||
return log_Q_t
|
||||
|
||||
def apply_cumulative_transitions(self, q, t):
|
||||
bsz = q.shape[0]
|
||||
a = self.log_cumprod_at[t]
|
||||
b = self.log_cumprod_bt[t]
|
||||
c = self.log_cumprod_ct[t]
|
||||
|
||||
num_latent_pixels = q.shape[2]
|
||||
c = c.expand(bsz, 1, num_latent_pixels)
|
||||
|
||||
q = (q + a).logaddexp(b)
|
||||
q = torch.cat((q, c), dim=1)
|
||||
|
||||
return q
|
||||
@@ -31,7 +31,6 @@ from .import_utils import (
|
||||
is_scipy_available,
|
||||
is_tf_available,
|
||||
is_torch_available,
|
||||
is_torch_version,
|
||||
is_transformers_available,
|
||||
is_unidecode_available,
|
||||
requires_backends,
|
||||
@@ -43,7 +42,6 @@ from .outputs import BaseOutput
|
||||
if is_torch_available():
|
||||
from .testing_utils import (
|
||||
floats_tensor,
|
||||
load_hf_numpy,
|
||||
load_image,
|
||||
load_numpy,
|
||||
parse_flag_from_env,
|
||||
|
||||
@@ -94,21 +94,6 @@ class FlaxDDPMScheduler(metaclass=DummyObject):
|
||||
requires_backends(cls, ["flax"])
|
||||
|
||||
|
||||
class FlaxDPMSolverMultistepScheduler(metaclass=DummyObject):
|
||||
_backends = ["flax"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["flax"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["flax"])
|
||||
|
||||
|
||||
class FlaxKarrasVeScheduler(metaclass=DummyObject):
|
||||
_backends = ["flax"]
|
||||
|
||||
|
||||
@@ -34,21 +34,6 @@ class AutoencoderKL(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class Transformer2DModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class UNet1DModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
@@ -242,21 +227,6 @@ class PNDMPipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class RePaintPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class ScoreSdeVePipeline(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
@@ -302,51 +272,6 @@ class DDPMScheduler(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class DPMSolverMultistepScheduler(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class EulerAncestralDiscreteScheduler(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class EulerDiscreteScheduler(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class IPNDMScheduler(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
@@ -392,21 +317,6 @@ class PNDMScheduler(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class RePaintScheduler(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class SchedulerMixin(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
@@ -437,21 +347,6 @@ class ScoreSdeVeScheduler(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class VQDiffusionScheduler(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class EMAModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
||||
@@ -4,21 +4,6 @@
|
||||
from ..utils import DummyObject, requires_backends
|
||||
|
||||
|
||||
class CycleDiffusionPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class LDMTextToImagePipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
@@ -92,18 +77,3 @@ class StableDiffusionPipeline(metaclass=DummyObject):
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class VQDiffusionPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@@ -15,14 +15,11 @@
|
||||
Import utilities: Utilities related to imports and our lazy inits.
|
||||
"""
|
||||
import importlib.util
|
||||
import operator as op
|
||||
import os
|
||||
import sys
|
||||
from collections import OrderedDict
|
||||
from typing import Union
|
||||
|
||||
from packaging import version
|
||||
from packaging.version import Version, parse
|
||||
|
||||
from . import logging
|
||||
|
||||
@@ -43,8 +40,6 @@ USE_TF = os.environ.get("USE_TF", "AUTO").upper()
|
||||
USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
|
||||
USE_JAX = os.environ.get("USE_FLAX", "AUTO").upper()
|
||||
|
||||
STR_OPERATION_TO_FUNC = {">": op.gt, ">=": op.ge, "==": op.eq, "!=": op.ne, "<=": op.le, "<": op.lt}
|
||||
|
||||
_torch_version = "N/A"
|
||||
if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES:
|
||||
_torch_available = importlib.util.find_spec("torch") is not None
|
||||
@@ -95,8 +90,7 @@ else:
|
||||
logger.info("Disabling Tensorflow because USE_TORCH is set")
|
||||
_tf_available = False
|
||||
|
||||
_jax_version = "N/A"
|
||||
_flax_version = "N/A"
|
||||
|
||||
if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES:
|
||||
_flax_available = importlib.util.find_spec("jax") is not None and importlib.util.find_spec("flax") is not None
|
||||
if _flax_available:
|
||||
@@ -142,7 +136,6 @@ except importlib_metadata.PackageNotFoundError:
|
||||
_modelcards_available = False
|
||||
|
||||
|
||||
_onnxruntime_version = "N/A"
|
||||
_onnx_available = importlib.util.find_spec("onnxruntime") is not None
|
||||
if _onnx_available:
|
||||
candidates = ("onnxruntime", "onnxruntime-gpu", "onnxruntime-directml", "onnxruntime-openvino")
|
||||
@@ -173,18 +166,6 @@ try:
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_accelerate_available = False
|
||||
|
||||
_xformers_available = importlib.util.find_spec("xformers") is not None
|
||||
try:
|
||||
_xformers_version = importlib_metadata.version("xformers")
|
||||
if _torch_available:
|
||||
import torch
|
||||
|
||||
if torch.__version__ < version.Version("1.12"):
|
||||
raise ValueError("PyTorch should be >= 1.12")
|
||||
logger.debug(f"Successfully imported xformers version {_xformers_version}")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_xformers_available = False
|
||||
|
||||
|
||||
def is_torch_available():
|
||||
return _torch_available
|
||||
@@ -222,10 +203,6 @@ def is_scipy_available():
|
||||
return _scipy_available
|
||||
|
||||
|
||||
def is_xformers_available():
|
||||
return _xformers_available
|
||||
|
||||
|
||||
def is_accelerate_available():
|
||||
return _accelerate_available
|
||||
|
||||
@@ -314,36 +291,3 @@ class DummyObject(type):
|
||||
if key.startswith("_"):
|
||||
return super().__getattr__(cls, key)
|
||||
requires_backends(cls, cls._backends)
|
||||
|
||||
|
||||
# This function was copied from: https://github.com/huggingface/accelerate/blob/874c4967d94badd24f893064cc3bef45f57cadf7/src/accelerate/utils/versions.py#L319
|
||||
def compare_versions(library_or_version: Union[str, Version], operation: str, requirement_version: str):
|
||||
"""
|
||||
Args:
|
||||
Compares a library version to some requirement using a given operation.
|
||||
library_or_version (`str` or `packaging.version.Version`):
|
||||
A library name or a version to check.
|
||||
operation (`str`):
|
||||
A string representation of an operator, such as `">"` or `"<="`.
|
||||
requirement_version (`str`):
|
||||
The version to compare the library version against
|
||||
"""
|
||||
if operation not in STR_OPERATION_TO_FUNC.keys():
|
||||
raise ValueError(f"`operation` must be one of {list(STR_OPERATION_TO_FUNC.keys())}, received {operation}")
|
||||
operation = STR_OPERATION_TO_FUNC[operation]
|
||||
if isinstance(library_or_version, str):
|
||||
library_or_version = parse(importlib_metadata.version(library_or_version))
|
||||
return operation(library_or_version, parse(requirement_version))
|
||||
|
||||
|
||||
# This function was copied from: https://github.com/huggingface/accelerate/blob/874c4967d94badd24f893064cc3bef45f57cadf7/src/accelerate/utils/versions.py#L338
|
||||
def is_torch_version(operation: str, version: str):
|
||||
"""
|
||||
Args:
|
||||
Compares the current PyTorch version to a given reference with an operation.
|
||||
operation (`str`):
|
||||
A string representation of an operator, such as `">"` or `"<="`
|
||||
version (`str`):
|
||||
A string version of PyTorch
|
||||
"""
|
||||
return compare_versions(parse(_torch_version), operation, version)
|
||||
|
||||
@@ -139,29 +139,6 @@ def require_onnxruntime(test_case):
|
||||
return unittest.skipUnless(is_onnx_available(), "test requires onnxruntime")(test_case)
|
||||
|
||||
|
||||
def load_numpy(arry: Union[str, np.ndarray]) -> np.ndarray:
|
||||
if isinstance(arry, str):
|
||||
if arry.startswith("http://") or arry.startswith("https://"):
|
||||
response = requests.get(arry)
|
||||
response.raise_for_status()
|
||||
arry = np.load(BytesIO(response.content))
|
||||
elif os.path.isfile(arry):
|
||||
arry = np.load(arry)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Incorrect path or url, URLs must start with `http://` or `https://`, and {arry} is not a valid path"
|
||||
)
|
||||
elif isinstance(arry, np.ndarray):
|
||||
pass
|
||||
else:
|
||||
raise ValueError(
|
||||
"Incorrect format used for numpy ndarray. Should be an url linking to an image, a local path, or a"
|
||||
" ndarray."
|
||||
)
|
||||
|
||||
return arry
|
||||
|
||||
|
||||
def load_image(image: Union[str, PIL.Image.Image]) -> PIL.Image.Image:
|
||||
"""
|
||||
Args:
|
||||
@@ -191,13 +168,17 @@ def load_image(image: Union[str, PIL.Image.Image]) -> PIL.Image.Image:
|
||||
return image
|
||||
|
||||
|
||||
def load_hf_numpy(path) -> np.ndarray:
|
||||
def load_numpy(path) -> np.ndarray:
|
||||
if not path.startswith("http://") or path.startswith("https://"):
|
||||
path = os.path.join(
|
||||
"https://huggingface.co/datasets/fusing/diffusers-testing/resolve/main", urllib.parse.quote(path)
|
||||
)
|
||||
|
||||
return load_numpy(path)
|
||||
response = requests.get(path)
|
||||
response.raise_for_status()
|
||||
array = np.load(BytesIO(response.content))
|
||||
|
||||
return array
|
||||
|
||||
|
||||
# --- pytest conf functions --- #
|
||||
|
||||
@@ -28,7 +28,7 @@ class UnetModel1DTests(unittest.TestCase):
|
||||
@slow
|
||||
def test_unet_1d_maestro(self):
|
||||
model_id = "harmonai/maestro-150k"
|
||||
model = UNet1DModel.from_pretrained(model_id, subfolder="unet")
|
||||
model = UNet1DModel.from_pretrained(model_id, subfolder="unet", device_map="auto")
|
||||
model.to(torch_device)
|
||||
|
||||
sample_size = 65536
|
||||
|
||||
@@ -21,15 +21,7 @@ import unittest
|
||||
import torch
|
||||
|
||||
from diffusers import UNet2DConditionModel, UNet2DModel
|
||||
from diffusers.utils import (
|
||||
floats_tensor,
|
||||
load_hf_numpy,
|
||||
logging,
|
||||
require_torch_gpu,
|
||||
slow,
|
||||
torch_all_close,
|
||||
torch_device,
|
||||
)
|
||||
from diffusers.utils import floats_tensor, load_numpy, logging, require_torch_gpu, slow, torch_all_close, torch_device
|
||||
from parameterized import parameterized
|
||||
|
||||
from ..test_modeling_common import ModelTesterMixin
|
||||
@@ -125,7 +117,9 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
@unittest.skipIf(torch_device != "cuda", "This test is supposed to run on GPU")
|
||||
def test_from_pretrained_accelerate(self):
|
||||
model, _ = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", output_loading_info=True)
|
||||
model, _ = UNet2DModel.from_pretrained(
|
||||
"fusing/unet-ldm-dummy-update", output_loading_info=True, device_map="auto"
|
||||
)
|
||||
model.to(torch_device)
|
||||
image = model(**self.dummy_input).sample
|
||||
|
||||
@@ -133,8 +127,9 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
@unittest.skipIf(torch_device != "cuda", "This test is supposed to run on GPU")
|
||||
def test_from_pretrained_accelerate_wont_change_results(self):
|
||||
# by defautl model loading will use accelerate as `low_cpu_mem_usage=True`
|
||||
model_accelerate, _ = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", output_loading_info=True)
|
||||
model_accelerate, _ = UNet2DModel.from_pretrained(
|
||||
"fusing/unet-ldm-dummy-update", output_loading_info=True, device_map="auto"
|
||||
)
|
||||
model_accelerate.to(torch_device)
|
||||
model_accelerate.eval()
|
||||
|
||||
@@ -156,7 +151,7 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
gc.collect()
|
||||
|
||||
model_normal_load, _ = UNet2DModel.from_pretrained(
|
||||
"fusing/unet-ldm-dummy-update", output_loading_info=True, low_cpu_mem_usage=False
|
||||
"fusing/unet-ldm-dummy-update", output_loading_info=True, device_map="auto"
|
||||
)
|
||||
model_normal_load.to(torch_device)
|
||||
model_normal_load.eval()
|
||||
@@ -170,8 +165,9 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
gc.collect()
|
||||
|
||||
tracemalloc.start()
|
||||
# by defautl model loading will use accelerate as `low_cpu_mem_usage=True`
|
||||
model_accelerate, _ = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", output_loading_info=True)
|
||||
model_accelerate, _ = UNet2DModel.from_pretrained(
|
||||
"fusing/unet-ldm-dummy-update", output_loading_info=True, device_map="auto"
|
||||
)
|
||||
model_accelerate.to(torch_device)
|
||||
model_accelerate.eval()
|
||||
_, peak_accelerate = tracemalloc.get_traced_memory()
|
||||
@@ -180,9 +176,7 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
model_normal_load, _ = UNet2DModel.from_pretrained(
|
||||
"fusing/unet-ldm-dummy-update", output_loading_info=True, low_cpu_mem_usage=False
|
||||
)
|
||||
model_normal_load, _ = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", output_loading_info=True)
|
||||
model_normal_load.to(torch_device)
|
||||
model_normal_load.eval()
|
||||
_, peak_normal = tracemalloc.get_traced_memory()
|
||||
@@ -346,7 +340,9 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
@slow
|
||||
def test_from_pretrained_hub(self):
|
||||
model, loading_info = UNet2DModel.from_pretrained("google/ncsnpp-celebahq-256", output_loading_info=True)
|
||||
model, loading_info = UNet2DModel.from_pretrained(
|
||||
"google/ncsnpp-celebahq-256", output_loading_info=True, device_map="auto"
|
||||
)
|
||||
self.assertIsNotNone(model)
|
||||
self.assertEqual(len(loading_info["missing_keys"]), 0)
|
||||
|
||||
@@ -360,7 +356,7 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
@slow
|
||||
def test_output_pretrained_ve_mid(self):
|
||||
model = UNet2DModel.from_pretrained("google/ncsnpp-celebahq-256")
|
||||
model = UNet2DModel.from_pretrained("google/ncsnpp-celebahq-256", device_map="auto")
|
||||
model.to(torch_device)
|
||||
|
||||
torch.manual_seed(0)
|
||||
@@ -427,7 +423,7 @@ class UNet2DConditionModelIntegrationTests(unittest.TestCase):
|
||||
|
||||
def get_latents(self, seed=0, shape=(4, 4, 64, 64), fp16=False):
|
||||
dtype = torch.float16 if fp16 else torch.float32
|
||||
image = torch.from_numpy(load_hf_numpy(self.get_file_format(seed, shape))).to(torch_device).to(dtype)
|
||||
image = torch.from_numpy(load_numpy(self.get_file_format(seed, shape))).to(torch_device).to(dtype)
|
||||
return image
|
||||
|
||||
def get_unet_model(self, fp16=False, model_id="CompVis/stable-diffusion-v1-4"):
|
||||
@@ -435,7 +431,7 @@ class UNet2DConditionModelIntegrationTests(unittest.TestCase):
|
||||
torch_dtype = torch.float16 if fp16 else torch.float32
|
||||
|
||||
model = UNet2DConditionModel.from_pretrained(
|
||||
model_id, subfolder="unet", torch_dtype=torch_dtype, revision=revision
|
||||
model_id, subfolder="unet", torch_dtype=torch_dtype, revision=revision, device_map="auto"
|
||||
)
|
||||
model.to(torch_device).eval()
|
||||
|
||||
@@ -443,7 +439,7 @@ class UNet2DConditionModelIntegrationTests(unittest.TestCase):
|
||||
|
||||
def get_encoder_hidden_states(self, seed=0, shape=(4, 77, 768), fp16=False):
|
||||
dtype = torch.float16 if fp16 else torch.float32
|
||||
hidden_states = torch.from_numpy(load_hf_numpy(self.get_file_format(seed, shape))).to(torch_device).to(dtype)
|
||||
hidden_states = torch.from_numpy(load_numpy(self.get_file_format(seed, shape))).to(torch_device).to(dtype)
|
||||
return hidden_states
|
||||
|
||||
@parameterized.expand(
|
||||
@@ -456,7 +452,6 @@ class UNet2DConditionModelIntegrationTests(unittest.TestCase):
|
||||
# fmt: on
|
||||
]
|
||||
)
|
||||
@require_torch_gpu
|
||||
def test_compvis_sd_v1_4(self, seed, timestep, expected_slice):
|
||||
model = self.get_unet_model(model_id="CompVis/stable-diffusion-v1-4")
|
||||
latents = self.get_latents(seed)
|
||||
@@ -508,7 +503,6 @@ class UNet2DConditionModelIntegrationTests(unittest.TestCase):
|
||||
# fmt: on
|
||||
]
|
||||
)
|
||||
@require_torch_gpu
|
||||
def test_compvis_sd_v1_5(self, seed, timestep, expected_slice):
|
||||
model = self.get_unet_model(model_id="runwayml/stable-diffusion-v1-5")
|
||||
latents = self.get_latents(seed)
|
||||
@@ -560,7 +554,6 @@ class UNet2DConditionModelIntegrationTests(unittest.TestCase):
|
||||
# fmt: on
|
||||
]
|
||||
)
|
||||
@require_torch_gpu
|
||||
def test_compvis_sd_inpaint(self, seed, timestep, expected_slice):
|
||||
model = self.get_unet_model(model_id="runwayml/stable-diffusion-inpainting")
|
||||
latents = self.get_latents(seed, shape=(4, 9, 64, 64))
|
||||
|
||||
@@ -20,7 +20,7 @@ import torch
|
||||
|
||||
from diffusers import AutoencoderKL
|
||||
from diffusers.modeling_utils import ModelMixin
|
||||
from diffusers.utils import floats_tensor, load_hf_numpy, require_torch_gpu, slow, torch_all_close, torch_device
|
||||
from diffusers.utils import floats_tensor, load_numpy, require_torch_gpu, slow, torch_all_close, torch_device
|
||||
from parameterized import parameterized
|
||||
|
||||
from ..test_modeling_common import ModelTesterMixin
|
||||
@@ -147,7 +147,7 @@ class AutoencoderKLIntegrationTests(unittest.TestCase):
|
||||
|
||||
def get_sd_image(self, seed=0, shape=(4, 3, 512, 512), fp16=False):
|
||||
dtype = torch.float16 if fp16 else torch.float32
|
||||
image = torch.from_numpy(load_hf_numpy(self.get_file_format(seed, shape))).to(torch_device).to(dtype)
|
||||
image = torch.from_numpy(load_numpy(self.get_file_format(seed, shape))).to(torch_device).to(dtype)
|
||||
return image
|
||||
|
||||
def get_sd_vae_model(self, model_id="CompVis/stable-diffusion-v1-4", fp16=False):
|
||||
@@ -155,10 +155,7 @@ class AutoencoderKLIntegrationTests(unittest.TestCase):
|
||||
torch_dtype = torch.float16 if fp16 else torch.float32
|
||||
|
||||
model = AutoencoderKL.from_pretrained(
|
||||
model_id,
|
||||
subfolder="vae",
|
||||
torch_dtype=torch_dtype,
|
||||
revision=revision,
|
||||
model_id, subfolder="vae", torch_dtype=torch_dtype, revision=revision, device_map="auto"
|
||||
)
|
||||
model.to(torch_device).eval()
|
||||
|
||||
|
||||
@@ -86,7 +86,7 @@ class PipelineIntegrationTests(unittest.TestCase):
|
||||
def test_dance_diffusion(self):
|
||||
device = torch_device
|
||||
|
||||
pipe = DanceDiffusionPipeline.from_pretrained("harmonai/maestro-150k")
|
||||
pipe = DanceDiffusionPipeline.from_pretrained("harmonai/maestro-150k", device_map="auto")
|
||||
pipe = pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
@@ -103,7 +103,9 @@ class PipelineIntegrationTests(unittest.TestCase):
|
||||
def test_dance_diffusion_fp16(self):
|
||||
device = torch_device
|
||||
|
||||
pipe = DanceDiffusionPipeline.from_pretrained("harmonai/maestro-150k", torch_dtype=torch.float16)
|
||||
pipe = DanceDiffusionPipeline.from_pretrained(
|
||||
"harmonai/maestro-150k", torch_dtype=torch.float16, device_map="auto"
|
||||
)
|
||||
pipe = pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
|
||||
@@ -78,7 +78,7 @@ class DDIMPipelineIntegrationTests(unittest.TestCase):
|
||||
def test_inference_ema_bedroom(self):
|
||||
model_id = "google/ddpm-ema-bedroom-256"
|
||||
|
||||
unet = UNet2DModel.from_pretrained(model_id)
|
||||
unet = UNet2DModel.from_pretrained(model_id, device_map="auto")
|
||||
scheduler = DDIMScheduler.from_config(model_id)
|
||||
|
||||
ddpm = DDIMPipeline(unet=unet, scheduler=scheduler)
|
||||
@@ -97,7 +97,7 @@ class DDIMPipelineIntegrationTests(unittest.TestCase):
|
||||
def test_inference_cifar10(self):
|
||||
model_id = "google/ddpm-cifar10-32"
|
||||
|
||||
unet = UNet2DModel.from_pretrained(model_id)
|
||||
unet = UNet2DModel.from_pretrained(model_id, device_map="auto")
|
||||
scheduler = DDIMScheduler()
|
||||
|
||||
ddim = DDIMPipeline(unet=unet, scheduler=scheduler)
|
||||
|
||||
@@ -38,7 +38,7 @@ class DDPMPipelineIntegrationTests(unittest.TestCase):
|
||||
def test_inference_cifar10(self):
|
||||
model_id = "google/ddpm-cifar10-32"
|
||||
|
||||
unet = UNet2DModel.from_pretrained(model_id)
|
||||
unet = UNet2DModel.from_pretrained(model_id, device_map="auto")
|
||||
scheduler = DDPMScheduler.from_config(model_id)
|
||||
|
||||
ddpm = DDPMPipeline(unet=unet, scheduler=scheduler)
|
||||
|
||||
@@ -70,7 +70,7 @@ class KarrasVePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
class KarrasVePipelineIntegrationTests(unittest.TestCase):
|
||||
def test_inference(self):
|
||||
model_id = "google/ncsnpp-celebahq-256"
|
||||
model = UNet2DModel.from_pretrained(model_id)
|
||||
model = UNet2DModel.from_pretrained(model_id, device_map="auto")
|
||||
scheduler = KarrasVeScheduler()
|
||||
|
||||
pipe = KarrasVePipeline(unet=model, scheduler=scheduler)
|
||||
|
||||
@@ -121,7 +121,7 @@ class LDMTextToImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
@require_torch
|
||||
class LDMTextToImagePipelineIntegrationTests(unittest.TestCase):
|
||||
def test_inference_text2img(self):
|
||||
ldm = LDMTextToImagePipeline.from_pretrained("CompVis/ldm-text2im-large-256")
|
||||
ldm = LDMTextToImagePipeline.from_pretrained("CompVis/ldm-text2im-large-256", device_map="auto")
|
||||
ldm.to(torch_device)
|
||||
ldm.set_progress_bar_config(disable=None)
|
||||
|
||||
@@ -138,7 +138,7 @@ class LDMTextToImagePipelineIntegrationTests(unittest.TestCase):
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
def test_inference_text2img_fast(self):
|
||||
ldm = LDMTextToImagePipeline.from_pretrained("CompVis/ldm-text2im-large-256")
|
||||
ldm = LDMTextToImagePipeline.from_pretrained("CompVis/ldm-text2im-large-256", device_map="auto")
|
||||
ldm.to(torch_device)
|
||||
ldm.set_progress_bar_config(disable=None)
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user