mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-24 13:24:49 +08:00
Compare commits
34 Commits
improve_co
...
v0.11.1
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f7bb9cae11 | ||
|
|
40b0519a8a | ||
|
|
a5edb981a7 | ||
|
|
54796b7e43 | ||
|
|
4cb887e0a7 | ||
|
|
9f657f106d | ||
|
|
ce1c27adc8 | ||
|
|
b267d28566 | ||
|
|
c7b4acfb37 | ||
|
|
be38b2d711 | ||
|
|
32a5d70c42 | ||
|
|
429e5449c1 | ||
|
|
dc7cd893fd | ||
|
|
8890758823 | ||
|
|
b25843e799 | ||
|
|
830a9d1f01 | ||
|
|
2dcf64b72a | ||
|
|
402b9560b2 | ||
|
|
c2a38ef9df | ||
|
|
08cc36ddff | ||
|
|
723e8f6bb4 | ||
|
|
c53a850604 | ||
|
|
086c7f9ea8 | ||
|
|
acd317810b | ||
|
|
c6d0dff4a3 | ||
|
|
a40095dd22 | ||
|
|
727434c206 | ||
|
|
21e61eb3a9 | ||
|
|
c891330f79 | ||
|
|
c5f04d4e34 | ||
|
|
61dec53356 | ||
|
|
badddee0ef | ||
|
|
13994b2d3f | ||
|
|
ea90bf2ba1 |
106
.github/workflows/nightly_tests.yml
vendored
106
.github/workflows/nightly_tests.yml
vendored
@@ -1,4 +1,4 @@
|
||||
name: Nightly integration tests
|
||||
name: Nightly tests on main
|
||||
|
||||
on:
|
||||
schedule:
|
||||
@@ -9,12 +9,107 @@ env:
|
||||
HF_HOME: /mnt/cache
|
||||
OMP_NUM_THREADS: 8
|
||||
MKL_NUM_THREADS: 8
|
||||
PYTEST_TIMEOUT: 1000
|
||||
PYTEST_TIMEOUT: 600
|
||||
RUN_SLOW: yes
|
||||
RUN_NIGHTLY: yes
|
||||
|
||||
jobs:
|
||||
run_slow_tests_apple_m1:
|
||||
name: Slow PyTorch MPS tests on MacOS
|
||||
run_nightly_tests:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
config:
|
||||
- name: Nightly PyTorch CUDA tests on Ubuntu
|
||||
framework: pytorch
|
||||
runner: docker-gpu
|
||||
image: diffusers/diffusers-pytorch-cuda
|
||||
report: torch_cuda
|
||||
- name: Nightly Flax TPU tests on Ubuntu
|
||||
framework: flax
|
||||
runner: docker-tpu
|
||||
image: diffusers/diffusers-flax-tpu
|
||||
report: flax_tpu
|
||||
- name: Nightly ONNXRuntime CUDA tests on Ubuntu
|
||||
framework: onnxruntime
|
||||
runner: docker-gpu
|
||||
image: diffusers/diffusers-onnxruntime-cuda
|
||||
report: onnx_cuda
|
||||
|
||||
name: ${{ matrix.config.name }}
|
||||
|
||||
runs-on: ${{ matrix.config.runner }}
|
||||
|
||||
container:
|
||||
image: ${{ matrix.config.image }}
|
||||
options: --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache/ ${{ matrix.config.runner == 'docker-tpu' && '--privileged' || '--gpus 0'}}
|
||||
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
- name: NVIDIA-SMI
|
||||
if: ${{ matrix.config.runner == 'docker-gpu' }}
|
||||
run: |
|
||||
nvidia-smi
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install -e .[quality,test]
|
||||
python -m pip install -U git+https://github.com/huggingface/transformers
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
python utils/print_env.py
|
||||
|
||||
- name: Run nightly PyTorch CUDA tests
|
||||
if: ${{ matrix.config.framework == 'pytorch' }}
|
||||
env:
|
||||
HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }}
|
||||
run: |
|
||||
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "not Flax and not Onnx" \
|
||||
--make-reports=tests_${{ matrix.config.report }} \
|
||||
tests/
|
||||
|
||||
- name: Run nightly Flax TPU tests
|
||||
if: ${{ matrix.config.framework == 'flax' }}
|
||||
env:
|
||||
HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }}
|
||||
run: |
|
||||
python -m pytest -n 0 \
|
||||
-s -v -k "Flax" \
|
||||
--make-reports=tests_${{ matrix.config.report }} \
|
||||
tests/
|
||||
|
||||
- name: Run nightly ONNXRuntime CUDA tests
|
||||
if: ${{ matrix.config.framework == 'onnxruntime' }}
|
||||
env:
|
||||
HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }}
|
||||
run: |
|
||||
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "Onnx" \
|
||||
--make-reports=tests_${{ matrix.config.report }} \
|
||||
tests/
|
||||
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
run: cat reports/tests_${{ matrix.config.report }}_failures_short.txt
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v2
|
||||
with:
|
||||
name: ${{ matrix.config.report }}_test_reports
|
||||
path: reports
|
||||
|
||||
run_nightly_tests_apple_m1:
|
||||
name: Nightly PyTorch MPS tests on MacOS
|
||||
runs-on: [ self-hosted, apple-m1 ]
|
||||
|
||||
steps:
|
||||
@@ -39,14 +134,13 @@ jobs:
|
||||
${CONDA_RUN} python -m pip install --upgrade pip
|
||||
${CONDA_RUN} python -m pip install -e .[quality,test]
|
||||
${CONDA_RUN} python -m pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu
|
||||
${CONDA_RUN} python -m pip install git+https://github.com/huggingface/accelerate
|
||||
|
||||
- name: Environment
|
||||
shell: arch -arch arm64 bash {0}
|
||||
run: |
|
||||
${CONDA_RUN} python utils/print_env.py
|
||||
|
||||
- name: Run slow PyTorch tests on M1 (MPS)
|
||||
- name: Run nightly PyTorch tests on M1 (MPS)
|
||||
shell: arch -arch arm64 bash {0}
|
||||
env:
|
||||
HF_HOME: /System/Volumes/Data/mnt/cache
|
||||
|
||||
4
.github/workflows/pr_tests.yml
vendored
4
.github/workflows/pr_tests.yml
vendored
@@ -1,4 +1,4 @@
|
||||
name: Run fast tests
|
||||
name: Fast tests for PRs
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
@@ -59,7 +59,6 @@ jobs:
|
||||
run: |
|
||||
apt-get update && apt-get install libsndfile1-dev -y
|
||||
python -m pip install -e .[quality,test]
|
||||
python -m pip install git+https://github.com/huggingface/accelerate
|
||||
python -m pip install -U git+https://github.com/huggingface/transformers
|
||||
|
||||
- name: Environment
|
||||
@@ -127,7 +126,6 @@ jobs:
|
||||
${CONDA_RUN} python -m pip install --upgrade pip
|
||||
${CONDA_RUN} python -m pip install -e .[quality,test]
|
||||
${CONDA_RUN} python -m pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu
|
||||
${CONDA_RUN} python -m pip install git+https://github.com/huggingface/accelerate
|
||||
${CONDA_RUN} python -m pip install -U git+https://github.com/huggingface/transformers
|
||||
|
||||
- name: Environment
|
||||
|
||||
6
.github/workflows/push_tests.yml
vendored
6
.github/workflows/push_tests.yml
vendored
@@ -1,4 +1,4 @@
|
||||
name: Run all tests
|
||||
name: Slow tests on main
|
||||
|
||||
on:
|
||||
push:
|
||||
@@ -10,7 +10,7 @@ env:
|
||||
HF_HOME: /mnt/cache
|
||||
OMP_NUM_THREADS: 8
|
||||
MKL_NUM_THREADS: 8
|
||||
PYTEST_TIMEOUT: 1000
|
||||
PYTEST_TIMEOUT: 600
|
||||
RUN_SLOW: yes
|
||||
|
||||
jobs:
|
||||
@@ -61,7 +61,6 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install -e .[quality,test]
|
||||
python -m pip install git+https://github.com/huggingface/accelerate
|
||||
python -m pip install -U git+https://github.com/huggingface/transformers
|
||||
|
||||
- name: Environment
|
||||
@@ -131,7 +130,6 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install -e .[quality,test,training]
|
||||
python -m pip install git+https://github.com/huggingface/accelerate
|
||||
python -m pip install -U git+https://github.com/huggingface/transformers
|
||||
|
||||
- name: Environment
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -166,3 +166,6 @@ tags
|
||||
.DS_Store
|
||||
# RL pipelines may produce mp4 outputs
|
||||
*.mp4
|
||||
|
||||
# dependencies
|
||||
/transformers
|
||||
|
||||
@@ -302,11 +302,8 @@ image = pipe(prompt=prompt, image=init_image, mask_image=mask_image).images[0]
|
||||
|
||||
### Tweak prompts reusing seeds and latents
|
||||
|
||||
You can generate your own latents to reproduce results, or tweak your prompt on a specific result you liked. [This notebook](https://github.com/pcuenca/diffusers-examples/blob/main/notebooks/stable-diffusion-seeds.ipynb) shows how to do it step by step. You can also run it in Google Colab [](https://colab.research.google.com/github/pcuenca/diffusers-examples/blob/main/notebooks/stable-diffusion-seeds.ipynb).
|
||||
|
||||
|
||||
For more details, check out [the Stable Diffusion notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/stable_diffusion.ipynb) [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/stable_diffusion.ipynb)
|
||||
and have a look into the [release notes](https://github.com/huggingface/diffusers/releases/tag/v0.2.0).
|
||||
You can generate your own latents to reproduce results, or tweak your prompt on a specific result you liked.
|
||||
Please have a look at [Reusing seeds for deterministic generation](https://huggingface.co/docs/diffusers/main/en/using-diffusers/reusing_seeds).
|
||||
|
||||
## Fine-Tuning Stable Diffusion
|
||||
|
||||
|
||||
266
docs/README.md
Normal file
266
docs/README.md
Normal file
@@ -0,0 +1,266 @@
|
||||
<!---
|
||||
Copyright 2022- The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
-->
|
||||
|
||||
# Generating the documentation
|
||||
|
||||
To generate the documentation, you first have to build it. Several packages are necessary to build the doc,
|
||||
you can install them with the following command, at the root of the code repository:
|
||||
|
||||
```bash
|
||||
pip install -e ".[docs]"
|
||||
```
|
||||
|
||||
Then you need to install our open source documentation builder tool:
|
||||
|
||||
```bash
|
||||
pip install git+https://github.com/huggingface/doc-builder
|
||||
```
|
||||
|
||||
---
|
||||
**NOTE**
|
||||
|
||||
You only need to generate the documentation to inspect it locally (if you're planning changes and want to
|
||||
check how they look before committing for instance). You don't have to commit the built documentation.
|
||||
|
||||
---
|
||||
|
||||
## Previewing the documentation
|
||||
|
||||
To preview the docs, first install the `watchdog` module with:
|
||||
|
||||
```bash
|
||||
pip install watchdog
|
||||
```
|
||||
|
||||
Then run the following command:
|
||||
|
||||
```bash
|
||||
doc-builder preview {package_name} {path_to_docs}
|
||||
```
|
||||
|
||||
For example:
|
||||
|
||||
```bash
|
||||
doc-builder preview diffusers docs/source/
|
||||
```
|
||||
|
||||
The docs will be viewable at [http://localhost:3000](http://localhost:3000). You can also preview the docs once you have opened a PR. You will see a bot add a comment to a link where the documentation with your changes lives.
|
||||
|
||||
---
|
||||
**NOTE**
|
||||
|
||||
The `preview` command only works with existing doc files. When you add a completely new file, you need to update `_toctree.yml` & restart `preview` command (`ctrl-c` to stop it & call `doc-builder preview ...` again).
|
||||
|
||||
---
|
||||
|
||||
## Adding a new element to the navigation bar
|
||||
|
||||
Accepted files are Markdown (.md or .mdx).
|
||||
|
||||
Create a file with its extension and put it in the source directory. You can then link it to the toc-tree by putting
|
||||
the filename without the extension in the [`_toctree.yml`](https://github.com/huggingface/diffusers/blob/main/docs/source/_toctree.yml) file.
|
||||
|
||||
## Renaming section headers and moving sections
|
||||
|
||||
It helps to keep the old links working when renaming the section header and/or moving sections from one document to another. This is because the old links are likely to be used in Issues, Forums, and Social media and it'd make for a much more superior user experience if users reading those months later could still easily navigate to the originally intended information.
|
||||
|
||||
Therefore, we simply keep a little map of moved sections at the end of the document where the original section was. The key is to preserve the original anchor.
|
||||
|
||||
So if you renamed a section from: "Section A" to "Section B", then you can add at the end of the file:
|
||||
|
||||
```
|
||||
Sections that were moved:
|
||||
|
||||
[ <a href="#section-b">Section A</a><a id="section-a"></a> ]
|
||||
```
|
||||
and of course, if you moved it to another file, then:
|
||||
|
||||
```
|
||||
Sections that were moved:
|
||||
|
||||
[ <a href="../new-file#section-b">Section A</a><a id="section-a"></a> ]
|
||||
```
|
||||
|
||||
Use the relative style to link to the new file so that the versioned docs continue to work.
|
||||
|
||||
For an example of a rich moved section set please see the very end of [the transformers Trainer doc](https://github.com/huggingface/transformers/blob/main/docs/source/en/main_classes/trainer.mdx).
|
||||
|
||||
|
||||
## Writing Documentation - Specification
|
||||
|
||||
The `huggingface/diffusers` documentation follows the
|
||||
[Google documentation](https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html) style for docstrings,
|
||||
although we can write them directly in Markdown.
|
||||
|
||||
### Adding a new tutorial
|
||||
|
||||
Adding a new tutorial or section is done in two steps:
|
||||
|
||||
- Add a new file under `docs/source`. This file can either be ReStructuredText (.rst) or Markdown (.md).
|
||||
- Link that file in `docs/source/_toctree.yml` on the correct toc-tree.
|
||||
|
||||
Make sure to put your new file under the proper section. It's unlikely to go in the first section (*Get Started*), so
|
||||
depending on the intended targets (beginners, more advanced users, or researchers) it should go in sections two, three, or four.
|
||||
|
||||
### Adding a new pipeline/scheduler
|
||||
|
||||
When adding a new pipeline:
|
||||
|
||||
- create a file `xxx.mdx` under `docs/source/api/pipelines` (don't hesitate to copy an existing file as template).
|
||||
- Link that file in (*Diffusers Summary*) section in `docs/source/api/pipelines/overview.mdx`, along with the link to the paper, and a colab notebook (if available).
|
||||
- Write a short overview of the diffusion model:
|
||||
- Overview with paper & authors
|
||||
- Paper abstract
|
||||
- Tips and tricks and how to use it best
|
||||
- Possible an end-to-end example of how to use it
|
||||
- Add all the pipeline classes that should be linked in the diffusion model. These classes should be added using our Markdown syntax. Usually as follows:
|
||||
|
||||
```
|
||||
## XXXPipeline
|
||||
|
||||
[[autodoc]] XXXPipeline
|
||||
```
|
||||
|
||||
This will include every public method of the pipeline that is documented. You can specify which methods should be in the docs:
|
||||
|
||||
```
|
||||
## XXXPipeline
|
||||
|
||||
[[autodoc]] XXXPipeline
|
||||
- __call__
|
||||
```
|
||||
|
||||
You can follow the same process to create a new scheduler under the `docs/source/api/schedulers` folder
|
||||
|
||||
### Writing source documentation
|
||||
|
||||
Values that should be put in `code` should either be surrounded by backticks: \`like so\`. Note that argument names
|
||||
and objects like True, None, or any strings should usually be put in `code`.
|
||||
|
||||
When mentioning a class, function, or method, it is recommended to use our syntax for internal links so that our tool
|
||||
adds a link to its documentation with this syntax: \[\`XXXClass\`\] or \[\`function\`\]. This requires the class or
|
||||
function to be in the main package.
|
||||
|
||||
If you want to create a link to some internal class or function, you need to
|
||||
provide its path. For instance: \[\`pipeline_utils.ImagePipelineOutput\`\]. This will be converted into a link with
|
||||
`pipeline_utils.ImagePipelineOutput` in the description. To get rid of the path and only keep the name of the object you are
|
||||
linking to in the description, add a ~: \[\`~pipeline_utils.ImagePipelineOutput\`\] will generate a link with `ImagePipelineOutput` in the description.
|
||||
|
||||
The same works for methods so you can either use \[\`XXXClass.method\`\] or \[~\`XXXClass.method\`\].
|
||||
|
||||
#### Defining arguments in a method
|
||||
|
||||
Arguments should be defined with the `Args:` (or `Arguments:` or `Parameters:`) prefix, followed by a line return and
|
||||
an indentation. The argument should be followed by its type, with its shape if it is a tensor, a colon, and its
|
||||
description:
|
||||
|
||||
```
|
||||
Args:
|
||||
n_layers (`int`): The number of layers of the model.
|
||||
```
|
||||
|
||||
If the description is too long to fit in one line, another indentation is necessary before writing the description
|
||||
after the argument.
|
||||
|
||||
Here's an example showcasing everything so far:
|
||||
|
||||
```
|
||||
Args:
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
Indices of input sequence tokens in the vocabulary.
|
||||
|
||||
Indices can be obtained using [`AlbertTokenizer`]. See [`~PreTrainedTokenizer.encode`] and
|
||||
[`~PreTrainedTokenizer.__call__`] for details.
|
||||
|
||||
[What are input IDs?](../glossary#input-ids)
|
||||
```
|
||||
|
||||
For optional arguments or arguments with defaults we follow the following syntax: imagine we have a function with the
|
||||
following signature:
|
||||
|
||||
```
|
||||
def my_function(x: str = None, a: float = 1):
|
||||
```
|
||||
|
||||
then its documentation should look like this:
|
||||
|
||||
```
|
||||
Args:
|
||||
x (`str`, *optional*):
|
||||
This argument controls ...
|
||||
a (`float`, *optional*, defaults to 1):
|
||||
This argument is used to ...
|
||||
```
|
||||
|
||||
Note that we always omit the "defaults to \`None\`" when None is the default for any argument. Also note that even
|
||||
if the first line describing your argument type and its default gets long, you can't break it on several lines. You can
|
||||
however write as many lines as you want in the indented description (see the example above with `input_ids`).
|
||||
|
||||
#### Writing a multi-line code block
|
||||
|
||||
Multi-line code blocks can be useful for displaying examples. They are done between two lines of three backticks as usual in Markdown:
|
||||
|
||||
|
||||
````
|
||||
```
|
||||
# first line of code
|
||||
# second line
|
||||
# etc
|
||||
```
|
||||
````
|
||||
|
||||
#### Writing a return block
|
||||
|
||||
The return block should be introduced with the `Returns:` prefix, followed by a line return and an indentation.
|
||||
The first line should be the type of the return, followed by a line return. No need to indent further for the elements
|
||||
building the return.
|
||||
|
||||
Here's an example of a single value return:
|
||||
|
||||
```
|
||||
Returns:
|
||||
`List[int]`: A list of integers in the range [0, 1] --- 1 for a special token, 0 for a sequence token.
|
||||
```
|
||||
|
||||
Here's an example of a tuple return, comprising several objects:
|
||||
|
||||
```
|
||||
Returns:
|
||||
`tuple(torch.FloatTensor)` comprising various elements depending on the configuration ([`BertConfig`]) and inputs:
|
||||
- ** loss** (*optional*, returned when `masked_lm_labels` is provided) `torch.FloatTensor` of shape `(1,)` --
|
||||
Total loss is the sum of the masked language modeling loss and the next sequence prediction (classification) loss.
|
||||
- **prediction_scores** (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`) --
|
||||
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
||||
```
|
||||
|
||||
#### Adding an image
|
||||
|
||||
Due to the rapidly growing repository, it is important to make sure that no files that would significantly weigh down the repository are added. This includes images, videos, and other non-text files. We prefer to leverage a hf.co hosted `dataset` like
|
||||
the ones hosted on [`hf-internal-testing`](https://huggingface.co/hf-internal-testing) in which to place these files and reference
|
||||
them by URL. We recommend putting them in the following dataset: [huggingface/documentation-images](https://huggingface.co/datasets/huggingface/documentation-images).
|
||||
If an external contribution, feel free to add the images to your PR and ask a Hugging Face member to migrate your images
|
||||
to this dataset.
|
||||
|
||||
## Styling the docstring
|
||||
|
||||
We have an automatic script running with the `make style` command that will make sure that:
|
||||
- the docstrings fully take advantage of the line width
|
||||
- all code examples are formatted using black, like the code of the Transformers library
|
||||
|
||||
This script may have some weird failures if you made a syntax mistake or if you uncover a bug. Therefore, it's
|
||||
recommended to commit your changes before running `make style`, so you can revert the changes done by that script
|
||||
easily.
|
||||
|
||||
@@ -28,6 +28,8 @@
|
||||
title: "Text-Guided Image-Inpainting"
|
||||
- local: using-diffusers/depth2img
|
||||
title: "Text-Guided Depth-to-Image"
|
||||
- local: using-diffusers/reusing_seeds
|
||||
title: "Reusing seeds for deterministic generation"
|
||||
- local: using-diffusers/custom_pipeline_examples
|
||||
title: "Community Pipelines"
|
||||
- local: using-diffusers/contribute_pipeline
|
||||
@@ -45,6 +47,8 @@
|
||||
- sections:
|
||||
- local: optimization/fp16
|
||||
title: "Memory and Speed"
|
||||
- local: optimization/xformers
|
||||
title: "xFormers"
|
||||
- local: optimization/onnx
|
||||
title: "ONNX"
|
||||
- local: optimization/open_vino
|
||||
@@ -78,8 +82,6 @@
|
||||
- sections:
|
||||
- local: api/models
|
||||
title: "Models"
|
||||
- local: api/schedulers
|
||||
title: "Schedulers"
|
||||
- local: api/diffusion_pipeline
|
||||
title: "Diffusion Pipeline"
|
||||
- local: api/logging
|
||||
@@ -120,6 +122,8 @@
|
||||
title: "Stochastic Karras VE"
|
||||
- local: api/pipelines/dance_diffusion
|
||||
title: "Dance Diffusion"
|
||||
- local: api/pipelines/unclip
|
||||
title: "UnCLIP"
|
||||
- local: api/pipelines/versatile_diffusion
|
||||
title: "Versatile Diffusion"
|
||||
- local: api/pipelines/vq_diffusion
|
||||
@@ -129,6 +133,44 @@
|
||||
- local: api/pipelines/audio_diffusion
|
||||
title: "Audio Diffusion"
|
||||
title: "Pipelines"
|
||||
- sections:
|
||||
- local: api/schedulers/overview
|
||||
title: "Overview"
|
||||
- local: api/schedulers/ddim
|
||||
title: "DDIM"
|
||||
- local: api/schedulers/ddpm
|
||||
title: "DDPM"
|
||||
- local: api/schedulers/singlestep_dpm_solver
|
||||
title: "Singlestep DPM-Solver"
|
||||
- local: api/schedulers/multistep_dpm_solver
|
||||
title: "Multistep DPM-Solver"
|
||||
- local: api/schedulers/heun
|
||||
title: "Heun Scheduler"
|
||||
- local: api/schedulers/dpm_discrete
|
||||
title: "DPM Discrete Scheduler"
|
||||
- local: api/schedulers/dpm_discrete_ancestral
|
||||
title: "DPM Discrete Scheduler with ancestral sampling"
|
||||
- local: api/schedulers/stochastic_karras_ve
|
||||
title: "Stochastic Kerras VE"
|
||||
- local: api/schedulers/lms_discrete
|
||||
title: "Linear Multistep"
|
||||
- local: api/schedulers/pndm
|
||||
title: "PNDM"
|
||||
- local: api/schedulers/score_sde_ve
|
||||
title: "VE-SDE"
|
||||
- local: api/schedulers/ipndm
|
||||
title: "IPNDM"
|
||||
- local: api/schedulers/score_sde_vp
|
||||
title: "VP-SDE"
|
||||
- local: api/schedulers/euler
|
||||
title: "Euler scheduler"
|
||||
- local: api/schedulers/euler_ancestral
|
||||
title: "Euler Ancestral Scheduler"
|
||||
- local: api/schedulers/vq_diffusion
|
||||
title: "VQDiffusionScheduler"
|
||||
- local: api/schedulers/repaint
|
||||
title: "RePaint Scheduler"
|
||||
title: "Schedulers"
|
||||
- sections:
|
||||
- local: api/experimental/rl
|
||||
title: "RL Planning"
|
||||
|
||||
@@ -58,6 +58,12 @@ The models are built on the base class ['ModelMixin'] that is a `torch.nn.module
|
||||
## Transformer2DModelOutput
|
||||
[[autodoc]] models.attention.Transformer2DModelOutput
|
||||
|
||||
## PriorTransformer
|
||||
[[autodoc]] models.prior_transformer.PriorTransformer
|
||||
|
||||
## PriorTransformerOutput
|
||||
[[autodoc]] models.prior_transformer.PriorTransformerOutput
|
||||
|
||||
## FlaxModelMixin
|
||||
[[autodoc]] FlaxModelMixin
|
||||
|
||||
|
||||
@@ -44,31 +44,32 @@ available a colab notebook to directly try them out.
|
||||
|
||||
| Pipeline | Paper | Tasks | Colab
|
||||
|---|---|:---:|:---:|
|
||||
| [alt_diffusion](./api/pipelines/alt_diffusion) | [**AltDiffusion**](https://arxiv.org/abs/2211.06679) | Image-to-Image Text-Guided Generation | -
|
||||
| [audio_diffusion](./api/pipelines/audio_diffusion) | [**Audio Diffusion**](https://github.com/teticio/audio_diffusion.git) | Unconditional Audio Generation |
|
||||
| [cycle_diffusion](./api/pipelines/cycle_diffusion) | [**Cycle Diffusion**](https://arxiv.org/abs/2210.05559) | Image-to-Image Text-Guided Generation |
|
||||
| [dance_diffusion](./api/pipelines/dance_diffusion) | [**Dance Diffusion**](https://github.com/williamberman/diffusers.git) | Unconditional Audio Generation |
|
||||
| [ddpm](./api/pipelines/ddpm) | [**Denoising Diffusion Probabilistic Models**](https://arxiv.org/abs/2006.11239) | Unconditional Image Generation |
|
||||
| [ddim](./api/pipelines/ddim) | [**Denoising Diffusion Implicit Models**](https://arxiv.org/abs/2010.02502) | Unconditional Image Generation |
|
||||
| [latent_diffusion](./api/pipelines/latent_diffusion) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752)| Text-to-Image Generation |
|
||||
| [latent_diffusion](./api/pipelines/latent_diffusion) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752)| Super Resolution Image-to-Image |
|
||||
| [latent_diffusion_uncond](./api/pipelines/latent_diffusion_uncond) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752) | Unconditional Image Generation |
|
||||
| [paint_by_example](./api/pipelines/paint_by_example) | [**Paint by Example: Exemplar-based Image Editing with Diffusion Models**](https://arxiv.org/abs/2211.13227) | Image-Guided Image Inpainting |
|
||||
| [pndm](./api/pipelines/pndm) | [**Pseudo Numerical Methods for Diffusion Models on Manifolds**](https://arxiv.org/abs/2202.09778) | Unconditional Image Generation |
|
||||
| [score_sde_ve](./api/pipelines/score_sde_ve) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation |
|
||||
| [score_sde_vp](./api/pipelines/score_sde_vp) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation |
|
||||
| [stable_diffusion](./api/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-to-Image Generation | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb)
|
||||
| [stable_diffusion](./api/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Image-to-Image Text-Guided Generation | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb)
|
||||
| [stable_diffusion](./api/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-Guided Image Inpainting | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb)
|
||||
| [stable_diffusion_2](./api/pipelines/stable_diffusion_2) | [**Stable Diffusion 2**](https://stability.ai/blog/stable-diffusion-v2-release) | Text-to-Image Generation |
|
||||
| [stable_diffusion_2](./api/pipelines/stable_diffusion_2) | [**Stable Diffusion 2**](https://stability.ai/blog/stable-diffusion-v2-release) | Text-Guided Image Inpainting |
|
||||
| [stable_diffusion_2](./api/pipelines/stable_diffusion_2) | [**Stable Diffusion 2**](https://stability.ai/blog/stable-diffusion-v2-release) | Text-Guided Super Resolution Image-to-Image |
|
||||
| [stable_diffusion_safe](./api/pipelines/stable_diffusion_safe) | [**Safe Stable Diffusion**](https://arxiv.org/abs/2211.05105) | Text-Guided Generation | [](https://colab.research.google.com/github/ml-research/safe-latent-diffusion/blob/main/examples/Safe%20Latent%20Diffusion.ipynb)
|
||||
| [stochastic_karras_ve](./api/pipelines/stochastic_karras_ve) | [**Elucidating the Design Space of Diffusion-Based Generative Models**](https://arxiv.org/abs/2206.00364) | Unconditional Image Generation |
|
||||
| [versatile_diffusion](./api/pipelines/versatile_diffusion) | [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://arxiv.org/abs/2211.08332) | Text-to-Image Generation |
|
||||
| [versatile_diffusion](./api/pipelines/versatile_diffusion) | [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://arxiv.org/abs/2211.08332) | Image Variations Generation |
|
||||
| [versatile_diffusion](./api/pipelines/versatile_diffusion) | [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://arxiv.org/abs/2211.08332) | Dual Image and Text Guided Generation |
|
||||
| [vq_diffusion](./api/pipelines/vq_diffusion) | [Vector Quantized Diffusion Model for Text-to-Image Synthesis](https://arxiv.org/abs/2111.14822) | Text-to-Image Generation |
|
||||
| [alt_diffusion](./alt_diffusion) | [**AltDiffusion**](https://arxiv.org/abs/2211.06679) | Image-to-Image Text-Guided Generation | -
|
||||
| [audio_diffusion](./audio_diffusion) | [**Audio Diffusion**](https://github.com/teticio/audio_diffusion.git) | Unconditional Audio Generation |
|
||||
| [cycle_diffusion](./cycle_diffusion) | [**Cycle Diffusion**](https://arxiv.org/abs/2210.05559) | Image-to-Image Text-Guided Generation |
|
||||
| [dance_diffusion](./dance_diffusion) | [**Dance Diffusion**](https://github.com/williamberman/diffusers.git) | Unconditional Audio Generation |
|
||||
| [ddpm](./ddpm) | [**Denoising Diffusion Probabilistic Models**](https://arxiv.org/abs/2006.11239) | Unconditional Image Generation |
|
||||
| [ddim](./ddim) | [**Denoising Diffusion Implicit Models**](https://arxiv.org/abs/2010.02502) | Unconditional Image Generation |
|
||||
| [latent_diffusion](./latent_diffusion) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752)| Text-to-Image Generation |
|
||||
| [latent_diffusion](./latent_diffusion) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752)| Super Resolution Image-to-Image |
|
||||
| [latent_diffusion_uncond](./latent_diffusion_uncond) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752) | Unconditional Image Generation |
|
||||
| [paint_by_example](./paint_by_example) | [**Paint by Example: Exemplar-based Image Editing with Diffusion Models**](https://arxiv.org/abs/2211.13227) | Image-Guided Image Inpainting |
|
||||
| [pndm](./pndm) | [**Pseudo Numerical Methods for Diffusion Models on Manifolds**](https://arxiv.org/abs/2202.09778) | Unconditional Image Generation |
|
||||
| [score_sde_ve](./score_sde_ve) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation |
|
||||
| [score_sde_vp](./score_sde_vp) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation |
|
||||
| [stable_diffusion](./stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-to-Image Generation | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb)
|
||||
| [stable_diffusion](./stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Image-to-Image Text-Guided Generation | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb)
|
||||
| [stable_diffusion](./stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-Guided Image Inpainting | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb)
|
||||
| [stable_diffusion_2](./stable_diffusion_2) | [**Stable Diffusion 2**](https://stability.ai/blog/stable-diffusion-v2-release) | Text-to-Image Generation |
|
||||
| [stable_diffusion_2](./stable_diffusion_2) | [**Stable Diffusion 2**](https://stability.ai/blog/stable-diffusion-v2-release) | Text-Guided Image Inpainting |
|
||||
| [stable_diffusion_2](./stable_diffusion_2) | [**Stable Diffusion 2**](https://stability.ai/blog/stable-diffusion-v2-release) | Text-Guided Super Resolution Image-to-Image |
|
||||
| [stable_diffusion_safe](./stable_diffusion_safe) | [**Safe Stable Diffusion**](https://arxiv.org/abs/2211.05105) | Text-Guided Generation | [](https://colab.research.google.com/github/ml-research/safe-latent-diffusion/blob/main/examples/Safe%20Latent%20Diffusion.ipynb)
|
||||
| [stochastic_karras_ve](./stochastic_karras_ve) | [**Elucidating the Design Space of Diffusion-Based Generative Models**](https://arxiv.org/abs/2206.00364) | Unconditional Image Generation |
|
||||
| [unclip](./unclip) | [Hierarchical Text-Conditional Image Generation with CLIP Latents](https://arxiv.org/abs/2204.06125) | Text-to-Image Generation |
|
||||
| [versatile_diffusion](./versatile_diffusion) | [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://arxiv.org/abs/2211.08332) | Text-to-Image Generation |
|
||||
| [versatile_diffusion](./versatile_diffusion) | [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://arxiv.org/abs/2211.08332) | Image Variations Generation |
|
||||
| [versatile_diffusion](./versatile_diffusion) | [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://arxiv.org/abs/2211.08332) | Dual Image and Text Guided Generation |
|
||||
| [vq_diffusion](./vq_diffusion) | [Vector Quantized Diffusion Model for Text-to-Image Synthesis](https://arxiv.org/abs/2111.14822) | Text-to-Image Generation |
|
||||
|
||||
|
||||
**Note**: Pipelines are simple examples of how to play around with the diffusion systems as described in the corresponding papers.
|
||||
@@ -138,9 +139,9 @@ from diffusers import StableDiffusionImg2ImgPipeline
|
||||
|
||||
# load the pipeline
|
||||
device = "cuda"
|
||||
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5", revision="fp16", torch_dtype=torch.float16
|
||||
).to(device)
|
||||
pipe = StableDiffusionImg2ImgPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16).to(
|
||||
device
|
||||
)
|
||||
|
||||
# let's download an initial image
|
||||
url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
|
||||
@@ -188,7 +189,6 @@ mask_image = download_image(mask_url).resize((512, 512))
|
||||
|
||||
pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-inpainting",
|
||||
revision="fp16",
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
pipe = pipe.to("cuda")
|
||||
|
||||
@@ -113,7 +113,7 @@ import torch
|
||||
|
||||
# load model and scheduler
|
||||
model_id = "stabilityai/stable-diffusion-x4-upscaler"
|
||||
pipeline = StableDiffusionUpscalePipeline.from_pretrained(model_id, revision="fp16", torch_dtype=torch.float16)
|
||||
pipeline = StableDiffusionUpscalePipeline.from_pretrained(model_id, torch_dtype=torch.float16)
|
||||
pipeline = pipeline.to("cuda")
|
||||
|
||||
# let's download an image
|
||||
|
||||
31
docs/source/api/pipelines/unclip.mdx
Normal file
31
docs/source/api/pipelines/unclip.mdx
Normal file
@@ -0,0 +1,31 @@
|
||||
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# unCLIP
|
||||
|
||||
## Overview
|
||||
|
||||
[Hierarchical Text-Conditional Image Generation with CLIP Latents](https://arxiv.org/abs/2204.06125) by Aditya Ramesh, Prafulla Dhariwal, Alex Nichol, Casey Chu, Mark Chen
|
||||
|
||||
The abstract of the paper is the following:
|
||||
|
||||
Contrastive models like CLIP have been shown to learn robust representations of images that capture both semantics and style. To leverage these representations for image generation, we propose a two-stage model: a prior that generates a CLIP image embedding given a text caption, and a decoder that generates an image conditioned on the image embedding. We show that explicitly generating image representations improves image diversity with minimal loss in photorealism and caption similarity. Our decoders conditioned on image representations can also produce variations of an image that preserve both its semantics and style, while varying the non-essential details absent from the image representation. Moreover, the joint embedding space of CLIP enables language-guided image manipulations in a zero-shot fashion. We use diffusion models for the decoder and experiment with both autoregressive and diffusion models for the prior, finding that the latter are computationally more efficient and produce higher-quality samples.
|
||||
|
||||
The unCLIP model in diffusers comes from kakaobrain's karlo and the original codebase can be found [here](https://github.com/kakaobrain/karlo). Additionally, lucidrains has a DALL-E 2 recreation [here](https://github.com/lucidrains/DALLE2-pytorch).
|
||||
|
||||
## Available Pipelines:
|
||||
|
||||
| Pipeline | Tasks | Colab
|
||||
|---|---|:---:|
|
||||
| [pipeline_unclip.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/unclip/pipeline_unclip.py) | *Text-to-Image Generation* | - |
|
||||
|
||||
|
||||
## UnCLIPPipeline
|
||||
[[autodoc]] pipelines.unclip.pipeline_unclip.UnCLIPPipeline
|
||||
- __call__
|
||||
@@ -1,183 +0,0 @@
|
||||
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Schedulers
|
||||
|
||||
Diffusers contains multiple pre-built schedule functions for the diffusion process.
|
||||
|
||||
## What is a scheduler?
|
||||
|
||||
The schedule functions, denoted *Schedulers* in the library take in the output of a trained model, a sample which the diffusion process is iterating on, and a timestep to return a denoised sample. That's why schedulers may also be called *Samplers* in other diffusion models implementations.
|
||||
|
||||
- Schedulers define the methodology for iteratively adding noise to an image or for updating a sample based on model outputs.
|
||||
- adding noise in different manners represent the algorithmic processes to train a diffusion model by adding noise to images.
|
||||
- for inference, the scheduler defines how to update a sample based on an output from a pretrained model.
|
||||
- Schedulers are often defined by a *noise schedule* and an *update rule* to solve the differential equation solution.
|
||||
|
||||
### Discrete versus continuous schedulers
|
||||
|
||||
All schedulers take in a timestep to predict the updated version of the sample being diffused.
|
||||
The timesteps dictate where in the diffusion process the step is, where data is generated by iterating forward in time and inference is executed by propagating backwards through timesteps.
|
||||
Different algorithms use timesteps that both discrete (accepting `int` inputs), such as the [`DDPMScheduler`] or [`PNDMScheduler`], and continuous (accepting `float` inputs), such as the score-based schedulers [`ScoreSdeVeScheduler`] or [`ScoreSdeVpScheduler`].
|
||||
|
||||
## Designing Re-usable schedulers
|
||||
|
||||
The core design principle between the schedule functions is to be model, system, and framework independent.
|
||||
This allows for rapid experimentation and cleaner abstractions in the code, where the model prediction is separated from the sample update.
|
||||
To this end, the design of schedulers is such that:
|
||||
|
||||
- Schedulers can be used interchangeably between diffusion models in inference to find the preferred trade-off between speed and generation quality.
|
||||
- Schedulers are currently by default in PyTorch, but are designed to be framework independent (partial Jax support currently exists).
|
||||
|
||||
|
||||
## API
|
||||
|
||||
The core API for any new scheduler must follow a limited structure.
|
||||
- Schedulers should provide one or more `def step(...)` functions that should be called to update the generated sample iteratively.
|
||||
- Schedulers should provide a `set_timesteps(...)` method that configures the parameters of a schedule function for a specific inference task.
|
||||
- Schedulers should be framework-specific.
|
||||
|
||||
The base class [`SchedulerMixin`] implements low level utilities used by multiple schedulers.
|
||||
|
||||
### SchedulerMixin
|
||||
[[autodoc]] SchedulerMixin
|
||||
|
||||
### SchedulerOutput
|
||||
The class [`SchedulerOutput`] contains the outputs from any schedulers `step(...)` call.
|
||||
|
||||
[[autodoc]] schedulers.scheduling_utils.SchedulerOutput
|
||||
|
||||
### Implemented Schedulers
|
||||
|
||||
#### Denoising diffusion implicit models (DDIM)
|
||||
|
||||
Original paper can be found here.
|
||||
|
||||
[[autodoc]] DDIMScheduler
|
||||
|
||||
#### Denoising diffusion probabilistic models (DDPM)
|
||||
|
||||
Original paper can be found [here](https://arxiv.org/abs/2010.02502).
|
||||
|
||||
[[autodoc]] DDPMScheduler
|
||||
|
||||
#### Singlestep DPM-Solver
|
||||
|
||||
Original paper can be found [here](https://arxiv.org/abs/2206.00927) and the [improved version](https://arxiv.org/abs/2211.01095). The original implementation can be found [here](https://github.com/LuChengTHU/dpm-solver).
|
||||
|
||||
[[autodoc]] DPMSolverSinglestepScheduler
|
||||
|
||||
#### Multistep DPM-Solver
|
||||
|
||||
Original paper can be found [here](https://arxiv.org/abs/2206.00927) and the [improved version](https://arxiv.org/abs/2211.01095). The original implementation can be found [here](https://github.com/LuChengTHU/dpm-solver).
|
||||
|
||||
[[autodoc]] DPMSolverMultistepScheduler
|
||||
|
||||
#### Heun scheduler inspired by Karras et. al paper
|
||||
|
||||
Algorithm 1 of [Karras et. al](https://arxiv.org/abs/2206.00364).
|
||||
Scheduler ported from @crowsonkb's https://github.com/crowsonkb/k-diffusion library:
|
||||
|
||||
All credit for making this scheduler work goes to [Katherine Crowson](https://github.com/crowsonkb/)
|
||||
|
||||
[[autodoc]] HeunDiscreteScheduler
|
||||
|
||||
#### DPM Discrete Scheduler inspired by Karras et. al paper
|
||||
|
||||
Inspired by [Karras et. al](https://arxiv.org/abs/2206.00364).
|
||||
Scheduler ported from @crowsonkb's https://github.com/crowsonkb/k-diffusion library:
|
||||
|
||||
All credit for making this scheduler work goes to [Katherine Crowson](https://github.com/crowsonkb/)
|
||||
|
||||
[[autodoc]] KDPM2DiscreteScheduler
|
||||
|
||||
#### DPM Discrete Scheduler with ancestral sampling inspired by Karras et. al paper
|
||||
|
||||
Inspired by [Karras et. al](https://arxiv.org/abs/2206.00364).
|
||||
Scheduler ported from @crowsonkb's https://github.com/crowsonkb/k-diffusion library:
|
||||
|
||||
All credit for making this scheduler work goes to [Katherine Crowson](https://github.com/crowsonkb/)
|
||||
|
||||
[[autodoc]] KDPM2AncestralDiscreteScheduler
|
||||
|
||||
#### Variance exploding, stochastic sampling from Karras et. al
|
||||
|
||||
Original paper can be found [here](https://arxiv.org/abs/2006.11239).
|
||||
|
||||
[[autodoc]] KarrasVeScheduler
|
||||
|
||||
#### Linear multistep scheduler for discrete beta schedules
|
||||
|
||||
Original implementation can be found [here](https://arxiv.org/abs/2206.00364).
|
||||
|
||||
[[autodoc]] LMSDiscreteScheduler
|
||||
|
||||
#### Pseudo numerical methods for diffusion models (PNDM)
|
||||
|
||||
Original implementation can be found [here](https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L181).
|
||||
|
||||
[[autodoc]] PNDMScheduler
|
||||
|
||||
#### variance exploding stochastic differential equation (VE-SDE) scheduler
|
||||
|
||||
Original paper can be found [here](https://arxiv.org/abs/2011.13456).
|
||||
|
||||
[[autodoc]] ScoreSdeVeScheduler
|
||||
|
||||
#### improved pseudo numerical methods for diffusion models (iPNDM)
|
||||
|
||||
Original implementation can be found [here](https://github.com/crowsonkb/v-diffusion-pytorch/blob/987f8985e38208345c1959b0ea767a625831cc9b/diffusion/sampling.py#L296).
|
||||
|
||||
[[autodoc]] IPNDMScheduler
|
||||
|
||||
#### variance preserving stochastic differential equation (VP-SDE) scheduler
|
||||
|
||||
Original paper can be found [here](https://arxiv.org/abs/2011.13456).
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Score SDE-VP is under construction.
|
||||
|
||||
</Tip>
|
||||
|
||||
[[autodoc]] schedulers.scheduling_sde_vp.ScoreSdeVpScheduler
|
||||
|
||||
#### Euler scheduler
|
||||
|
||||
Euler scheduler (Algorithm 2) from the paper [Elucidating the Design Space of Diffusion-Based Generative Models](https://arxiv.org/abs/2206.00364) by Karras et al. (2022). Based on the original [k-diffusion](https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L51) implementation by Katherine Crowson.
|
||||
Fast scheduler which often times generates good outputs with 20-30 steps.
|
||||
|
||||
[[autodoc]] EulerDiscreteScheduler
|
||||
|
||||
|
||||
#### Euler Ancestral scheduler
|
||||
|
||||
Ancestral sampling with Euler method steps. Based on the original (k-diffusion)[https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L72] implementation by Katherine Crowson.
|
||||
Fast scheduler which often times generates good outputs with 20-30 steps.
|
||||
|
||||
[[autodoc]] EulerAncestralDiscreteScheduler
|
||||
|
||||
|
||||
#### VQDiffusionScheduler
|
||||
|
||||
Original paper can be found [here](https://arxiv.org/abs/2111.14822)
|
||||
|
||||
[[autodoc]] VQDiffusionScheduler
|
||||
|
||||
#### RePaint scheduler
|
||||
|
||||
DDPM-based inpainting scheduler for unsupervised inpainting with extreme masks.
|
||||
Intended for use with [`RePaintPipeline`].
|
||||
Based on the paper [RePaint: Inpainting using Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2201.09865)
|
||||
and the original implementation by Andreas Lugmayr et al.: https://github.com/andreas128/RePaint
|
||||
|
||||
[[autodoc]] RePaintScheduler
|
||||
27
docs/source/api/schedulers/ddim.mdx
Normal file
27
docs/source/api/schedulers/ddim.mdx
Normal file
@@ -0,0 +1,27 @@
|
||||
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Denoising diffusion implicit models (DDIM)
|
||||
|
||||
## Overview
|
||||
|
||||
[Denoising Diffusion Implicit Models](https://arxiv.org/abs/2010.02502) (DDIM) by Jiaming Song, Chenlin Meng and Stefano Ermon.
|
||||
|
||||
The abstract of the paper is the following:
|
||||
|
||||
Denoising diffusion probabilistic models (DDPMs) have achieved high quality image generation without adversarial training, yet they require simulating a Markov chain for many steps to produce a sample. To accelerate sampling, we present denoising diffusion implicit models (DDIMs), a more efficient class of iterative implicit probabilistic models with the same training procedure as DDPMs. In DDPMs, the generative process is defined as the reverse of a Markovian diffusion process. We construct a class of non-Markovian diffusion processes that lead to the same training objective, but whose reverse process can be much faster to sample from. We empirically demonstrate that DDIMs can produce high quality samples 10× to 50× faster in terms of wall-clock time compared to DDPMs, allow us to trade off computation for sample quality, and can perform semantically meaningful image interpolation directly in the latent space.
|
||||
|
||||
The original codebase of this paper can be found here: [ermongroup/ddim](https://github.com/ermongroup/ddim).
|
||||
For questions, feel free to contact the author on [tsong.me](https://tsong.me/).
|
||||
|
||||
## DDIMScheduler
|
||||
[[autodoc]] DDIMScheduler
|
||||
27
docs/source/api/schedulers/ddpm.mdx
Normal file
27
docs/source/api/schedulers/ddpm.mdx
Normal file
@@ -0,0 +1,27 @@
|
||||
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Denoising diffusion probabilistic models (DDPM)
|
||||
|
||||
## Overview
|
||||
|
||||
[Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239)
|
||||
(DDPM) by Jonathan Ho, Ajay Jain and Pieter Abbeel proposes the diffusion based model of the same name, but in the context of the 🤗 Diffusers library, DDPM refers to the discrete denoising scheduler from the paper as well as the pipeline.
|
||||
|
||||
The abstract of the paper is the following:
|
||||
|
||||
We present high quality image synthesis results using diffusion probabilistic models, a class of latent variable models inspired by considerations from nonequilibrium thermodynamics. Our best results are obtained by training on a weighted variational bound designed according to a novel connection between diffusion probabilistic models and denoising score matching with Langevin dynamics, and our models naturally admit a progressive lossy decompression scheme that can be interpreted as a generalization of autoregressive decoding. On the unconditional CIFAR10 dataset, we obtain an Inception score of 9.46 and a state-of-the-art FID score of 3.17. On 256x256 LSUN, we obtain sample quality similar to ProgressiveGAN.
|
||||
|
||||
The original paper can be found [here](https://arxiv.org/abs/2010.02502).
|
||||
|
||||
## DDPMScheduler
|
||||
[[autodoc]] DDPMScheduler
|
||||
22
docs/source/api/schedulers/dpm_discrete.mdx
Normal file
22
docs/source/api/schedulers/dpm_discrete.mdx
Normal file
@@ -0,0 +1,22 @@
|
||||
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# DPM Discrete Scheduler inspired by Karras et. al paper
|
||||
|
||||
## Overview
|
||||
|
||||
Inspired by [Karras et. al](https://arxiv.org/abs/2206.00364). Scheduler ported from @crowsonkb's https://github.com/crowsonkb/k-diffusion library:
|
||||
|
||||
All credit for making this scheduler work goes to [Katherine Crowson](https://github.com/crowsonkb/)
|
||||
|
||||
## KDPM2DiscreteScheduler
|
||||
[[autodoc]] KDPM2DiscreteScheduler
|
||||
22
docs/source/api/schedulers/dpm_discrete_ancestral.mdx
Normal file
22
docs/source/api/schedulers/dpm_discrete_ancestral.mdx
Normal file
@@ -0,0 +1,22 @@
|
||||
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# DPM Discrete Scheduler with ancestral sampling inspired by Karras et. al paper
|
||||
|
||||
## Overview
|
||||
|
||||
Inspired by [Karras et. al](https://arxiv.org/abs/2206.00364). Scheduler ported from @crowsonkb's https://github.com/crowsonkb/k-diffusion library:
|
||||
|
||||
All credit for making this scheduler work goes to [Katherine Crowson](https://github.com/crowsonkb/)
|
||||
|
||||
## KDPM2AncestralDiscreteScheduler
|
||||
[[autodoc]] KDPM2AncestralDiscreteScheduler
|
||||
21
docs/source/api/schedulers/euler.mdx
Normal file
21
docs/source/api/schedulers/euler.mdx
Normal file
@@ -0,0 +1,21 @@
|
||||
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Euler scheduler
|
||||
|
||||
## Overview
|
||||
|
||||
Euler scheduler (Algorithm 2) from the paper [Elucidating the Design Space of Diffusion-Based Generative Models](https://arxiv.org/abs/2206.00364) by Karras et al. (2022). Based on the original [k-diffusion](https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L51) implementation by Katherine Crowson.
|
||||
Fast scheduler which often times generates good outputs with 20-30 steps.
|
||||
|
||||
## EulerDiscreteScheduler
|
||||
[[autodoc]] EulerDiscreteScheduler
|
||||
21
docs/source/api/schedulers/euler_ancestral.mdx
Normal file
21
docs/source/api/schedulers/euler_ancestral.mdx
Normal file
@@ -0,0 +1,21 @@
|
||||
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Euler Ancestral scheduler
|
||||
|
||||
## Overview
|
||||
|
||||
Ancestral sampling with Euler method steps. Based on the original (k-diffusion)[https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L72] implementation by Katherine Crowson.
|
||||
Fast scheduler which often times generates good outputs with 20-30 steps.
|
||||
|
||||
## EulerAncestralDiscreteScheduler
|
||||
[[autodoc]] EulerAncestralDiscreteScheduler
|
||||
23
docs/source/api/schedulers/heun.mdx
Normal file
23
docs/source/api/schedulers/heun.mdx
Normal file
@@ -0,0 +1,23 @@
|
||||
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Heun scheduler inspired by Karras et. al paper
|
||||
|
||||
## Overview
|
||||
|
||||
Algorithm 1 of [Karras et. al](https://arxiv.org/abs/2206.00364).
|
||||
Scheduler ported from @crowsonkb's https://github.com/crowsonkb/k-diffusion library:
|
||||
|
||||
All credit for making this scheduler work goes to [Katherine Crowson](https://github.com/crowsonkb/)
|
||||
|
||||
## HeunDiscreteScheduler
|
||||
[[autodoc]] HeunDiscreteScheduler
|
||||
20
docs/source/api/schedulers/ipndm.mdx
Normal file
20
docs/source/api/schedulers/ipndm.mdx
Normal file
@@ -0,0 +1,20 @@
|
||||
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# improved pseudo numerical methods for diffusion models (iPNDM)
|
||||
|
||||
## Overview
|
||||
|
||||
Original implementation can be found [here](https://github.com/crowsonkb/v-diffusion-pytorch/blob/987f8985e38208345c1959b0ea767a625831cc9b/diffusion/sampling.py#L296).
|
||||
|
||||
## IPNDMScheduler
|
||||
[[autodoc]] IPNDMScheduler
|
||||
20
docs/source/api/schedulers/lms_discrete.mdx
Normal file
20
docs/source/api/schedulers/lms_discrete.mdx
Normal file
@@ -0,0 +1,20 @@
|
||||
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Linear multistep scheduler for discrete beta schedules
|
||||
|
||||
## Overview
|
||||
|
||||
Original implementation can be found [here](https://arxiv.org/abs/2206.00364).
|
||||
|
||||
## LMSDiscreteScheduler
|
||||
[[autodoc]] LMSDiscreteScheduler
|
||||
20
docs/source/api/schedulers/multistep_dpm_solver.mdx
Normal file
20
docs/source/api/schedulers/multistep_dpm_solver.mdx
Normal file
@@ -0,0 +1,20 @@
|
||||
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Multistep DPM-Solver
|
||||
|
||||
## Overview
|
||||
|
||||
Original paper can be found [here](https://arxiv.org/abs/2206.00927) and the [improved version](https://arxiv.org/abs/2211.01095). The original implementation can be found [here](https://github.com/LuChengTHU/dpm-solver).
|
||||
|
||||
## DPMSolverMultistepScheduler
|
||||
[[autodoc]] DPMSolverMultistepScheduler
|
||||
83
docs/source/api/schedulers/overview.mdx
Normal file
83
docs/source/api/schedulers/overview.mdx
Normal file
@@ -0,0 +1,83 @@
|
||||
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Schedulers
|
||||
|
||||
Diffusers contains multiple pre-built schedule functions for the diffusion process.
|
||||
|
||||
## What is a scheduler?
|
||||
|
||||
The schedule functions, denoted *Schedulers* in the library take in the output of a trained model, a sample which the diffusion process is iterating on, and a timestep to return a denoised sample. That's why schedulers may also be called *Samplers* in other diffusion models implementations.
|
||||
|
||||
- Schedulers define the methodology for iteratively adding noise to an image or for updating a sample based on model outputs.
|
||||
- adding noise in different manners represent the algorithmic processes to train a diffusion model by adding noise to images.
|
||||
- for inference, the scheduler defines how to update a sample based on an output from a pretrained model.
|
||||
- Schedulers are often defined by a *noise schedule* and an *update rule* to solve the differential equation solution.
|
||||
|
||||
### Discrete versus continuous schedulers
|
||||
|
||||
All schedulers take in a timestep to predict the updated version of the sample being diffused.
|
||||
The timesteps dictate where in the diffusion process the step is, where data is generated by iterating forward in time and inference is executed by propagating backwards through timesteps.
|
||||
Different algorithms use timesteps that can be discrete (accepting `int` inputs), such as the [`DDPMScheduler`] or [`PNDMScheduler`], or continuous (accepting `float` inputs), such as the score-based schedulers [`ScoreSdeVeScheduler`] or [`ScoreSdeVpScheduler`].
|
||||
|
||||
## Designing Re-usable schedulers
|
||||
|
||||
The core design principle between the schedule functions is to be model, system, and framework independent.
|
||||
This allows for rapid experimentation and cleaner abstractions in the code, where the model prediction is separated from the sample update.
|
||||
To this end, the design of schedulers is such that:
|
||||
|
||||
- Schedulers can be used interchangeably between diffusion models in inference to find the preferred trade-off between speed and generation quality.
|
||||
- Schedulers are currently by default in PyTorch, but are designed to be framework independent (partial Jax support currently exists).
|
||||
|
||||
## Schedulers Summary
|
||||
|
||||
The following table summarizes all officially supported schedulers, their corresponding paper
|
||||
|
||||
|
||||
| Scheduler | Paper |
|
||||
|---|---|
|
||||
| [ddim](./ddim) | [**Denoising Diffusion Implicit Models**](https://arxiv.org/abs/2010.02502) |
|
||||
| [ddpm](./ddpm) | [**Denoising Diffusion Probabilistic Models**](https://arxiv.org/abs/2006.11239) |
|
||||
| [singlestep_dpm_solver](./singlestep_dpm_solver) | [**Singlestep DPM-Solver**](https://arxiv.org/abs/2206.00927) |
|
||||
| [multistep_dpm_solver](./multistep_dpm_solver) | [**Multistep DPM-Solver**](https://arxiv.org/abs/2206.00927) |
|
||||
| [heun](./heun) | [**Heun scheduler inspired by Karras et. al paper**](https://arxiv.org/abs/2206.00364) |
|
||||
| [dpm_discrete](./dpm_discrete) | [**DPM Discrete Scheduler inspired by Karras et. al paper**](https://arxiv.org/abs/2206.00364) |
|
||||
| [dpm_discrete_ancestral](./dpm_discrete_ancestral) | [**DPM Discrete Scheduler with ancestral sampling inspired by Karras et. al paper**](https://arxiv.org/abs/2206.00364) |
|
||||
| [stochastic_karras_ve](./stochastic_karras_ve) | [**Variance exploding, stochastic sampling from Karras et. al**](https://arxiv.org/abs/2206.00364) |
|
||||
| [lms_discrete](./lms_discrete) | [**Linear multistep scheduler for discrete beta schedules**](https://arxiv.org/abs/2206.00364) |
|
||||
| [pndm](./pndm) | [**Pseudo numerical methods for diffusion models (PNDM)**](https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L181) |
|
||||
| [score_sde_ve](./score_sde_ve) | [**variance exploding stochastic differential equation (VE-SDE) scheduler**](https://arxiv.org/abs/2011.13456) |
|
||||
| [ipndm](./ipndm) | [**improved pseudo numerical methods for diffusion models (iPNDM)**](https://github.com/crowsonkb/v-diffusion-pytorch/blob/987f8985e38208345c1959b0ea767a625831cc9b/diffusion/sampling.py#L296) |
|
||||
| [score_sde_vp](./score_sde_vp) | [**Variance preserving stochastic differential equation (VP-SDE) scheduler**](https://arxiv.org/abs/2011.13456) |
|
||||
| [euler](./euler) | [**Euler scheduler**](https://arxiv.org/abs/2206.00364) |
|
||||
| [euler_ancestral](./euler_ancestral) | [**Euler Ancestral scheduler**](https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L72) |
|
||||
| [vq_diffusion](./vq_diffusion) | [**VQDiffusionScheduler**](https://arxiv.org/abs/2111.14822) |
|
||||
| [repaint](./repaint) | [**RePaint scheduler**](https://arxiv.org/abs/2201.09865) |
|
||||
|
||||
## API
|
||||
|
||||
The core API for any new scheduler must follow a limited structure.
|
||||
- Schedulers should provide one or more `def step(...)` functions that should be called to update the generated sample iteratively.
|
||||
- Schedulers should provide a `set_timesteps(...)` method that configures the parameters of a schedule function for a specific inference task.
|
||||
- Schedulers should be framework-specific.
|
||||
|
||||
The base class [`SchedulerMixin`] implements low level utilities used by multiple schedulers.
|
||||
|
||||
### SchedulerMixin
|
||||
[[autodoc]] SchedulerMixin
|
||||
|
||||
### SchedulerOutput
|
||||
The class [`SchedulerOutput`] contains the outputs from any schedulers `step(...)` call.
|
||||
|
||||
[[autodoc]] schedulers.scheduling_utils.SchedulerOutput
|
||||
|
||||
|
||||
20
docs/source/api/schedulers/pndm.mdx
Normal file
20
docs/source/api/schedulers/pndm.mdx
Normal file
@@ -0,0 +1,20 @@
|
||||
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Pseudo numerical methods for diffusion models (PNDM)
|
||||
|
||||
## Overview
|
||||
|
||||
Original implementation can be found [here](https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L181).
|
||||
|
||||
## PNDMScheduler
|
||||
[[autodoc]] PNDMScheduler
|
||||
23
docs/source/api/schedulers/repaint.mdx
Normal file
23
docs/source/api/schedulers/repaint.mdx
Normal file
@@ -0,0 +1,23 @@
|
||||
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# RePaint scheduler
|
||||
|
||||
## Overview
|
||||
|
||||
DDPM-based inpainting scheduler for unsupervised inpainting with extreme masks.
|
||||
Intended for use with [`RePaintPipeline`].
|
||||
Based on the paper [RePaint: Inpainting using Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2201.09865)
|
||||
and the original implementation by Andreas Lugmayr et al.: https://github.com/andreas128/RePaint
|
||||
|
||||
## RePaintScheduler
|
||||
[[autodoc]] RePaintScheduler
|
||||
20
docs/source/api/schedulers/score_sde_ve.mdx
Normal file
20
docs/source/api/schedulers/score_sde_ve.mdx
Normal file
@@ -0,0 +1,20 @@
|
||||
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# variance exploding stochastic differential equation (VE-SDE) scheduler
|
||||
|
||||
## Overview
|
||||
|
||||
Original paper can be found [here](https://arxiv.org/abs/2011.13456).
|
||||
|
||||
## ScoreSdeVeScheduler
|
||||
[[autodoc]] ScoreSdeVeScheduler
|
||||
26
docs/source/api/schedulers/score_sde_vp.mdx
Normal file
26
docs/source/api/schedulers/score_sde_vp.mdx
Normal file
@@ -0,0 +1,26 @@
|
||||
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Variance preserving stochastic differential equation (VP-SDE) scheduler
|
||||
|
||||
## Overview
|
||||
|
||||
Original paper can be found [here](https://arxiv.org/abs/2011.13456).
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Score SDE-VP is under construction.
|
||||
|
||||
</Tip>
|
||||
|
||||
## ScoreSdeVpScheduler
|
||||
[[autodoc]] schedulers.scheduling_sde_vp.ScoreSdeVpScheduler
|
||||
20
docs/source/api/schedulers/singlestep_dpm_solver.mdx
Normal file
20
docs/source/api/schedulers/singlestep_dpm_solver.mdx
Normal file
@@ -0,0 +1,20 @@
|
||||
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Singlestep DPM-Solver
|
||||
|
||||
## Overview
|
||||
|
||||
Original paper can be found [here](https://arxiv.org/abs/2206.00927) and the [improved version](https://arxiv.org/abs/2211.01095). The original implementation can be found [here](https://github.com/LuChengTHU/dpm-solver).
|
||||
|
||||
## DPMSolverSinglestepScheduler
|
||||
[[autodoc]] DPMSolverSinglestepScheduler
|
||||
20
docs/source/api/schedulers/stochastic_karras_ve.mdx
Normal file
20
docs/source/api/schedulers/stochastic_karras_ve.mdx
Normal file
@@ -0,0 +1,20 @@
|
||||
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Variance exploding, stochastic sampling from Karras et. al
|
||||
|
||||
## Overview
|
||||
|
||||
Original paper can be found [here](https://arxiv.org/abs/2206.00364).
|
||||
|
||||
## KarrasVeScheduler
|
||||
[[autodoc]] KarrasVeScheduler
|
||||
20
docs/source/api/schedulers/vq_diffusion.mdx
Normal file
20
docs/source/api/schedulers/vq_diffusion.mdx
Normal file
@@ -0,0 +1,20 @@
|
||||
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# VQDiffusionScheduler
|
||||
|
||||
## Overview
|
||||
|
||||
Original paper can be found [here](https://arxiv.org/abs/2111.14822)
|
||||
|
||||
## VQDiffusionScheduler
|
||||
[[autodoc]] VQDiffusionScheduler
|
||||
@@ -23,7 +23,7 @@ specific language governing permissions and limitations under the License.
|
||||
More precisely, 🤗 Diffusers offers:
|
||||
|
||||
- State-of-the-art diffusion pipelines that can be run in inference with just a couple of lines of code (see [**Using Diffusers**](./using-diffusers/conditional_image_generation)) or have a look at [**Pipelines**](#pipelines) to get an overview of all supported pipelines and their corresponding papers.
|
||||
- Various noise schedulers that can be used interchangeably for the preferred speed vs. quality trade-off in inference. For more information see [**Schedulers**](./api/schedulers).
|
||||
- Various noise schedulers that can be used interchangeably for the preferred speed vs. quality trade-off in inference. For more information see [**Schedulers**](./api/schedulers/overview).
|
||||
- Multiple types of models, such as UNet, can be used as building blocks in an end-to-end diffusion system. See [**Models**](./api/models) for more details
|
||||
- Training examples to show how to train the most popular diffusion model tasks. For more information see [**Training**](./training/overview).
|
||||
|
||||
@@ -55,6 +55,7 @@ available a colab notebook to directly try them out.
|
||||
| [stable_diffusion_2](./api/pipelines/stable_diffusion_2) | [**Stable Diffusion 2**](https://stability.ai/blog/stable-diffusion-v2-release) | Text-Guided Super Resolution Image-to-Image |
|
||||
| [stable_diffusion_safe](./api/pipelines/stable_diffusion_safe) | [**Safe Stable Diffusion**](https://arxiv.org/abs/2211.05105) | Text-Guided Generation | [](https://colab.research.google.com/github/ml-research/safe-latent-diffusion/blob/main/examples/Safe%20Latent%20Diffusion.ipynb)
|
||||
| [stochastic_karras_ve](./api/pipelines/stochastic_karras_ve) | [**Elucidating the Design Space of Diffusion-Based Generative Models**](https://arxiv.org/abs/2206.00364) | Unconditional Image Generation |
|
||||
| [unclip](./api/pipelines/unclip) | [Hierarchical Text-Conditional Image Generation with CLIP Latents](https://arxiv.org/abs/2204.06125) | Text-to-Image Generation |
|
||||
| [versatile_diffusion](./api/pipelines/versatile_diffusion) | [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://arxiv.org/abs/2211.08332) | Text-to-Image Generation |
|
||||
| [versatile_diffusion](./api/pipelines/versatile_diffusion) | [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://arxiv.org/abs/2211.08332) | Image Variations Generation |
|
||||
| [versatile_diffusion](./api/pipelines/versatile_diffusion) | [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://arxiv.org/abs/2211.08332) | Dual Image and Text Guided Generation |
|
||||
|
||||
@@ -12,7 +12,9 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
# Memory and speed
|
||||
|
||||
We present some techniques and ideas to optimize 🤗 Diffusers _inference_ for memory or speed.
|
||||
We present some techniques and ideas to optimize 🤗 Diffusers _inference_ for memory or speed. As a general rule, we recommend the use of [xFormers](https://github.com/facebookresearch/xformers) for memory efficient attention, please see the recommended [installation instructions](xformers).
|
||||
|
||||
We'll discuss how the following settings impact performance and memory.
|
||||
|
||||
| | Latency | Speedup |
|
||||
| ---------------- | ------- | ------- |
|
||||
@@ -77,7 +79,7 @@ To save more GPU memory and get even more speed, you can load and run the model
|
||||
```Python
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5",
|
||||
revision="fp16",
|
||||
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
pipe = pipe.to("cuda")
|
||||
@@ -105,7 +107,7 @@ from diffusers import StableDiffusionPipeline
|
||||
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5",
|
||||
revision="fp16",
|
||||
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
pipe = pipe.to("cuda")
|
||||
@@ -132,7 +134,7 @@ from diffusers import StableDiffusionPipeline
|
||||
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5",
|
||||
revision="fp16",
|
||||
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
pipe = pipe.to("cuda")
|
||||
@@ -157,7 +159,7 @@ from diffusers import StableDiffusionPipeline
|
||||
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5",
|
||||
revision="fp16",
|
||||
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
pipe = pipe.to("cuda")
|
||||
@@ -177,7 +179,7 @@ from diffusers import StableDiffusionPipeline
|
||||
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5",
|
||||
revision="fp16",
|
||||
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
pipe = pipe.to("cuda")
|
||||
@@ -232,7 +234,6 @@ def generate_inputs():
|
||||
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5",
|
||||
revision="fp16",
|
||||
torch_dtype=torch.float16,
|
||||
).to("cuda")
|
||||
unet = pipe.unet
|
||||
@@ -296,7 +297,6 @@ class UNet2DConditionOutput:
|
||||
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5",
|
||||
revision="fp16",
|
||||
torch_dtype=torch.float16,
|
||||
).to("cuda")
|
||||
|
||||
@@ -322,7 +322,9 @@ with torch.inference_mode():
|
||||
|
||||
|
||||
## Memory Efficient Attention
|
||||
Recent work on optimizing the bandwitdh in the attention block have generated huge speed ups and gains in GPU memory usage. The most recent being Flash Attention (from @tridao, [code](https://github.com/HazyResearch/flash-attention), [paper](https://arxiv.org/pdf/2205.14135.pdf)) .
|
||||
|
||||
Recent work on optimizing the bandwitdh in the attention block has generated huge speed ups and gains in GPU memory usage. The most recent being Flash Attention from @tridao: [code](https://github.com/HazyResearch/flash-attention), [paper](https://arxiv.org/pdf/2205.14135.pdf).
|
||||
|
||||
Here are the speedups we obtain on a few Nvidia GPUs when running the inference at 512x512 with a batch size of 1 (one prompt):
|
||||
|
||||
| GPU | Base Attention FP16 | Memory Efficient Attention FP16 |
|
||||
@@ -338,14 +340,13 @@ Here are the speedups we obtain on a few Nvidia GPUs when running the inference
|
||||
To leverage it just make sure you have:
|
||||
- PyTorch > 1.12
|
||||
- Cuda available
|
||||
- Installed the [xformers](https://github.com/facebookresearch/xformers) library
|
||||
- [Installed the xformers library](xformers).
|
||||
```python
|
||||
from diffusers import StableDiffusionPipeline
|
||||
import torch
|
||||
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5",
|
||||
revision="fp16",
|
||||
torch_dtype=torch.float16,
|
||||
).to("cuda")
|
||||
|
||||
|
||||
26
docs/source/optimization/xformers.mdx
Normal file
26
docs/source/optimization/xformers.mdx
Normal file
@@ -0,0 +1,26 @@
|
||||
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Installing xFormers
|
||||
|
||||
We recommend the use of [xFormers](https://github.com/facebookresearch/xformers) for both inference and training. In our tests, the optimizations performed in the attention blocks allow for both faster speed and reduced memory consumption.
|
||||
|
||||
Installing xFormers has historically been a bit involved, as binary distributions were not always up to date. Fortunately, the project has [very recently](https://github.com/facebookresearch/xformers/pull/591) integrated a process to build pip wheels as part of the project's continuous integration, so this should improve a lot starting from xFormers version 0.0.16.
|
||||
|
||||
Until xFormers 0.0.16 is deployed, you can install pip wheels using [`TestPyPI`](https://test.pypi.org/project/formers/). These are the steps that worked for us in a Linux computer to install xFormers version 0.0.15:
|
||||
|
||||
```bash
|
||||
pip install pyre-extensions==0.0.23
|
||||
pip install -i https://test.pypi.org/simple/ formers==0.0.15.dev376
|
||||
```
|
||||
|
||||
We'll update these instructions when the wheels are published to the official PyPI repository.
|
||||
@@ -97,7 +97,7 @@ Running the pipeline is then identical to the code above as it's the same model
|
||||
>>> image.save("image_of_squirrel_painting.png")
|
||||
```
|
||||
|
||||
Diffusion systems can be used with multiple different [schedulers](./api/schedulers) each with their
|
||||
Diffusion systems can be used with multiple different [schedulers](./api/schedulers/overview) each with their
|
||||
pros and cons. By default, Stable Diffusion runs with [`PNDMScheduler`], but it's very simple to
|
||||
use a different scheduler. *E.g.* if you would instead like to use the [`EulerDiscreteScheduler`] scheduler,
|
||||
you could use it as follows:
|
||||
|
||||
@@ -21,8 +21,6 @@ The [Dreambooth training script](https://github.com/huggingface/diffusers/tree/m
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
<!-- TODO: replace with our blog when it's done -->
|
||||
|
||||
Dreambooth fine-tuning is very sensitive to hyperparameters and easy to overfit. We recommend you take a look at our [in-depth analysis](https://huggingface.co/blog/dreambooth) with recommended settings for different subjects, and go from there.
|
||||
|
||||
</Tip>
|
||||
@@ -38,23 +36,17 @@ pip install git+https://github.com/huggingface/diffusers
|
||||
pip install -U -r diffusers/examples/dreambooth/requirements.txt
|
||||
```
|
||||
|
||||
Then initialize and configure a [🤗 Accelerate](https://github.com/huggingface/accelerate/) environment with:
|
||||
xFormers is not part of the training requirements, but [we recommend you install it if you can](../optimization/xformers). It could make your training faster and less memory intensive.
|
||||
|
||||
After all dependencies have been set up you can configure a [🤗 Accelerate](https://github.com/huggingface/accelerate/) environment with:
|
||||
|
||||
```bash
|
||||
accelerate config
|
||||
```
|
||||
|
||||
You need to accept the model license before downloading or using the weights. In this example we'll use model version `v1-4`, so you'll need to visit [its card](https://huggingface.co/CompVis/stable-diffusion-v1-4), read the license and tick the checkbox if you agree.
|
||||
In this example we'll use model version `v1-4`, so please visit [its card](https://huggingface.co/CompVis/stable-diffusion-v1-4) and carefully read the license before proceeding.
|
||||
|
||||
You have to be a registered user in 🤗 Hugging Face Hub, and you'll also need to use an access token for the code to work. For more information on access tokens, please refer to [this section of the documentation](https://huggingface.co/docs/hub/security-tokens).
|
||||
|
||||
Run the following command to authenticate your token
|
||||
|
||||
```bash
|
||||
huggingface-cli login
|
||||
```
|
||||
|
||||
If you have already cloned the repo, then you won't need to go through these steps. Instead, you can pass the path to your local checkout to the training script and it will be loaded from there.
|
||||
The command below will download and cache the model weights from the Hub because we use the model's Hub id `CompVis/stable-diffusion-v1-4`. You may also clone the repo locally and use the local path in your system where the checkout was saved.
|
||||
|
||||
### Dog toy example
|
||||
|
||||
@@ -111,6 +103,59 @@ accelerate launch train_dreambooth.py \
|
||||
--max_train_steps=800
|
||||
```
|
||||
|
||||
### Saving checkpoints while training
|
||||
|
||||
It's easy to overfit while training with Dreambooth, so sometimes it's useful to save regular checkpoints during the process. One of the intermediate checkpoints might work better than the final model! To use this feature you need to pass the following argument to the training script:
|
||||
|
||||
```bash
|
||||
--checkpointing_steps=500
|
||||
```
|
||||
|
||||
This will save the full training state in subfolders of your `output_dir`. Subfolder names begin with the prefix `checkpoint-`, and then the number of steps performed so far; for example: `checkpoint-1500` would be a checkpoint saved after 1500 training steps.
|
||||
|
||||
#### Resuming training from a saved checkpoint
|
||||
|
||||
If you want to resume training from any of the saved checkpoints, you can pass the argument `--resume_from_checkpoint` and then indicate the name of the checkpoint you want to use. You can also use the special string `"latest"` to resume from the last checkpoint saved (i.e., the one with the largest number of steps). For example, the following would resume training from the checkpoint saved after 1500 steps:
|
||||
|
||||
```bash
|
||||
--resume_from_checkpoint="checkpoint-1500"
|
||||
```
|
||||
|
||||
This would be a good opportunity to tweak some of your hyperparameters if you wish.
|
||||
|
||||
#### Performing inference using a saved checkpoint
|
||||
|
||||
Saved checkpoints are stored in a format suitable for resuming training. They not only include the model weights, but also the state of the optimizer, data loaders and learning rate.
|
||||
|
||||
You can use a checkpoint for inference, but first you need to convert it to an inference pipeline. This is how you could do it:
|
||||
|
||||
```python
|
||||
from accelerate import Accelerator
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
# Load the pipeline with the same arguments (model, revision) that were used for training
|
||||
model_id = "CompVis/stable-diffusion-v1-4"
|
||||
pipeline = DiffusionPipeline.from_pretrained(model_id)
|
||||
|
||||
accelerator = Accelerator()
|
||||
|
||||
# Use text_encoder if `--train_text_encoder` was used for the initial training
|
||||
unet, text_encoder = accelerator.prepare(pipeline.unet, pipeline.text_encoder)
|
||||
|
||||
# Restore state from a checkpoint path. You have to use the absolute path here.
|
||||
accelerator.load_state("/sddata/dreambooth/daruma-v2-1/checkpoint-100")
|
||||
|
||||
# Rebuild the pipeline with the unwrapped models (assignment to .unet and .text_encoder should work too)
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
model_id,
|
||||
unet=accelerator.unwrap_model(unet),
|
||||
text_encoder=accelerator.unwrap_model(text_encoder),
|
||||
)
|
||||
|
||||
# Perform inference, or save, or push to the hub
|
||||
pipeline.save_pretrained("dreambooth-pipeline")
|
||||
```
|
||||
|
||||
### Training on a 16GB GPU
|
||||
|
||||
With the help of gradient checkpointing and the 8-bit optimizer from [bitsandbytes](https://github.com/TimDettmers/bitsandbytes), it's possible to train dreambooth on a 16GB GPU.
|
||||
|
||||
@@ -38,6 +38,7 @@ Training examples show how to pretrain or fine-tune diffusion models for a varie
|
||||
- [Text Inversion](./text_inversion)
|
||||
- [Dreambooth](./dreambooth)
|
||||
|
||||
If possible, please [install xFormers](../optimization/xformers) for memory efficient attention. This could help make your training faster and less memory intensive.
|
||||
|
||||
| Task | 🤗 Accelerate | 🤗 Datasets | Colab
|
||||
|---|---|:---:|:---:|
|
||||
|
||||
@@ -58,7 +58,6 @@ guided_pipeline = DiffusionPipeline.from_pretrained(
|
||||
custom_pipeline="clip_guided_stable_diffusion",
|
||||
clip_model=clip_model,
|
||||
feature_extractor=feature_extractor,
|
||||
revision="fp16",
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
guided_pipeline.enable_attention_slicing()
|
||||
@@ -113,7 +112,6 @@ import torch
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4",
|
||||
revision="fp16",
|
||||
torch_dtype=torch.float16,
|
||||
safety_checker=None, # Very important for videos...lots of false positives while interpolating
|
||||
custom_pipeline="interpolate_stable_diffusion",
|
||||
@@ -159,7 +157,6 @@ pipe = DiffusionPipeline.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4",
|
||||
custom_pipeline="stable_diffusion_mega",
|
||||
torch_dtype=torch.float16,
|
||||
revision="fp16",
|
||||
)
|
||||
pipe.to("cuda")
|
||||
pipe.enable_attention_slicing()
|
||||
@@ -204,7 +201,7 @@ from diffusers import DiffusionPipeline
|
||||
import torch
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
"hakurei/waifu-diffusion", custom_pipeline="lpw_stable_diffusion", revision="fp16", torch_dtype=torch.float16
|
||||
"hakurei/waifu-diffusion", custom_pipeline="lpw_stable_diffusion", torch_dtype=torch.float16
|
||||
)
|
||||
pipe = pipe.to("cuda")
|
||||
|
||||
@@ -268,7 +265,7 @@ diffuser_pipeline = DiffusionPipeline.from_pretrained(
|
||||
custom_pipeline="speech_to_image_diffusion",
|
||||
speech_model=model,
|
||||
speech_processor=processor,
|
||||
revision="fp16",
|
||||
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
|
||||
|
||||
@@ -24,9 +24,9 @@ from diffusers import StableDiffusionImg2ImgPipeline
|
||||
|
||||
# load the pipeline
|
||||
device = "cuda"
|
||||
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5", revision="fp16", torch_dtype=torch.float16
|
||||
).to(device)
|
||||
pipe = StableDiffusionImg2ImgPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16).to(
|
||||
device
|
||||
)
|
||||
|
||||
# let's download an initial image
|
||||
url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
|
||||
|
||||
@@ -42,7 +42,6 @@ mask_image = download_image(mask_url).resize((512, 512))
|
||||
|
||||
pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-inpainting",
|
||||
revision="fp16",
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
pipe = pipe.to("cuda")
|
||||
|
||||
73
docs/source/using-diffusers/reusing_seeds.mdx
Normal file
73
docs/source/using-diffusers/reusing_seeds.mdx
Normal file
@@ -0,0 +1,73 @@
|
||||
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Re-using seeds for fast prompt engineering
|
||||
|
||||
A common use case when generating images is to generate a batch of images, select one image and improve it with a better, more detailed prompt in a second run.
|
||||
To do this, one needs to make each generated image of the batch deterministic.
|
||||
Images are generated by denoising gaussian random noise which can be instantiated by passing a [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html#generator).
|
||||
|
||||
Now, for batched generation, we need to make sure that every single generated image in the batch is tied exactly to one seed. In 🧨 Diffusers, this can be achieved by not passing one `generator`, but a list
|
||||
of `generators` to the pipeline.
|
||||
|
||||
Let's go through an example using [`runwayml/stable-diffusion-v1-5`](runwayml/stable-diffusion-v1-5).
|
||||
We want to generate several versions of the prompt:
|
||||
|
||||
```py
|
||||
prompt = "Labrador in the style of Vermeer"
|
||||
```
|
||||
|
||||
Let's load the pipeline
|
||||
|
||||
```python
|
||||
>>> from diffusers import DiffusionPipeline
|
||||
|
||||
>>> pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
|
||||
>>> pipe = pipe.to("cuda")
|
||||
```
|
||||
|
||||
Now, let's define 4 different generators, since we would like to reproduce a certain image. We'll use seeds `0` to `3` to create our generators.
|
||||
|
||||
```python
|
||||
>>> import torch
|
||||
|
||||
>>> generator = [torch.Generator(device="cuda").manual_seed(i) for i in range(4)]
|
||||
```
|
||||
|
||||
Let's generate 4 images:
|
||||
|
||||
```python
|
||||
>>> images = pipe(prompt, generator=generator, num_images_per_prompt=4).images
|
||||
>>> images
|
||||
```
|
||||
|
||||

|
||||
|
||||
Ok, the last images has some double eyes, but the first image looks good!
|
||||
Let's try to make the prompt a bit better **while keeping the first seed**
|
||||
so that the images are similar to the first image.
|
||||
|
||||
```python
|
||||
prompt = [prompt + t for t in [", highly realistic", ", artsy", ", trending", ", colorful"]]
|
||||
generator = [torch.Generator(device="cuda").manual_seed(0) for i in range(4)]
|
||||
```
|
||||
|
||||
We create 4 generators with seed `0`, which is the first seed we used before.
|
||||
|
||||
Let's run the pipeline again.
|
||||
|
||||
```python
|
||||
>>> images = pipe(prompt, generator=generator).images
|
||||
>>> images
|
||||
```
|
||||
|
||||

|
||||
@@ -13,7 +13,7 @@ specific language governing permissions and limitations under the License.
|
||||
# Schedulers
|
||||
|
||||
Diffusion pipelines are inherently a collection of diffusion models and schedulers that are partly independent from each other. This means that one is able to switch out parts of the pipeline to better customize
|
||||
a pipeline to one's use case. The best example of this are the [Schedulers](../api/schedulers.mdx).
|
||||
a pipeline to one's use case. The best example of this are the [Schedulers](../api/schedulers/overview.mdx).
|
||||
|
||||
Whereas diffusion models usually simply define the forward pass from noise to a less noisy sample,
|
||||
schedulers define the whole denoising process, *i.e.*:
|
||||
|
||||
@@ -52,6 +52,10 @@ For such examples, we are more lenient regarding the philosophy defined above an
|
||||
Examples that are useful for the community, but are either not yet deemed popular or not yet following our above philosophy should go into the [community examples](https://github.com/huggingface/diffusers/tree/main/examples/community) folder. The community folder therefore includes training examples and inference pipelines.
|
||||
**Note**: Community examples can be a [great first contribution](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22) to show to the community how you like to use `diffusers` 🪄.
|
||||
|
||||
## Research Projects
|
||||
|
||||
We also provide **research_projects** examples that are maintained by the community as defined in the respective research project folders. These examples are useful and offer the extended capabilities which are complementary to the official examples. You may refer to [research_projects](https://github.com/huggingface/diffusers/tree/main/examples/research_projects) for details.
|
||||
|
||||
## Important note
|
||||
|
||||
To make sure you can successfully run the latest versions of the example scripts, you have to **install the library from source** and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
|
||||
|
||||
@@ -57,7 +57,7 @@ guided_pipeline = DiffusionPipeline.from_pretrained(
|
||||
custom_pipeline="clip_guided_stable_diffusion",
|
||||
clip_model=clip_model,
|
||||
feature_extractor=feature_extractor,
|
||||
revision="fp16",
|
||||
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
guided_pipeline.enable_attention_slicing()
|
||||
@@ -208,7 +208,7 @@ import torch
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
'hakurei/waifu-diffusion',
|
||||
custom_pipeline="lpw_stable_diffusion",
|
||||
revision="fp16",
|
||||
|
||||
torch_dtype=torch.float16
|
||||
)
|
||||
pipe=pipe.to("cuda")
|
||||
@@ -275,7 +275,7 @@ diffuser_pipeline = DiffusionPipeline.from_pretrained(
|
||||
custom_pipeline="speech_to_image_diffusion",
|
||||
speech_model=model,
|
||||
speech_processor=processor,
|
||||
revision="fp16",
|
||||
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
|
||||
@@ -333,7 +333,7 @@ import torch
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4",
|
||||
custom_pipeline="wildcard_stable_diffusion",
|
||||
revision="fp16",
|
||||
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
prompt = "__animal__ sitting on a __object__ wearing a __clothing__"
|
||||
@@ -567,7 +567,7 @@ diffuser_pipeline = DiffusionPipeline.from_pretrained(
|
||||
detection_pipeline=language_detection_pipeline,
|
||||
translation_model=trans_model,
|
||||
translation_tokenizer=trans_tokenizer,
|
||||
revision="fp16",
|
||||
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
|
||||
@@ -615,7 +615,7 @@ mask_image = PIL.Image.open(mask_path).convert("RGB").resize((512, 512))
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-inpainting",
|
||||
custom_pipeline="img2img_inpainting",
|
||||
revision="fp16",
|
||||
|
||||
torch_dtype=torch.float16
|
||||
)
|
||||
pipe = pipe.to("cuda")
|
||||
|
||||
@@ -68,7 +68,7 @@ class WildcardStableDiffusionPipeline(DiffusionPipeline):
|
||||
Example Usage:
|
||||
pipe = WildcardStableDiffusionPipeline.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4",
|
||||
revision="fp16",
|
||||
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
prompt = "__animal__ sitting on a __object__ wearing a __clothing__"
|
||||
|
||||
@@ -44,20 +44,6 @@ write_basic_config()
|
||||
|
||||
### Dog toy example
|
||||
|
||||
You need to accept the model license before downloading or using the weights. In this example we'll use model version `v1-4`, so you'll need to visit [its card](https://huggingface.co/CompVis/stable-diffusion-v1-4), read the license and tick the checkbox if you agree.
|
||||
|
||||
You have to be a registered user in 🤗 Hugging Face Hub, and you'll also need to use an access token for the code to work. For more information on access tokens, please refer to [this section of the documentation](https://huggingface.co/docs/hub/security-tokens).
|
||||
|
||||
Run the following command to authenticate your token
|
||||
|
||||
```bash
|
||||
huggingface-cli login
|
||||
```
|
||||
|
||||
If you have already cloned the repo, then you won't need to go through these steps.
|
||||
|
||||
<br>
|
||||
|
||||
Now let's get our dataset. Download images from [here](https://drive.google.com/drive/folders/1BO_dyz-p65qhBRRMRA4TbZ8qW4rB99JZ) and save them in a directory. This will be our training data.
|
||||
|
||||
And launch the training using
|
||||
|
||||
@@ -155,7 +155,8 @@ def parse_args(input_args=None):
|
||||
type=int,
|
||||
default=500,
|
||||
help=(
|
||||
"Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
|
||||
"Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
|
||||
" checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
|
||||
" training using `--resume_from_checkpoint`."
|
||||
),
|
||||
)
|
||||
|
||||
@@ -242,6 +242,25 @@ def parse_args():
|
||||
),
|
||||
)
|
||||
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
||||
parser.add_argument(
|
||||
"--checkpointing_steps",
|
||||
type=int,
|
||||
default=500,
|
||||
help=(
|
||||
"Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
|
||||
" checkpoints in case they are better than the last checkpoint and are suitable for resuming training"
|
||||
" using `--resume_from_checkpoint`."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--resume_from_checkpoint",
|
||||
type=str,
|
||||
default=None,
|
||||
help=(
|
||||
"Whether training should be resumed from a previous checkpoint. Use a path saved by"
|
||||
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
|
||||
),
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
||||
@@ -591,6 +610,7 @@ def main():
|
||||
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
unet, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
accelerator.register_for_checkpointing(lr_scheduler)
|
||||
|
||||
weight_dtype = torch.float32
|
||||
if args.mixed_precision == "fp16":
|
||||
@@ -628,14 +648,39 @@ def main():
|
||||
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
||||
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
|
||||
logger.info(f" Total optimization steps = {args.max_train_steps}")
|
||||
# Only show the progress bar once on each machine.
|
||||
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
|
||||
progress_bar.set_description("Steps")
|
||||
global_step = 0
|
||||
first_epoch = 0
|
||||
|
||||
for epoch in range(args.num_train_epochs):
|
||||
if args.resume_from_checkpoint:
|
||||
if args.resume_from_checkpoint != "latest":
|
||||
path = os.path.basename(args.resume_from_checkpoint)
|
||||
else:
|
||||
# Get the most recent checkpoint
|
||||
dirs = os.listdir(args.output_dir)
|
||||
dirs = [d for d in dirs if d.startswith("checkpoint")]
|
||||
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
|
||||
path = dirs[-1]
|
||||
accelerator.print(f"Resuming from checkpoint {path}")
|
||||
accelerator.load_state(os.path.join(args.output_dir, path))
|
||||
global_step = int(path.split("-")[1])
|
||||
|
||||
resume_global_step = global_step * args.gradient_accumulation_steps
|
||||
first_epoch = resume_global_step // num_update_steps_per_epoch
|
||||
resume_step = resume_global_step % num_update_steps_per_epoch
|
||||
|
||||
# Only show the progress bar once on each machine.
|
||||
progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
|
||||
progress_bar.set_description("Steps")
|
||||
|
||||
for epoch in range(first_epoch, args.num_epochs):
|
||||
unet.train()
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
# Skip steps until we reach the resumed step
|
||||
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
|
||||
if step % args.gradient_accumulation_steps == 0:
|
||||
progress_bar.update(1)
|
||||
continue
|
||||
|
||||
with accelerator.accumulate(unet):
|
||||
# Convert images to latent space
|
||||
|
||||
@@ -719,6 +764,12 @@ def main():
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
|
||||
if global_step % args.checkpointing_steps == 0:
|
||||
if accelerator.is_main_process:
|
||||
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
||||
accelerator.save_state(save_path)
|
||||
logger.info(f"Saved state to {save_path}")
|
||||
|
||||
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
|
||||
progress_bar.set_postfix(**logs)
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
17
examples/research_projects/intel_opts/README.md
Normal file
17
examples/research_projects/intel_opts/README.md
Normal file
@@ -0,0 +1,17 @@
|
||||
## Diffusers examples with Intel optimizations
|
||||
|
||||
**This research project is not actively maintained by the diffusers team. For any questions or comments, please make sure to tag @hshen14 .**
|
||||
|
||||
This aims to provide diffusers examples with Intel optimizations such as Bfloat16 for training/fine-tuning acceleration and 8-bit integer (INT8) for inference acceleration on Intel platforms.
|
||||
|
||||
## Accelerating the fine-tuning for textual inversion
|
||||
|
||||
We accelereate the fine-tuning for textual inversion with Intel Extension for PyTorch. The [examples](textual_inversion) enable both single node and multi-node distributed training with Bfloat16 support on Intel Xeon Scalable Processor.
|
||||
|
||||
## Accelerating the inference for Stable Diffusion using Bfloat16
|
||||
|
||||
We start the inference acceleration with Bfloat16 using Intel Extension for PyTorch. The [script](inference_bf16.py) is generally designed to support standard Stable Diffusion models with Bfloat16 support.
|
||||
|
||||
## Accelerating the inference for Stable Diffusion using INT8
|
||||
|
||||
Coming soon ...
|
||||
49
examples/research_projects/intel_opts/inference_bf16.py
Normal file
49
examples/research_projects/intel_opts/inference_bf16.py
Normal file
@@ -0,0 +1,49 @@
|
||||
import torch
|
||||
|
||||
import intel_extension_for_pytorch as ipex
|
||||
from diffusers import StableDiffusionPipeline
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def image_grid(imgs, rows, cols):
|
||||
assert len(imgs) == rows * cols
|
||||
|
||||
w, h = imgs[0].size
|
||||
grid = Image.new("RGB", size=(cols * w, rows * h))
|
||||
grid_w, grid_h = grid.size
|
||||
|
||||
for i, img in enumerate(imgs):
|
||||
grid.paste(img, box=(i % cols * w, i // cols * h))
|
||||
return grid
|
||||
|
||||
|
||||
prompt = ["a lovely <dicoo> in red dress and hat, in the snowly and brightly night, with many brighly buildings"]
|
||||
batch_size = 8
|
||||
prompt = prompt * batch_size
|
||||
|
||||
device = "cpu"
|
||||
model_id = "path-to-your-trained-model"
|
||||
model = StableDiffusionPipeline.from_pretrained(model_id)
|
||||
model = model.to(device)
|
||||
|
||||
# to channels last
|
||||
model.unet = model.unet.to(memory_format=torch.channels_last)
|
||||
model.vae = model.vae.to(memory_format=torch.channels_last)
|
||||
model.text_encoder = model.text_encoder.to(memory_format=torch.channels_last)
|
||||
model.safety_checker = model.safety_checker.to(memory_format=torch.channels_last)
|
||||
|
||||
# optimize with ipex
|
||||
model.unet = ipex.optimize(model.unet.eval(), dtype=torch.bfloat16, inplace=True)
|
||||
model.vae = ipex.optimize(model.vae.eval(), dtype=torch.bfloat16, inplace=True)
|
||||
model.text_encoder = ipex.optimize(model.text_encoder.eval(), dtype=torch.bfloat16, inplace=True)
|
||||
model.safety_checker = ipex.optimize(model.safety_checker.eval(), dtype=torch.bfloat16, inplace=True)
|
||||
|
||||
# compute
|
||||
seed = 666
|
||||
generator = torch.Generator(device).manual_seed(seed)
|
||||
with torch.cpu.amp.autocast(enabled=True, dtype=torch.bfloat16):
|
||||
images = model(prompt, guidance_scale=7.5, num_inference_steps=50, generator=generator).images
|
||||
|
||||
# save image
|
||||
grid = image_grid(images, rows=2, cols=4)
|
||||
grid.save(model_id + ".png")
|
||||
@@ -0,0 +1,68 @@
|
||||
## Textual Inversion fine-tuning example
|
||||
|
||||
[Textual inversion](https://arxiv.org/abs/2208.01618) is a method to personalize text2image models like stable diffusion on your own images using just 3-5 examples.
|
||||
The `textual_inversion.py` script shows how to implement the training procedure and adapt it for stable diffusion.
|
||||
|
||||
## Training with Intel Extension for PyTorch
|
||||
|
||||
Intel Extension for PyTorch provides the optimizations for faster training and inference on CPUs. You can leverage the training example "textual_inversion.py". Follow the [instructions](https://github.com/huggingface/diffusers/tree/main/examples/textual_inversion) to get the model and [dataset](https://huggingface.co/sd-concepts-library/dicoo2) before running the script.
|
||||
|
||||
The example supports both single node and multi-node distributed training:
|
||||
|
||||
### Single node training
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
|
||||
export DATA_DIR="path-to-dir-containing-dicoo-images"
|
||||
|
||||
python textual_inversion.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||
--train_data_dir=$DATA_DIR \
|
||||
--learnable_property="object" \
|
||||
--placeholder_token="<dicoo>" --initializer_token="toy" \
|
||||
--seed=7 \
|
||||
--resolution=512 \
|
||||
--train_batch_size=1 \
|
||||
--gradient_accumulation_steps=1 \
|
||||
--max_train_steps=3000 \
|
||||
--learning_rate=2.5e-03 --scale_lr \
|
||||
--output_dir="textual_inversion_dicoo"
|
||||
```
|
||||
|
||||
Note: Bfloat16 is available on Intel Xeon Scalable Processors Cooper Lake or Sapphire Rapids. You may not get performance speedup without Bfloat16 support.
|
||||
|
||||
### Multi-node distributed training
|
||||
|
||||
Before running the scripts, make sure to install the library's training dependencies successfully:
|
||||
|
||||
```bash
|
||||
python -m pip install oneccl_bind_pt==1.13 -f https://developer.intel.com/ipex-whl-stable-cpu
|
||||
```
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
|
||||
export DATA_DIR="path-to-dir-containing-dicoo-images"
|
||||
|
||||
oneccl_bindings_for_pytorch_path=$(python -c "from oneccl_bindings_for_pytorch import cwd; print(cwd)")
|
||||
source $oneccl_bindings_for_pytorch_path/env/setvars.sh
|
||||
|
||||
python -m intel_extension_for_pytorch.cpu.launch --distributed \
|
||||
--hostfile hostfile --nnodes 2 --nproc_per_node 2 textual_inversion.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||
--train_data_dir=$DATA_DIR \
|
||||
--learnable_property="object" \
|
||||
--placeholder_token="<dicoo>" --initializer_token="toy" \
|
||||
--seed=7 \
|
||||
--resolution=512 \
|
||||
--train_batch_size=1 \
|
||||
--gradient_accumulation_steps=1 \
|
||||
--max_train_steps=750 \
|
||||
--learning_rate=2.5e-03 --scale_lr \
|
||||
--output_dir="textual_inversion_dicoo"
|
||||
```
|
||||
The above is a simple distributed training usage on 2 nodes with 2 processes on each node. Add the right hostname or ip address in the "hostfile" and make sure these 2 nodes are reachable from each other. For more details, please refer to the [user guide](https://github.com/intel/torch-ccl).
|
||||
|
||||
|
||||
### Reference
|
||||
|
||||
We publish a [Medium blog](https://medium.com/intel-analytics-software/personalized-stable-diffusion-with-few-shot-fine-tuning-on-a-single-cpu-f01a3316b13) on how to create your own Stable Diffusion model on CPUs using textual inversion. Try it out now, if you have interests.
|
||||
@@ -0,0 +1,7 @@
|
||||
accelerate
|
||||
torchvision
|
||||
transformers>=4.21.0
|
||||
ftfy
|
||||
tensorboard
|
||||
modelcards
|
||||
intel_extension_for_pytorch>=1.13
|
||||
@@ -0,0 +1,645 @@
|
||||
import argparse
|
||||
import itertools
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
import intel_extension_for_pytorch as ipex
|
||||
import PIL
|
||||
from accelerate import Accelerator
|
||||
from accelerate.logging import get_logger
|
||||
from accelerate.utils import set_seed
|
||||
from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
|
||||
from diffusers.utils import check_min_version
|
||||
from huggingface_hub import HfFolder, Repository, whoami
|
||||
|
||||
# TODO: remove and import from diffusers.utils when the new version of diffusers is released
|
||||
from packaging import version
|
||||
from PIL import Image
|
||||
from torchvision import transforms
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
|
||||
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
|
||||
PIL_INTERPOLATION = {
|
||||
"linear": PIL.Image.Resampling.BILINEAR,
|
||||
"bilinear": PIL.Image.Resampling.BILINEAR,
|
||||
"bicubic": PIL.Image.Resampling.BICUBIC,
|
||||
"lanczos": PIL.Image.Resampling.LANCZOS,
|
||||
"nearest": PIL.Image.Resampling.NEAREST,
|
||||
}
|
||||
else:
|
||||
PIL_INTERPOLATION = {
|
||||
"linear": PIL.Image.LINEAR,
|
||||
"bilinear": PIL.Image.BILINEAR,
|
||||
"bicubic": PIL.Image.BICUBIC,
|
||||
"lanczos": PIL.Image.LANCZOS,
|
||||
"nearest": PIL.Image.NEAREST,
|
||||
}
|
||||
# ------------------------------------------------------------------------------
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.10.0.dev0")
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path):
|
||||
logger.info("Saving embeddings")
|
||||
learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[placeholder_token_id]
|
||||
learned_embeds_dict = {args.placeholder_token: learned_embeds.detach().cpu()}
|
||||
torch.save(learned_embeds_dict, save_path)
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
||||
parser.add_argument(
|
||||
"--save_steps",
|
||||
type=int,
|
||||
default=500,
|
||||
help="Save learned_embeds.bin every X updates steps.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--only_save_embeds",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Save only the embeddings for the new concept.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pretrained_model_name_or_path",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--revision",
|
||||
type=str,
|
||||
default=None,
|
||||
required=False,
|
||||
help="Revision of pretrained model identifier from huggingface.co/models.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokenizer_name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Pretrained tokenizer name or path if not the same as model_name",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train_data_dir", type=str, default=None, required=True, help="A folder containing the training data."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--placeholder_token",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="A token to use as a placeholder for the concept.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--initializer_token", type=str, default=None, required=True, help="A token to use as initializer word."
|
||||
)
|
||||
parser.add_argument("--learnable_property", type=str, default="object", help="Choose between 'object' and 'style'")
|
||||
parser.add_argument("--repeats", type=int, default=100, help="How many times to repeat the training data.")
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
default="text-inversion-model",
|
||||
help="The output directory where the model predictions and checkpoints will be written.",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
|
||||
parser.add_argument(
|
||||
"--resolution",
|
||||
type=int,
|
||||
default=512,
|
||||
help=(
|
||||
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
|
||||
" resolution"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
|
||||
)
|
||||
parser.add_argument("--num_train_epochs", type=int, default=100)
|
||||
parser.add_argument(
|
||||
"--max_train_steps",
|
||||
type=int,
|
||||
default=5000,
|
||||
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gradient_accumulation_steps",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--learning_rate",
|
||||
type=float,
|
||||
default=1e-4,
|
||||
help="Initial learning rate (after the potential warmup period) to use.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--scale_lr",
|
||||
action="store_true",
|
||||
default=True,
|
||||
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lr_scheduler",
|
||||
type=str,
|
||||
default="constant",
|
||||
help=(
|
||||
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
|
||||
' "constant", "constant_with_warmup"]'
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
|
||||
)
|
||||
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
|
||||
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
|
||||
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
|
||||
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
|
||||
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
|
||||
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
|
||||
parser.add_argument(
|
||||
"--hub_model_id",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The name of the repository to keep in sync with the local `output_dir`.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--logging_dir",
|
||||
type=str,
|
||||
default="logs",
|
||||
help=(
|
||||
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
|
||||
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mixed_precision",
|
||||
type=str,
|
||||
default="no",
|
||||
choices=["no", "fp16", "bf16"],
|
||||
help=(
|
||||
"Whether to use mixed precision. Choose"
|
||||
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
|
||||
"and an Nvidia Ampere GPU."
|
||||
),
|
||||
)
|
||||
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
||||
|
||||
args = parser.parse_args()
|
||||
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
||||
if env_local_rank != -1 and env_local_rank != args.local_rank:
|
||||
args.local_rank = env_local_rank
|
||||
|
||||
if args.train_data_dir is None:
|
||||
raise ValueError("You must specify a train data directory.")
|
||||
|
||||
return args
|
||||
|
||||
|
||||
imagenet_templates_small = [
|
||||
"a photo of a {}",
|
||||
"a rendering of a {}",
|
||||
"a cropped photo of the {}",
|
||||
"the photo of a {}",
|
||||
"a photo of a clean {}",
|
||||
"a photo of a dirty {}",
|
||||
"a dark photo of the {}",
|
||||
"a photo of my {}",
|
||||
"a photo of the cool {}",
|
||||
"a close-up photo of a {}",
|
||||
"a bright photo of the {}",
|
||||
"a cropped photo of a {}",
|
||||
"a photo of the {}",
|
||||
"a good photo of the {}",
|
||||
"a photo of one {}",
|
||||
"a close-up photo of the {}",
|
||||
"a rendition of the {}",
|
||||
"a photo of the clean {}",
|
||||
"a rendition of a {}",
|
||||
"a photo of a nice {}",
|
||||
"a good photo of a {}",
|
||||
"a photo of the nice {}",
|
||||
"a photo of the small {}",
|
||||
"a photo of the weird {}",
|
||||
"a photo of the large {}",
|
||||
"a photo of a cool {}",
|
||||
"a photo of a small {}",
|
||||
]
|
||||
|
||||
imagenet_style_templates_small = [
|
||||
"a painting in the style of {}",
|
||||
"a rendering in the style of {}",
|
||||
"a cropped painting in the style of {}",
|
||||
"the painting in the style of {}",
|
||||
"a clean painting in the style of {}",
|
||||
"a dirty painting in the style of {}",
|
||||
"a dark painting in the style of {}",
|
||||
"a picture in the style of {}",
|
||||
"a cool painting in the style of {}",
|
||||
"a close-up painting in the style of {}",
|
||||
"a bright painting in the style of {}",
|
||||
"a cropped painting in the style of {}",
|
||||
"a good painting in the style of {}",
|
||||
"a close-up painting in the style of {}",
|
||||
"a rendition in the style of {}",
|
||||
"a nice painting in the style of {}",
|
||||
"a small painting in the style of {}",
|
||||
"a weird painting in the style of {}",
|
||||
"a large painting in the style of {}",
|
||||
]
|
||||
|
||||
|
||||
class TextualInversionDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
data_root,
|
||||
tokenizer,
|
||||
learnable_property="object", # [object, style]
|
||||
size=512,
|
||||
repeats=100,
|
||||
interpolation="bicubic",
|
||||
flip_p=0.5,
|
||||
set="train",
|
||||
placeholder_token="*",
|
||||
center_crop=False,
|
||||
):
|
||||
self.data_root = data_root
|
||||
self.tokenizer = tokenizer
|
||||
self.learnable_property = learnable_property
|
||||
self.size = size
|
||||
self.placeholder_token = placeholder_token
|
||||
self.center_crop = center_crop
|
||||
self.flip_p = flip_p
|
||||
|
||||
self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)]
|
||||
|
||||
self.num_images = len(self.image_paths)
|
||||
self._length = self.num_images
|
||||
|
||||
if set == "train":
|
||||
self._length = self.num_images * repeats
|
||||
|
||||
self.interpolation = {
|
||||
"linear": PIL_INTERPOLATION["linear"],
|
||||
"bilinear": PIL_INTERPOLATION["bilinear"],
|
||||
"bicubic": PIL_INTERPOLATION["bicubic"],
|
||||
"lanczos": PIL_INTERPOLATION["lanczos"],
|
||||
}[interpolation]
|
||||
|
||||
self.templates = imagenet_style_templates_small if learnable_property == "style" else imagenet_templates_small
|
||||
self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p)
|
||||
|
||||
def __len__(self):
|
||||
return self._length
|
||||
|
||||
def __getitem__(self, i):
|
||||
example = {}
|
||||
image = Image.open(self.image_paths[i % self.num_images])
|
||||
|
||||
if not image.mode == "RGB":
|
||||
image = image.convert("RGB")
|
||||
|
||||
placeholder_string = self.placeholder_token
|
||||
text = random.choice(self.templates).format(placeholder_string)
|
||||
|
||||
example["input_ids"] = self.tokenizer(
|
||||
text,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
return_tensors="pt",
|
||||
).input_ids[0]
|
||||
|
||||
# default to score-sde preprocessing
|
||||
img = np.array(image).astype(np.uint8)
|
||||
|
||||
if self.center_crop:
|
||||
crop = min(img.shape[0], img.shape[1])
|
||||
h, w, = (
|
||||
img.shape[0],
|
||||
img.shape[1],
|
||||
)
|
||||
img = img[(h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2]
|
||||
|
||||
image = Image.fromarray(img)
|
||||
image = image.resize((self.size, self.size), resample=self.interpolation)
|
||||
|
||||
image = self.flip_transform(image)
|
||||
image = np.array(image).astype(np.uint8)
|
||||
image = (image / 127.5 - 1.0).astype(np.float32)
|
||||
|
||||
example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1)
|
||||
return example
|
||||
|
||||
|
||||
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
|
||||
if token is None:
|
||||
token = HfFolder.get_token()
|
||||
if organization is None:
|
||||
username = whoami(token)["name"]
|
||||
return f"{username}/{model_id}"
|
||||
else:
|
||||
return f"{organization}/{model_id}"
|
||||
|
||||
|
||||
def freeze_params(params):
|
||||
for param in params:
|
||||
param.requires_grad = False
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
logging_dir = os.path.join(args.output_dir, args.logging_dir)
|
||||
|
||||
accelerator = Accelerator(
|
||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||
mixed_precision=args.mixed_precision,
|
||||
log_with="tensorboard",
|
||||
logging_dir=logging_dir,
|
||||
)
|
||||
|
||||
# If passed along, set the training seed now.
|
||||
if args.seed is not None:
|
||||
set_seed(args.seed)
|
||||
|
||||
# Handle the repository creation
|
||||
if accelerator.is_main_process:
|
||||
if args.push_to_hub:
|
||||
if args.hub_model_id is None:
|
||||
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
|
||||
else:
|
||||
repo_name = args.hub_model_id
|
||||
repo = Repository(args.output_dir, clone_from=repo_name)
|
||||
|
||||
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
|
||||
if "step_*" not in gitignore:
|
||||
gitignore.write("step_*\n")
|
||||
if "epoch_*" not in gitignore:
|
||||
gitignore.write("epoch_*\n")
|
||||
elif args.output_dir is not None:
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
# Load the tokenizer and add the placeholder token as a additional special token
|
||||
if args.tokenizer_name:
|
||||
tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
|
||||
elif args.pretrained_model_name_or_path:
|
||||
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
|
||||
|
||||
# Add the placeholder token in tokenizer
|
||||
num_added_tokens = tokenizer.add_tokens(args.placeholder_token)
|
||||
if num_added_tokens == 0:
|
||||
raise ValueError(
|
||||
f"The tokenizer already contains the token {args.placeholder_token}. Please pass a different"
|
||||
" `placeholder_token` that is not already in the tokenizer."
|
||||
)
|
||||
|
||||
# Convert the initializer_token, placeholder_token to ids
|
||||
token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False)
|
||||
# Check if initializer_token is a single token or a sequence of tokens
|
||||
if len(token_ids) > 1:
|
||||
raise ValueError("The initializer token must be a single token.")
|
||||
|
||||
initializer_token_id = token_ids[0]
|
||||
placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token)
|
||||
|
||||
# Load models and create wrapper for stable diffusion
|
||||
text_encoder = CLIPTextModel.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
subfolder="text_encoder",
|
||||
revision=args.revision,
|
||||
)
|
||||
vae = AutoencoderKL.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
subfolder="vae",
|
||||
revision=args.revision,
|
||||
)
|
||||
unet = UNet2DConditionModel.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
subfolder="unet",
|
||||
revision=args.revision,
|
||||
)
|
||||
|
||||
# Resize the token embeddings as we are adding new special tokens to the tokenizer
|
||||
text_encoder.resize_token_embeddings(len(tokenizer))
|
||||
|
||||
# Initialise the newly added placeholder token with the embeddings of the initializer token
|
||||
token_embeds = text_encoder.get_input_embeddings().weight.data
|
||||
token_embeds[placeholder_token_id] = token_embeds[initializer_token_id]
|
||||
|
||||
# Freeze vae and unet
|
||||
freeze_params(vae.parameters())
|
||||
freeze_params(unet.parameters())
|
||||
# Freeze all parameters except for the token embeddings in text encoder
|
||||
params_to_freeze = itertools.chain(
|
||||
text_encoder.text_model.encoder.parameters(),
|
||||
text_encoder.text_model.final_layer_norm.parameters(),
|
||||
text_encoder.text_model.embeddings.position_embedding.parameters(),
|
||||
)
|
||||
freeze_params(params_to_freeze)
|
||||
|
||||
if args.scale_lr:
|
||||
args.learning_rate = (
|
||||
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
|
||||
)
|
||||
|
||||
# Initialize the optimizer
|
||||
optimizer = torch.optim.AdamW(
|
||||
text_encoder.get_input_embeddings().parameters(), # only optimize the embeddings
|
||||
lr=args.learning_rate,
|
||||
betas=(args.adam_beta1, args.adam_beta2),
|
||||
weight_decay=args.adam_weight_decay,
|
||||
eps=args.adam_epsilon,
|
||||
)
|
||||
|
||||
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
|
||||
|
||||
train_dataset = TextualInversionDataset(
|
||||
data_root=args.train_data_dir,
|
||||
tokenizer=tokenizer,
|
||||
size=args.resolution,
|
||||
placeholder_token=args.placeholder_token,
|
||||
repeats=args.repeats,
|
||||
learnable_property=args.learnable_property,
|
||||
center_crop=args.center_crop,
|
||||
set="train",
|
||||
)
|
||||
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=args.train_batch_size, shuffle=True)
|
||||
|
||||
# Scheduler and math around the number of training steps.
|
||||
overrode_max_train_steps = False
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
if args.max_train_steps is None:
|
||||
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
||||
overrode_max_train_steps = True
|
||||
|
||||
lr_scheduler = get_scheduler(
|
||||
args.lr_scheduler,
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
|
||||
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
|
||||
)
|
||||
|
||||
text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
text_encoder, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
|
||||
# Move vae and unet to device
|
||||
vae.to(accelerator.device)
|
||||
unet.to(accelerator.device)
|
||||
|
||||
# Keep vae and unet in eval model as we don't train these
|
||||
vae.eval()
|
||||
unet.eval()
|
||||
|
||||
unet = ipex.optimize(unet, dtype=torch.bfloat16, inplace=True)
|
||||
vae = ipex.optimize(vae, dtype=torch.bfloat16, inplace=True)
|
||||
|
||||
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
if overrode_max_train_steps:
|
||||
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
||||
# Afterwards we recalculate our number of training epochs
|
||||
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||
|
||||
# We need to initialize the trackers we use, and also store our configuration.
|
||||
# The trackers initializes automatically on the main process.
|
||||
if accelerator.is_main_process:
|
||||
accelerator.init_trackers("textual_inversion", config=vars(args))
|
||||
|
||||
# Train!
|
||||
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
||||
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(f" Num examples = {len(train_dataset)}")
|
||||
logger.info(f" Num Epochs = {args.num_train_epochs}")
|
||||
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
|
||||
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
||||
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
|
||||
logger.info(f" Total optimization steps = {args.max_train_steps}")
|
||||
# Only show the progress bar once on each machine.
|
||||
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
|
||||
progress_bar.set_description("Steps")
|
||||
global_step = 0
|
||||
|
||||
text_encoder.train()
|
||||
text_encoder, optimizer = ipex.optimize(text_encoder, optimizer=optimizer, dtype=torch.bfloat16)
|
||||
|
||||
for epoch in range(args.num_train_epochs):
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
with torch.cpu.amp.autocast(enabled=True, dtype=torch.bfloat16):
|
||||
with accelerator.accumulate(text_encoder):
|
||||
# Convert images to latent space
|
||||
latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach()
|
||||
latents = latents * 0.18215
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn(latents.shape).to(latents.device)
|
||||
bsz = latents.shape[0]
|
||||
# Sample a random timestep for each image
|
||||
timesteps = torch.randint(
|
||||
0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device
|
||||
).long()
|
||||
|
||||
# Add noise to the latents according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process)
|
||||
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||
|
||||
# Get the text embedding for conditioning
|
||||
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
|
||||
|
||||
# Predict the noise residual
|
||||
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
||||
|
||||
# Get the target for loss depending on the prediction type
|
||||
if noise_scheduler.config.prediction_type == "epsilon":
|
||||
target = noise
|
||||
elif noise_scheduler.config.prediction_type == "v_prediction":
|
||||
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
||||
else:
|
||||
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
|
||||
|
||||
loss = F.mse_loss(model_pred, target, reduction="none").mean([1, 2, 3]).mean()
|
||||
accelerator.backward(loss)
|
||||
|
||||
# Zero out the gradients for all token embeddings except the newly added
|
||||
# embeddings for the concept, as we only want to optimize the concept embeddings
|
||||
if accelerator.num_processes > 1:
|
||||
grads = text_encoder.module.get_input_embeddings().weight.grad
|
||||
else:
|
||||
grads = text_encoder.get_input_embeddings().weight.grad
|
||||
# Get the index for tokens that we want to zero the grads for
|
||||
index_grads_to_zero = torch.arange(len(tokenizer)) != placeholder_token_id
|
||||
grads.data[index_grads_to_zero, :] = grads.data[index_grads_to_zero, :].fill_(0)
|
||||
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
if global_step % args.save_steps == 0:
|
||||
save_path = os.path.join(args.output_dir, f"learned_embeds-steps-{global_step}.bin")
|
||||
save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path)
|
||||
|
||||
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
|
||||
progress_bar.set_postfix(**logs)
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
if global_step >= args.max_train_steps:
|
||||
break
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# Create the pipeline using using the trained modules and save it.
|
||||
if accelerator.is_main_process:
|
||||
if args.push_to_hub and args.only_save_embeds:
|
||||
logger.warn("Enabling full model saving because --push_to_hub=True was specified.")
|
||||
save_full_model = True
|
||||
else:
|
||||
save_full_model = not args.only_save_embeds
|
||||
if save_full_model:
|
||||
pipeline = StableDiffusionPipeline(
|
||||
text_encoder=accelerator.unwrap_model(text_encoder),
|
||||
vae=vae,
|
||||
unet=unet,
|
||||
tokenizer=tokenizer,
|
||||
scheduler=PNDMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler"),
|
||||
safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"),
|
||||
feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
|
||||
)
|
||||
pipeline.save_pretrained(args.output_dir)
|
||||
# Save the newly trained embeddings
|
||||
save_path = os.path.join(args.output_dir, "learned_embeds.bin")
|
||||
save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path)
|
||||
|
||||
if args.push_to_hub:
|
||||
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
|
||||
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -216,6 +216,24 @@ def parse_args():
|
||||
),
|
||||
)
|
||||
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
||||
parser.add_argument(
|
||||
"--checkpointing_steps",
|
||||
type=int,
|
||||
default=500,
|
||||
help=(
|
||||
"Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
|
||||
" training using `--resume_from_checkpoint`."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--resume_from_checkpoint",
|
||||
type=str,
|
||||
default=None,
|
||||
help=(
|
||||
"Whether training should be resumed from a previous checkpoint. Use a path saved by"
|
||||
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
|
||||
),
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
||||
@@ -528,6 +546,7 @@ def main():
|
||||
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
unet, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
accelerator.register_for_checkpointing(lr_scheduler)
|
||||
|
||||
weight_dtype = torch.float32
|
||||
if accelerator.mixed_precision == "fp16":
|
||||
@@ -567,16 +586,40 @@ def main():
|
||||
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
||||
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
|
||||
logger.info(f" Total optimization steps = {args.max_train_steps}")
|
||||
global_step = 0
|
||||
first_epoch = 0
|
||||
|
||||
if args.resume_from_checkpoint:
|
||||
if args.resume_from_checkpoint != "latest":
|
||||
path = os.path.basename(args.resume_from_checkpoint)
|
||||
else:
|
||||
# Get the most recent checkpoint
|
||||
dirs = os.listdir(args.output_dir)
|
||||
dirs = [d for d in dirs if d.startswith("checkpoint")]
|
||||
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
|
||||
path = dirs[-1]
|
||||
accelerator.print(f"Resuming from checkpoint {path}")
|
||||
accelerator.load_state(os.path.join(args.output_dir, path))
|
||||
global_step = int(path.split("-")[1])
|
||||
|
||||
resume_global_step = global_step * args.gradient_accumulation_steps
|
||||
first_epoch = resume_global_step // num_update_steps_per_epoch
|
||||
resume_step = resume_global_step % num_update_steps_per_epoch
|
||||
|
||||
# Only show the progress bar once on each machine.
|
||||
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
|
||||
progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
|
||||
progress_bar.set_description("Steps")
|
||||
global_step = 0
|
||||
|
||||
for epoch in range(args.num_train_epochs):
|
||||
for epoch in range(first_epoch, args.num_train_epochs):
|
||||
unet.train()
|
||||
train_loss = 0.0
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
# Skip steps until we reach the resumed step
|
||||
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
|
||||
if step % args.gradient_accumulation_steps == 0:
|
||||
progress_bar.update(1)
|
||||
continue
|
||||
|
||||
with accelerator.accumulate(unet):
|
||||
# Convert images to latent space
|
||||
latents = vae.encode(batch["pixel_values"].to(weight_dtype)).latent_dist.sample()
|
||||
@@ -629,6 +672,12 @@ def main():
|
||||
accelerator.log({"train_loss": train_loss}, step=global_step)
|
||||
train_loss = 0.0
|
||||
|
||||
if global_step % args.checkpointing_steps == 0:
|
||||
if accelerator.is_main_process:
|
||||
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
||||
accelerator.save_state(save_path)
|
||||
logger.info(f"Saved state to {save_path}")
|
||||
|
||||
logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
|
||||
progress_bar.set_postfix(**logs)
|
||||
|
||||
|
||||
@@ -205,6 +205,24 @@ def parse_args():
|
||||
),
|
||||
)
|
||||
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
||||
parser.add_argument(
|
||||
"--checkpointing_steps",
|
||||
type=int,
|
||||
default=500,
|
||||
help=(
|
||||
"Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
|
||||
" training using `--resume_from_checkpoint`."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--resume_from_checkpoint",
|
||||
type=str,
|
||||
default=None,
|
||||
help=(
|
||||
"Whether training should be resumed from a previous checkpoint. Use a path saved by"
|
||||
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
|
||||
),
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
||||
@@ -512,10 +530,17 @@ def main():
|
||||
text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
text_encoder, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
accelerator.register_for_checkpointing(lr_scheduler)
|
||||
|
||||
weight_dtype = torch.float32
|
||||
if accelerator.mixed_precision == "fp16":
|
||||
weight_dtype = torch.float16
|
||||
elif accelerator.mixed_precision == "bf16":
|
||||
weight_dtype = torch.bfloat16
|
||||
|
||||
# Move vae and unet to device
|
||||
vae.to(accelerator.device)
|
||||
unet.to(accelerator.device)
|
||||
unet.to(accelerator.device, dtype=weight_dtype)
|
||||
vae.to(accelerator.device, dtype=weight_dtype)
|
||||
|
||||
# Keep vae and unet in eval model as we don't train these
|
||||
vae.eval()
|
||||
@@ -543,24 +568,49 @@ def main():
|
||||
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
||||
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
|
||||
logger.info(f" Total optimization steps = {args.max_train_steps}")
|
||||
# Only show the progress bar once on each machine.
|
||||
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
|
||||
progress_bar.set_description("Steps")
|
||||
global_step = 0
|
||||
first_epoch = 0
|
||||
|
||||
if args.resume_from_checkpoint:
|
||||
if args.resume_from_checkpoint != "latest":
|
||||
path = os.path.basename(args.resume_from_checkpoint)
|
||||
else:
|
||||
# Get the most recent checkpoint
|
||||
dirs = os.listdir(args.output_dir)
|
||||
dirs = [d for d in dirs if d.startswith("checkpoint")]
|
||||
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
|
||||
path = dirs[-1]
|
||||
accelerator.print(f"Resuming from checkpoint {path}")
|
||||
accelerator.load_state(os.path.join(args.output_dir, path))
|
||||
global_step = int(path.split("-")[1])
|
||||
|
||||
resume_global_step = global_step * args.gradient_accumulation_steps
|
||||
first_epoch = resume_global_step // num_update_steps_per_epoch
|
||||
resume_step = resume_global_step % num_update_steps_per_epoch
|
||||
|
||||
# Only show the progress bar once on each machine.
|
||||
progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
|
||||
progress_bar.set_description("Steps")
|
||||
|
||||
# keep original embeddings as reference
|
||||
orig_embeds_params = text_encoder.get_input_embeddings().weight.data.clone()
|
||||
|
||||
for epoch in range(args.num_train_epochs):
|
||||
for epoch in range(first_epoch, args.num_train_epochs):
|
||||
text_encoder.train()
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
# Skip steps until we reach the resumed step
|
||||
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
|
||||
if step % args.gradient_accumulation_steps == 0:
|
||||
progress_bar.update(1)
|
||||
continue
|
||||
|
||||
with accelerator.accumulate(text_encoder):
|
||||
# Convert images to latent space
|
||||
latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach()
|
||||
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample().detach()
|
||||
latents = latents * 0.18215
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn(latents.shape).to(latents.device)
|
||||
noise = torch.randn(latents.shape).to(latents.device).to(dtype=weight_dtype)
|
||||
bsz = latents.shape[0]
|
||||
# Sample a random timestep for each image
|
||||
timesteps = torch.randint(
|
||||
@@ -572,7 +622,7 @@ def main():
|
||||
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||
|
||||
# Get the text embedding for conditioning
|
||||
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
|
||||
encoder_hidden_states = text_encoder(batch["input_ids"])[0].to(dtype=weight_dtype)
|
||||
|
||||
# Predict the noise residual
|
||||
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
||||
@@ -585,7 +635,7 @@ def main():
|
||||
else:
|
||||
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
|
||||
|
||||
loss = F.mse_loss(model_pred, target, reduction="none").mean([1, 2, 3]).mean()
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean()
|
||||
accelerator.backward(loss)
|
||||
|
||||
optimizer.step()
|
||||
@@ -605,6 +655,12 @@ def main():
|
||||
save_path = os.path.join(args.output_dir, f"learned_embeds-steps-{global_step}.bin")
|
||||
save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path)
|
||||
|
||||
if global_step % args.checkpointing_steps == 0:
|
||||
if accelerator.is_main_process:
|
||||
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
||||
accelerator.save_state(save_path)
|
||||
logger.info(f"Saved state to {save_path}")
|
||||
|
||||
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
|
||||
progress_bar.set_postfix(**logs)
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
@@ -173,6 +173,16 @@ def parse_args():
|
||||
parser.add_argument(
|
||||
"--hub_private_repo", action="store_true", help="Whether or not to create a private repository."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--logger",
|
||||
type=str,
|
||||
default="tensorboard",
|
||||
choices=["tensorboard", "wandb"],
|
||||
help=(
|
||||
"Whether to use [tensorboard](https://www.tensorflow.org/tensorboard) or [wandb](https://www.wandb.ai)"
|
||||
" for experiment tracking and logging of model metrics and model checkpoints"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--logging_dir",
|
||||
type=str,
|
||||
@@ -194,7 +204,6 @@ def parse_args():
|
||||
"and an Nvidia Ampere GPU."
|
||||
),
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--prediction_type",
|
||||
type=str,
|
||||
@@ -202,9 +211,26 @@ def parse_args():
|
||||
choices=["epsilon", "sample"],
|
||||
help="Whether the model should predict the 'epsilon'/noise error or directly the reconstructed image 'x0'.",
|
||||
)
|
||||
|
||||
parser.add_argument("--ddpm_num_steps", type=int, default=1000)
|
||||
parser.add_argument("--ddpm_beta_schedule", type=str, default="linear")
|
||||
parser.add_argument(
|
||||
"--checkpointing_steps",
|
||||
type=int,
|
||||
default=500,
|
||||
help=(
|
||||
"Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
|
||||
" training using `--resume_from_checkpoint`."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--resume_from_checkpoint",
|
||||
type=str,
|
||||
default=None,
|
||||
help=(
|
||||
"Whether training should be resumed from a previous checkpoint. Use a path saved by"
|
||||
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
|
||||
),
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
||||
@@ -232,7 +258,7 @@ def main(args):
|
||||
accelerator = Accelerator(
|
||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||
mixed_precision=args.mixed_precision,
|
||||
log_with="tensorboard",
|
||||
log_with=args.logger,
|
||||
logging_dir=logging_dir,
|
||||
)
|
||||
|
||||
@@ -319,6 +345,7 @@ def main(args):
|
||||
model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
model, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
accelerator.register_for_checkpointing(lr_scheduler)
|
||||
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
|
||||
@@ -351,11 +378,36 @@ def main(args):
|
||||
accelerator.init_trackers(run)
|
||||
|
||||
global_step = 0
|
||||
for epoch in range(args.num_epochs):
|
||||
first_epoch = 0
|
||||
|
||||
if args.resume_from_checkpoint:
|
||||
if args.resume_from_checkpoint != "latest":
|
||||
path = os.path.basename(args.resume_from_checkpoint)
|
||||
else:
|
||||
# Get the most recent checkpoint
|
||||
dirs = os.listdir(args.output_dir)
|
||||
dirs = [d for d in dirs if d.startswith("checkpoint")]
|
||||
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
|
||||
path = dirs[-1]
|
||||
accelerator.print(f"Resuming from checkpoint {path}")
|
||||
accelerator.load_state(os.path.join(args.output_dir, path))
|
||||
global_step = int(path.split("-")[1])
|
||||
|
||||
resume_global_step = global_step * args.gradient_accumulation_steps
|
||||
first_epoch = resume_global_step // num_update_steps_per_epoch
|
||||
resume_step = resume_global_step % num_update_steps_per_epoch
|
||||
|
||||
for epoch in range(first_epoch, args.num_epochs):
|
||||
model.train()
|
||||
progress_bar = tqdm(total=num_update_steps_per_epoch, disable=not accelerator.is_local_main_process)
|
||||
progress_bar.set_description(f"Epoch {epoch}")
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
# Skip steps until we reach the resumed step
|
||||
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
|
||||
if step % args.gradient_accumulation_steps == 0:
|
||||
progress_bar.update(1)
|
||||
continue
|
||||
|
||||
clean_images = batch["input"]
|
||||
# Sample noise that we'll add to the images
|
||||
noise = torch.randn(clean_images.shape).to(clean_images.device)
|
||||
@@ -402,6 +454,12 @@ def main(args):
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
|
||||
if global_step % args.checkpointing_steps == 0:
|
||||
if accelerator.is_main_process:
|
||||
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
||||
accelerator.save_state(save_path)
|
||||
logger.info(f"Saved state to {save_path}")
|
||||
|
||||
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step}
|
||||
if args.use_ema:
|
||||
logs["ema_decay"] = ema_model.decay
|
||||
@@ -429,9 +487,11 @@ def main(args):
|
||||
|
||||
# denormalize the images and save to tensorboard
|
||||
images_processed = (images * 255).round().astype("uint8")
|
||||
accelerator.trackers[0].writer.add_images(
|
||||
"test_samples", images_processed.transpose(0, 3, 1, 2), epoch
|
||||
)
|
||||
|
||||
if args.logger == "tensorboard":
|
||||
accelerator.get_tracker("tensorboard").add_images(
|
||||
"test_samples", images_processed.transpose(0, 3, 1, 2), epoch
|
||||
)
|
||||
|
||||
if epoch % args.save_model_epochs == 0 or epoch == args.num_epochs - 1:
|
||||
# save the model
|
||||
|
||||
1159
scripts/convert_kakao_brain_unclip_to_diffusers.py
Normal file
1159
scripts/convert_kakao_brain_unclip_to_diffusers.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -803,7 +803,7 @@ if __name__ == "__main__":
|
||||
"--scheduler_type",
|
||||
default="pndm",
|
||||
type=str,
|
||||
help="Type of scheduler to use. Should be one of ['pndm', 'lms', 'ddim', 'euler', 'euler-ancest', 'dpm']",
|
||||
help="Type of scheduler to use. Should be one of ['pndm', 'lms', 'ddim', 'euler', 'euler-ancestral', 'dpm']",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pipeline_type",
|
||||
|
||||
@@ -692,7 +692,7 @@ if __name__ == "__main__":
|
||||
"--scheduler_type",
|
||||
default="pndm",
|
||||
type=str,
|
||||
help="Type of scheduler to use. Should be one of ['pndm', 'lms', 'ddim', 'euler', 'euler-ancest', 'dpm']",
|
||||
help="Type of scheduler to use. Should be one of ['pndm', 'lms', 'ddim', 'euler', 'euler-ancestral', 'dpm']",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--extract_ema",
|
||||
|
||||
2
setup.py
2
setup.py
@@ -218,7 +218,7 @@ install_requires = [
|
||||
|
||||
setup(
|
||||
name="diffusers",
|
||||
version="0.11.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
version="0.11.1", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
description="Diffusers",
|
||||
long_description=open("README.md", "r", encoding="utf-8").read(),
|
||||
long_description_content_type="text/markdown",
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
__version__ = "0.11.0.dev0"
|
||||
__version__ = "0.11.1"
|
||||
|
||||
from .configuration_utils import ConfigMixin
|
||||
from .onnx_utils import OnnxRuntimeModel
|
||||
@@ -25,7 +25,15 @@ except OptionalDependencyNotAvailable:
|
||||
from .utils.dummy_pt_objects import * # noqa F403
|
||||
else:
|
||||
from .modeling_utils import ModelMixin
|
||||
from .models import AutoencoderKL, Transformer2DModel, UNet1DModel, UNet2DConditionModel, UNet2DModel, VQModel
|
||||
from .models import (
|
||||
AutoencoderKL,
|
||||
PriorTransformer,
|
||||
Transformer2DModel,
|
||||
UNet1DModel,
|
||||
UNet2DConditionModel,
|
||||
UNet2DModel,
|
||||
VQModel,
|
||||
)
|
||||
from .optimization import (
|
||||
get_constant_schedule,
|
||||
get_constant_schedule_with_warmup,
|
||||
@@ -63,6 +71,7 @@ else:
|
||||
RePaintScheduler,
|
||||
SchedulerMixin,
|
||||
ScoreSdeVeScheduler,
|
||||
UnCLIPScheduler,
|
||||
VQDiffusionScheduler,
|
||||
)
|
||||
from .training_utils import EMAModel
|
||||
@@ -96,6 +105,7 @@ else:
|
||||
StableDiffusionPipeline,
|
||||
StableDiffusionPipelineSafe,
|
||||
StableDiffusionUpscalePipeline,
|
||||
UnCLIPPipeline,
|
||||
VersatileDiffusionDualGuidedPipeline,
|
||||
VersatileDiffusionImageVariationPipeline,
|
||||
VersatileDiffusionPipeline,
|
||||
|
||||
@@ -17,6 +17,7 @@ from ..utils import is_flax_available, is_torch_available
|
||||
|
||||
if is_torch_available():
|
||||
from .attention import Transformer2DModel
|
||||
from .prior_transformer import PriorTransformer
|
||||
from .unet_1d import UNet1DModel
|
||||
from .unet_2d import UNet2DModel
|
||||
from .unet_2d_condition import UNet2DConditionModel
|
||||
|
||||
@@ -66,8 +66,8 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
||||
in_channels (`int`, *optional*):
|
||||
Pass if the input is continuous. The number of channels in the input and output.
|
||||
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
|
||||
dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use.
|
||||
cross_attention_dim (`int`, *optional*): The number of context dimensions to use.
|
||||
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
||||
cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use.
|
||||
sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
|
||||
Note that this is fixed at training time as it is used for learning a number of position embeddings. See
|
||||
`ImagePositionalEmbeddings`.
|
||||
@@ -181,7 +181,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
||||
hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
|
||||
When continous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
|
||||
hidden_states
|
||||
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, context dim)`, *optional*):
|
||||
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
|
||||
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
|
||||
self-attention.
|
||||
timestep ( `torch.long`, *optional*):
|
||||
@@ -213,7 +213,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
||||
|
||||
# 2. Blocks
|
||||
for block in self.transformer_blocks:
|
||||
hidden_states = block(hidden_states, context=encoder_hidden_states, timestep=timestep)
|
||||
hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states, timestep=timestep)
|
||||
|
||||
# 3. Output
|
||||
if self.is_input_continuous:
|
||||
@@ -260,6 +260,8 @@ class AttentionBlock(nn.Module):
|
||||
eps (`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm.
|
||||
"""
|
||||
|
||||
# IMPORTANT;TODO(Patrick, William) - this class will be deprecated soon. Do not use it anymore
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
@@ -369,6 +371,7 @@ class AttentionBlock(nn.Module):
|
||||
|
||||
# compute next hidden_states
|
||||
hidden_states = self.proj_attn(hidden_states)
|
||||
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width)
|
||||
|
||||
# res connect and rescale
|
||||
@@ -385,7 +388,7 @@ class BasicTransformerBlock(nn.Module):
|
||||
num_attention_heads (`int`): The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`): The number of channels in each head.
|
||||
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
||||
cross_attention_dim (`int`, *optional*): The size of the context vector for cross attention.
|
||||
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
|
||||
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
||||
num_embeds_ada_norm (:
|
||||
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
|
||||
@@ -432,7 +435,7 @@ class BasicTransformerBlock(nn.Module):
|
||||
dropout=dropout,
|
||||
bias=attention_bias,
|
||||
upcast_attention=upcast_attention,
|
||||
) # is self-attn if context is none
|
||||
) # is self-attn if encoder_hidden_states is none
|
||||
else:
|
||||
self.attn2 = None
|
||||
|
||||
@@ -470,25 +473,33 @@ class BasicTransformerBlock(nn.Module):
|
||||
except Exception as e:
|
||||
raise e
|
||||
self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
|
||||
self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
|
||||
if self.attn2 is not None:
|
||||
self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
|
||||
|
||||
def forward(self, hidden_states, context=None, timestep=None):
|
||||
def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None):
|
||||
# 1. Self-Attention
|
||||
norm_hidden_states = (
|
||||
self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states)
|
||||
)
|
||||
|
||||
if self.only_cross_attention:
|
||||
hidden_states = self.attn1(norm_hidden_states, context) + hidden_states
|
||||
hidden_states = (
|
||||
self.attn1(norm_hidden_states, encoder_hidden_states, attention_mask=attention_mask) + hidden_states
|
||||
)
|
||||
else:
|
||||
hidden_states = self.attn1(norm_hidden_states) + hidden_states
|
||||
hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask) + hidden_states
|
||||
|
||||
if self.attn2 is not None:
|
||||
# 2. Cross-Attention
|
||||
norm_hidden_states = (
|
||||
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
|
||||
)
|
||||
hidden_states = self.attn2(norm_hidden_states, context=context) + hidden_states
|
||||
hidden_states = (
|
||||
self.attn2(
|
||||
norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
|
||||
)
|
||||
+ hidden_states
|
||||
)
|
||||
|
||||
# 3. Feed-forward
|
||||
hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
|
||||
@@ -503,7 +514,7 @@ class CrossAttention(nn.Module):
|
||||
Parameters:
|
||||
query_dim (`int`): The number of channels in the query.
|
||||
cross_attention_dim (`int`, *optional*):
|
||||
The number of channels in the context. If not given, defaults to `query_dim`.
|
||||
The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
|
||||
heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
|
||||
dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
|
||||
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
||||
@@ -520,13 +531,18 @@ class CrossAttention(nn.Module):
|
||||
dropout: float = 0.0,
|
||||
bias=False,
|
||||
upcast_attention: bool = False,
|
||||
upcast_softmax: bool = False,
|
||||
added_kv_proj_dim: Optional[int] = None,
|
||||
norm_num_groups: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
|
||||
self.upcast_attention = upcast_attention
|
||||
self.upcast_softmax = upcast_softmax
|
||||
|
||||
self.scale = dim_head**-0.5
|
||||
|
||||
self.heads = heads
|
||||
# for slice_size > 0 the attention score computation
|
||||
# is split across the batch axis to save memory
|
||||
@@ -534,11 +550,21 @@ class CrossAttention(nn.Module):
|
||||
self.sliceable_head_dim = heads
|
||||
self._slice_size = None
|
||||
self._use_memory_efficient_attention_xformers = False
|
||||
self.added_kv_proj_dim = added_kv_proj_dim
|
||||
|
||||
if norm_num_groups is not None:
|
||||
self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True)
|
||||
else:
|
||||
self.group_norm = None
|
||||
|
||||
self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
|
||||
self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
|
||||
self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
|
||||
|
||||
if self.added_kv_proj_dim is not None:
|
||||
self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
|
||||
self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
|
||||
|
||||
self.to_out = nn.ModuleList([])
|
||||
self.to_out.append(nn.Linear(inner_dim, query_dim))
|
||||
self.to_out.append(nn.Dropout(dropout))
|
||||
@@ -563,40 +589,64 @@ class CrossAttention(nn.Module):
|
||||
|
||||
self._slice_size = slice_size
|
||||
|
||||
def forward(self, hidden_states, context=None, mask=None):
|
||||
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
||||
batch_size, sequence_length, _ = hidden_states.shape
|
||||
|
||||
encoder_hidden_states = encoder_hidden_states
|
||||
|
||||
if self.group_norm is not None:
|
||||
hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
query = self.to_q(hidden_states)
|
||||
context = context if context is not None else hidden_states
|
||||
key = self.to_k(context)
|
||||
value = self.to_v(context)
|
||||
|
||||
dim = query.shape[-1]
|
||||
|
||||
query = self.reshape_heads_to_batch_dim(query)
|
||||
key = self.reshape_heads_to_batch_dim(key)
|
||||
value = self.reshape_heads_to_batch_dim(value)
|
||||
|
||||
# TODO(PVP) - mask is currently never used. Remember to re-implement when used
|
||||
if self.added_kv_proj_dim is not None:
|
||||
key = self.to_k(hidden_states)
|
||||
value = self.to_v(hidden_states)
|
||||
encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states)
|
||||
encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states)
|
||||
|
||||
key = self.reshape_heads_to_batch_dim(key)
|
||||
value = self.reshape_heads_to_batch_dim(value)
|
||||
encoder_hidden_states_key_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj)
|
||||
encoder_hidden_states_value_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj)
|
||||
|
||||
key = torch.concat([encoder_hidden_states_key_proj, key], dim=1)
|
||||
value = torch.concat([encoder_hidden_states_value_proj, value], dim=1)
|
||||
else:
|
||||
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
|
||||
key = self.to_k(encoder_hidden_states)
|
||||
value = self.to_v(encoder_hidden_states)
|
||||
|
||||
key = self.reshape_heads_to_batch_dim(key)
|
||||
value = self.reshape_heads_to_batch_dim(value)
|
||||
|
||||
if attention_mask is not None:
|
||||
if attention_mask.shape[-1] != query.shape[1]:
|
||||
target_length = query.shape[1]
|
||||
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
|
||||
attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
|
||||
|
||||
# attention, what we cannot get enough of
|
||||
if self._use_memory_efficient_attention_xformers:
|
||||
hidden_states = self._memory_efficient_attention_xformers(query, key, value)
|
||||
hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
|
||||
# Some versions of xformers return output in fp32, cast it back to the dtype of the input
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
else:
|
||||
if self._slice_size is None or query.shape[0] // self._slice_size == 1:
|
||||
hidden_states = self._attention(query, key, value)
|
||||
hidden_states = self._attention(query, key, value, attention_mask)
|
||||
else:
|
||||
hidden_states = self._sliced_attention(query, key, value, sequence_length, dim)
|
||||
hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
|
||||
|
||||
# linear proj
|
||||
hidden_states = self.to_out[0](hidden_states)
|
||||
|
||||
# dropout
|
||||
hidden_states = self.to_out[1](hidden_states)
|
||||
return hidden_states
|
||||
|
||||
def _attention(self, query, key, value):
|
||||
def _attention(self, query, key, value, attention_mask=None):
|
||||
if self.upcast_attention:
|
||||
query = query.float()
|
||||
key = key.float()
|
||||
@@ -608,6 +658,13 @@ class CrossAttention(nn.Module):
|
||||
beta=0,
|
||||
alpha=self.scale,
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_scores = attention_scores + attention_mask
|
||||
|
||||
if self.upcast_softmax:
|
||||
attention_scores = attention_scores.float()
|
||||
|
||||
attention_probs = attention_scores.softmax(dim=-1)
|
||||
|
||||
# cast back to the original dtype
|
||||
@@ -620,7 +677,7 @@ class CrossAttention(nn.Module):
|
||||
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
def _sliced_attention(self, query, key, value, sequence_length, dim):
|
||||
def _sliced_attention(self, query, key, value, sequence_length, dim, attention_mask):
|
||||
batch_size_attention = query.shape[0]
|
||||
hidden_states = torch.zeros(
|
||||
(batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype
|
||||
@@ -644,6 +701,13 @@ class CrossAttention(nn.Module):
|
||||
beta=0,
|
||||
alpha=self.scale,
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
attn_slice = attn_slice + attention_mask[start_idx:end_idx]
|
||||
|
||||
if self.upcast_softmax:
|
||||
attn_slice = attn_slice.float()
|
||||
|
||||
attn_slice = attn_slice.softmax(dim=-1)
|
||||
|
||||
# cast back to the original dtype
|
||||
@@ -656,11 +720,12 @@ class CrossAttention(nn.Module):
|
||||
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
def _memory_efficient_attention_xformers(self, query, key, value):
|
||||
def _memory_efficient_attention_xformers(self, query, key, value, attention_mask):
|
||||
# TODO attention_mask
|
||||
query = query.contiguous()
|
||||
key = key.contiguous()
|
||||
value = value.contiguous()
|
||||
hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=None)
|
||||
hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
|
||||
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
@@ -802,7 +867,7 @@ class DualTransformer2DModel(nn.Module):
|
||||
Pass if the input is continuous. The number of channels in the input and output.
|
||||
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
|
||||
dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use.
|
||||
cross_attention_dim (`int`, *optional*): The number of context dimensions to use.
|
||||
cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use.
|
||||
sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
|
||||
Note that this is fixed at training time as it is used for learning a number of position embeddings. See
|
||||
`ImagePositionalEmbeddings`.
|
||||
@@ -867,17 +932,21 @@ class DualTransformer2DModel(nn.Module):
|
||||
# E.g. `(1, 0)` means that we'll use `transformers[1](conditions[0])` and `transformers[0](conditions[1])`
|
||||
self.transformer_index_for_condition = [1, 0]
|
||||
|
||||
def forward(self, hidden_states, encoder_hidden_states, timestep=None, return_dict: bool = True):
|
||||
def forward(
|
||||
self, hidden_states, encoder_hidden_states, timestep=None, attention_mask=None, return_dict: bool = True
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
|
||||
When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
|
||||
hidden_states
|
||||
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, context dim)`, *optional*):
|
||||
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
|
||||
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
|
||||
self-attention.
|
||||
timestep ( `torch.long`, *optional*):
|
||||
Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
|
||||
attention_mask (`torch.FloatTensor`, *optional*):
|
||||
Optional attention mask to be applied in CrossAttention
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
|
||||
|
||||
@@ -890,13 +959,17 @@ class DualTransformer2DModel(nn.Module):
|
||||
|
||||
encoded_states = []
|
||||
tokens_start = 0
|
||||
# attention_mask is not used yet
|
||||
for i in range(2):
|
||||
# for each of the two transformers, pass the corresponding condition tokens
|
||||
condition_state = encoder_hidden_states[:, tokens_start : tokens_start + self.condition_lengths[i]]
|
||||
transformer_index = self.transformer_index_for_condition[i]
|
||||
encoded_state = self.transformers[transformer_index](input_states, condition_state, timestep, return_dict)[
|
||||
0
|
||||
]
|
||||
encoded_state = self.transformers[transformer_index](
|
||||
input_states,
|
||||
encoder_hidden_states=condition_state,
|
||||
timestep=timestep,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
encoded_states.append(encoded_state - input_states)
|
||||
tokens_start += self.condition_lengths[i]
|
||||
|
||||
|
||||
194
src/diffusers/models/prior_transformer.py
Normal file
194
src/diffusers/models/prior_transformer.py
Normal file
@@ -0,0 +1,194 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..utils import BaseOutput
|
||||
from .attention import BasicTransformerBlock
|
||||
from .embeddings import TimestepEmbedding, Timesteps
|
||||
|
||||
|
||||
@dataclass
|
||||
class PriorTransformerOutput(BaseOutput):
|
||||
"""
|
||||
Args:
|
||||
predicted_image_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
|
||||
The predicted CLIP image embedding conditioned on the CLIP text embedding input.
|
||||
"""
|
||||
|
||||
predicted_image_embedding: torch.FloatTensor
|
||||
|
||||
|
||||
class PriorTransformer(ModelMixin, ConfigMixin):
|
||||
"""
|
||||
The prior transformer from unCLIP is used to predict CLIP image embeddings from CLIP text embeddings. Note that the
|
||||
transformer predicts the image embeddings through a denoising diffusion process.
|
||||
|
||||
This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
|
||||
implements for all the models (such as downloading or saving, etc.)
|
||||
|
||||
For more details, see the original paper: https://arxiv.org/abs/2204.06125
|
||||
|
||||
Parameters:
|
||||
num_attention_heads (`int`, *optional*, defaults to 32): The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
|
||||
num_layers (`int`, *optional*, defaults to 20): The number of layers of Transformer blocks to use.
|
||||
embedding_dim (`int`, *optional*, defaults to 768): The dimension of the CLIP embeddings. Note that CLIP
|
||||
image embeddings and text embeddings are both the same dimension.
|
||||
num_embeddings (`int`, *optional*, defaults to 77): The max number of clip embeddings allowed. I.e. the
|
||||
length of the prompt after it has been tokenized.
|
||||
additional_embeddings (`int`, *optional*, defaults to 4): The number of additional tokens appended to the
|
||||
projected hidden_states. The actual length of the used hidden_states is `num_embeddings +
|
||||
additional_embeddings`.
|
||||
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
||||
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_attention_heads: int = 32,
|
||||
attention_head_dim: int = 64,
|
||||
num_layers: int = 20,
|
||||
embedding_dim: int = 768,
|
||||
num_embeddings=77,
|
||||
additional_embeddings=4,
|
||||
dropout: float = 0.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.attention_head_dim = attention_head_dim
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
self.additional_embeddings = additional_embeddings
|
||||
|
||||
self.time_proj = Timesteps(inner_dim, True, 0)
|
||||
self.time_embedding = TimestepEmbedding(inner_dim, inner_dim)
|
||||
|
||||
self.proj_in = nn.Linear(embedding_dim, inner_dim)
|
||||
|
||||
self.embedding_proj = nn.Linear(embedding_dim, inner_dim)
|
||||
self.encoder_hidden_states_proj = nn.Linear(embedding_dim, inner_dim)
|
||||
|
||||
self.positional_embedding = nn.Parameter(torch.zeros(1, num_embeddings + additional_embeddings, inner_dim))
|
||||
|
||||
self.prd_embedding = nn.Parameter(torch.zeros(1, 1, inner_dim))
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
BasicTransformerBlock(
|
||||
inner_dim,
|
||||
num_attention_heads,
|
||||
attention_head_dim,
|
||||
dropout=dropout,
|
||||
activation_fn="gelu",
|
||||
attention_bias=True,
|
||||
)
|
||||
for d in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.norm_out = nn.LayerNorm(inner_dim)
|
||||
self.proj_to_clip_embeddings = nn.Linear(inner_dim, embedding_dim)
|
||||
|
||||
causal_attention_mask = torch.full(
|
||||
[num_embeddings + additional_embeddings, num_embeddings + additional_embeddings], float("-inf")
|
||||
)
|
||||
causal_attention_mask.triu_(1)
|
||||
causal_attention_mask = causal_attention_mask[None, ...]
|
||||
self.register_buffer("causal_attention_mask", causal_attention_mask, persistent=False)
|
||||
|
||||
self.clip_mean = nn.Parameter(torch.zeros(1, embedding_dim))
|
||||
self.clip_std = nn.Parameter(torch.zeros(1, embedding_dim))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
timestep: Union[torch.Tensor, float, int],
|
||||
proj_embedding: torch.FloatTensor,
|
||||
encoder_hidden_states: torch.FloatTensor,
|
||||
attention_mask: Optional[torch.BoolTensor] = None,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
|
||||
x_t, the currently predicted image embeddings.
|
||||
timestep (`torch.long`):
|
||||
Current denoising step.
|
||||
proj_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
|
||||
Projected embedding vector the denoising process is conditioned on.
|
||||
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_embeddings, embedding_dim)`):
|
||||
Hidden states of the text embeddings the denoising process is conditioned on.
|
||||
attention_mask (`torch.BoolTensor` of shape `(batch_size, num_embeddings)`):
|
||||
Text mask for the text embeddings.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`models.prior_transformer.PriorTransformerOutput`] instead of a plain
|
||||
tuple.
|
||||
|
||||
Returns:
|
||||
[`~models.prior_transformer.PriorTransformerOutput`] or `tuple`:
|
||||
[`~models.prior_transformer.PriorTransformerOutput`] if `return_dict` is True, otherwise a `tuple`. When
|
||||
returning a tuple, the first element is the sample tensor.
|
||||
"""
|
||||
batch_size = hidden_states.shape[0]
|
||||
|
||||
timesteps = timestep
|
||||
if not torch.is_tensor(timesteps):
|
||||
timesteps = torch.tensor([timesteps], dtype=torch.long, device=hidden_states.device)
|
||||
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
|
||||
timesteps = timesteps[None].to(hidden_states.device)
|
||||
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timesteps = timesteps * torch.ones(batch_size, dtype=timesteps.dtype, device=timesteps.device)
|
||||
|
||||
timesteps_projected = self.time_proj(timesteps)
|
||||
|
||||
# timesteps does not contain any weights and will always return f32 tensors
|
||||
# but time_embedding might be fp16, so we need to cast here.
|
||||
timesteps_projected = timesteps_projected.to(dtype=self.dtype)
|
||||
time_embeddings = self.time_embedding(timesteps_projected)
|
||||
|
||||
proj_embeddings = self.embedding_proj(proj_embedding)
|
||||
encoder_hidden_states = self.encoder_hidden_states_proj(encoder_hidden_states)
|
||||
hidden_states = self.proj_in(hidden_states)
|
||||
prd_embedding = self.prd_embedding.to(hidden_states.dtype).expand(batch_size, -1, -1)
|
||||
positional_embeddings = self.positional_embedding.to(hidden_states.dtype)
|
||||
|
||||
hidden_states = torch.cat(
|
||||
[
|
||||
encoder_hidden_states,
|
||||
proj_embeddings[:, None, :],
|
||||
time_embeddings[:, None, :],
|
||||
hidden_states[:, None, :],
|
||||
prd_embedding,
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
|
||||
hidden_states = hidden_states + positional_embeddings
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
|
||||
attention_mask = F.pad(attention_mask, (0, self.additional_embeddings), value=0.0)
|
||||
attention_mask = (attention_mask[:, None, :] + self.causal_attention_mask).to(hidden_states.dtype)
|
||||
attention_mask = attention_mask.repeat_interleave(self.config.num_attention_heads, dim=0)
|
||||
|
||||
for block in self.transformer_blocks:
|
||||
hidden_states = block(hidden_states, attention_mask=attention_mask)
|
||||
|
||||
hidden_states = self.norm_out(hidden_states)
|
||||
hidden_states = hidden_states[:, -1]
|
||||
predicted_image_embedding = self.proj_to_clip_embeddings(hidden_states)
|
||||
|
||||
if not return_dict:
|
||||
return (predicted_image_embedding,)
|
||||
|
||||
return PriorTransformerOutput(predicted_image_embedding=predicted_image_embedding)
|
||||
|
||||
def post_process_latents(self, prior_latents):
|
||||
prior_latents = (prior_latents * self.clip_std) + self.clip_mean
|
||||
return prior_latents
|
||||
@@ -405,7 +405,14 @@ class ResnetBlock2D(nn.Module):
|
||||
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
if temb_channels is not None:
|
||||
self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels)
|
||||
if self.time_embedding_norm == "default":
|
||||
time_emb_proj_out_channels = out_channels
|
||||
elif self.time_embedding_norm == "scale_shift":
|
||||
time_emb_proj_out_channels = out_channels * 2
|
||||
else:
|
||||
raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
|
||||
|
||||
self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels)
|
||||
else:
|
||||
self.time_emb_proj = None
|
||||
|
||||
@@ -465,9 +472,16 @@ class ResnetBlock2D(nn.Module):
|
||||
|
||||
if temb is not None:
|
||||
temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None]
|
||||
|
||||
if temb is not None and self.time_embedding_norm == "default":
|
||||
hidden_states = hidden_states + temb
|
||||
|
||||
hidden_states = self.norm2(hidden_states)
|
||||
|
||||
if temb is not None and self.time_embedding_norm == "scale_shift":
|
||||
scale, shift = torch.chunk(temb, 2, dim=1)
|
||||
hidden_states = hidden_states * (1 + scale) + shift
|
||||
|
||||
hidden_states = self.nonlinearity(hidden_states)
|
||||
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
|
||||
@@ -55,6 +55,8 @@ class UNet2DModel(ModelMixin, ConfigMixin):
|
||||
down_block_types (`Tuple[str]`, *optional*, defaults to :
|
||||
obj:`("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D")`): Tuple of downsample block
|
||||
types.
|
||||
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2D"`):
|
||||
The mid block type. Choose from `UNetMidBlock2D` or `UnCLIPUNetMidBlock2D`.
|
||||
up_block_types (`Tuple[str]`, *optional*, defaults to :
|
||||
obj:`("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D")`): Tuple of upsample block types.
|
||||
block_out_channels (`Tuple[int]`, *optional*, defaults to :
|
||||
@@ -66,6 +68,8 @@ class UNet2DModel(ModelMixin, ConfigMixin):
|
||||
attention_head_dim (`int`, *optional*, defaults to `8`): The attention head dimension.
|
||||
norm_num_groups (`int`, *optional*, defaults to `32`): The number of groups for the normalization.
|
||||
norm_eps (`float`, *optional*, defaults to `1e-5`): The epsilon for the normalization.
|
||||
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
|
||||
for resnet blocks, see [`~models.resnet.ResnetBlock2D`]. Choose from `default` or `scale_shift`.
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
@@ -88,6 +92,8 @@ class UNet2DModel(ModelMixin, ConfigMixin):
|
||||
attention_head_dim: int = 8,
|
||||
norm_num_groups: int = 32,
|
||||
norm_eps: float = 1e-5,
|
||||
resnet_time_scale_shift: str = "default",
|
||||
add_attention: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -130,6 +136,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
|
||||
resnet_groups=norm_num_groups,
|
||||
attn_num_head_channels=attention_head_dim,
|
||||
downsample_padding=downsample_padding,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
)
|
||||
self.down_blocks.append(down_block)
|
||||
|
||||
@@ -140,9 +147,10 @@ class UNet2DModel(ModelMixin, ConfigMixin):
|
||||
resnet_eps=norm_eps,
|
||||
resnet_act_fn=act_fn,
|
||||
output_scale_factor=mid_block_scale_factor,
|
||||
resnet_time_scale_shift="default",
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
attn_num_head_channels=attention_head_dim,
|
||||
resnet_groups=norm_num_groups,
|
||||
add_attention=add_attention,
|
||||
)
|
||||
|
||||
# up
|
||||
@@ -167,6 +175,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
|
||||
resnet_act_fn=act_fn,
|
||||
resnet_groups=norm_num_groups,
|
||||
attn_num_head_channels=attention_head_dim,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
)
|
||||
self.up_blocks.append(up_block)
|
||||
prev_output_channel = output_channel
|
||||
|
||||
@@ -15,7 +15,7 @@ import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from .attention import AttentionBlock, DualTransformer2DModel, Transformer2DModel
|
||||
from .attention import AttentionBlock, CrossAttention, DualTransformer2DModel, Transformer2DModel
|
||||
from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, ResnetBlock2D, Upsample2D
|
||||
|
||||
|
||||
@@ -36,6 +36,7 @@ def get_down_block(
|
||||
use_linear_projection=False,
|
||||
only_cross_attention=False,
|
||||
upcast_attention=False,
|
||||
resnet_time_scale_shift="default",
|
||||
):
|
||||
down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
|
||||
if down_block_type == "DownBlock2D":
|
||||
@@ -49,6 +50,19 @@ def get_down_block(
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
resnet_groups=resnet_groups,
|
||||
downsample_padding=downsample_padding,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
)
|
||||
elif down_block_type == "ResnetDownsampleBlock2D":
|
||||
return ResnetDownsampleBlock2D(
|
||||
num_layers=num_layers,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
add_downsample=add_downsample,
|
||||
resnet_eps=resnet_eps,
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
resnet_groups=resnet_groups,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
)
|
||||
elif down_block_type == "AttnDownBlock2D":
|
||||
return AttnDownBlock2D(
|
||||
@@ -62,6 +76,7 @@ def get_down_block(
|
||||
resnet_groups=resnet_groups,
|
||||
downsample_padding=downsample_padding,
|
||||
attn_num_head_channels=attn_num_head_channels,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
)
|
||||
elif down_block_type == "CrossAttnDownBlock2D":
|
||||
if cross_attention_dim is None:
|
||||
@@ -82,6 +97,23 @@ def get_down_block(
|
||||
use_linear_projection=use_linear_projection,
|
||||
only_cross_attention=only_cross_attention,
|
||||
upcast_attention=upcast_attention,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
)
|
||||
elif down_block_type == "SimpleCrossAttnDownBlock2D":
|
||||
if cross_attention_dim is None:
|
||||
raise ValueError("cross_attention_dim must be specified for SimpleCrossAttnDownBlock2D")
|
||||
return SimpleCrossAttnDownBlock2D(
|
||||
num_layers=num_layers,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
add_downsample=add_downsample,
|
||||
resnet_eps=resnet_eps,
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
resnet_groups=resnet_groups,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
attn_num_head_channels=attn_num_head_channels,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
)
|
||||
elif down_block_type == "SkipDownBlock2D":
|
||||
return SkipDownBlock2D(
|
||||
@@ -93,6 +125,7 @@ def get_down_block(
|
||||
resnet_eps=resnet_eps,
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
downsample_padding=downsample_padding,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
)
|
||||
elif down_block_type == "AttnSkipDownBlock2D":
|
||||
return AttnSkipDownBlock2D(
|
||||
@@ -105,6 +138,7 @@ def get_down_block(
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
downsample_padding=downsample_padding,
|
||||
attn_num_head_channels=attn_num_head_channels,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
)
|
||||
elif down_block_type == "DownEncoderBlock2D":
|
||||
return DownEncoderBlock2D(
|
||||
@@ -116,6 +150,7 @@ def get_down_block(
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
resnet_groups=resnet_groups,
|
||||
downsample_padding=downsample_padding,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
)
|
||||
elif down_block_type == "AttnDownEncoderBlock2D":
|
||||
return AttnDownEncoderBlock2D(
|
||||
@@ -128,6 +163,7 @@ def get_down_block(
|
||||
resnet_groups=resnet_groups,
|
||||
downsample_padding=downsample_padding,
|
||||
attn_num_head_channels=attn_num_head_channels,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
)
|
||||
raise ValueError(f"{down_block_type} does not exist.")
|
||||
|
||||
@@ -149,6 +185,7 @@ def get_up_block(
|
||||
use_linear_projection=False,
|
||||
only_cross_attention=False,
|
||||
upcast_attention=False,
|
||||
resnet_time_scale_shift="default",
|
||||
):
|
||||
up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
|
||||
if up_block_type == "UpBlock2D":
|
||||
@@ -162,6 +199,20 @@ def get_up_block(
|
||||
resnet_eps=resnet_eps,
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
resnet_groups=resnet_groups,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
)
|
||||
elif up_block_type == "ResnetUpsampleBlock2D":
|
||||
return ResnetUpsampleBlock2D(
|
||||
num_layers=num_layers,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
prev_output_channel=prev_output_channel,
|
||||
temb_channels=temb_channels,
|
||||
add_upsample=add_upsample,
|
||||
resnet_eps=resnet_eps,
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
resnet_groups=resnet_groups,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
)
|
||||
elif up_block_type == "CrossAttnUpBlock2D":
|
||||
if cross_attention_dim is None:
|
||||
@@ -182,6 +233,24 @@ def get_up_block(
|
||||
use_linear_projection=use_linear_projection,
|
||||
only_cross_attention=only_cross_attention,
|
||||
upcast_attention=upcast_attention,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
)
|
||||
elif up_block_type == "SimpleCrossAttnUpBlock2D":
|
||||
if cross_attention_dim is None:
|
||||
raise ValueError("cross_attention_dim must be specified for SimpleCrossAttnUpBlock2D")
|
||||
return SimpleCrossAttnUpBlock2D(
|
||||
num_layers=num_layers,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
prev_output_channel=prev_output_channel,
|
||||
temb_channels=temb_channels,
|
||||
add_upsample=add_upsample,
|
||||
resnet_eps=resnet_eps,
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
resnet_groups=resnet_groups,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
attn_num_head_channels=attn_num_head_channels,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
)
|
||||
elif up_block_type == "AttnUpBlock2D":
|
||||
return AttnUpBlock2D(
|
||||
@@ -195,6 +264,7 @@ def get_up_block(
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
resnet_groups=resnet_groups,
|
||||
attn_num_head_channels=attn_num_head_channels,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
)
|
||||
elif up_block_type == "SkipUpBlock2D":
|
||||
return SkipUpBlock2D(
|
||||
@@ -206,6 +276,7 @@ def get_up_block(
|
||||
add_upsample=add_upsample,
|
||||
resnet_eps=resnet_eps,
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
)
|
||||
elif up_block_type == "AttnSkipUpBlock2D":
|
||||
return AttnSkipUpBlock2D(
|
||||
@@ -218,6 +289,7 @@ def get_up_block(
|
||||
resnet_eps=resnet_eps,
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
attn_num_head_channels=attn_num_head_channels,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
)
|
||||
elif up_block_type == "UpDecoderBlock2D":
|
||||
return UpDecoderBlock2D(
|
||||
@@ -228,6 +300,7 @@ def get_up_block(
|
||||
resnet_eps=resnet_eps,
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
resnet_groups=resnet_groups,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
)
|
||||
elif up_block_type == "AttnUpDecoderBlock2D":
|
||||
return AttnUpDecoderBlock2D(
|
||||
@@ -239,6 +312,7 @@ def get_up_block(
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
resnet_groups=resnet_groups,
|
||||
attn_num_head_channels=attn_num_head_channels,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
)
|
||||
raise ValueError(f"{up_block_type} does not exist.")
|
||||
|
||||
@@ -255,14 +329,13 @@ class UNetMidBlock2D(nn.Module):
|
||||
resnet_act_fn: str = "swish",
|
||||
resnet_groups: int = 32,
|
||||
resnet_pre_norm: bool = True,
|
||||
add_attention: bool = True,
|
||||
attn_num_head_channels=1,
|
||||
attention_type="default",
|
||||
output_scale_factor=1.0,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.attention_type = attention_type
|
||||
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
||||
self.add_attention = add_attention
|
||||
|
||||
# there is always at least one resnet
|
||||
resnets = [
|
||||
@@ -282,15 +355,19 @@ class UNetMidBlock2D(nn.Module):
|
||||
attentions = []
|
||||
|
||||
for _ in range(num_layers):
|
||||
attentions.append(
|
||||
AttentionBlock(
|
||||
in_channels,
|
||||
num_head_channels=attn_num_head_channels,
|
||||
rescale_output_factor=output_scale_factor,
|
||||
eps=resnet_eps,
|
||||
norm_num_groups=resnet_groups,
|
||||
if self.add_attention:
|
||||
attentions.append(
|
||||
AttentionBlock(
|
||||
in_channels,
|
||||
num_head_channels=attn_num_head_channels,
|
||||
rescale_output_factor=output_scale_factor,
|
||||
eps=resnet_eps,
|
||||
norm_num_groups=resnet_groups,
|
||||
)
|
||||
)
|
||||
)
|
||||
else:
|
||||
attentions.append(None)
|
||||
|
||||
resnets.append(
|
||||
ResnetBlock2D(
|
||||
in_channels=in_channels,
|
||||
@@ -309,13 +386,11 @@ class UNetMidBlock2D(nn.Module):
|
||||
self.attentions = nn.ModuleList(attentions)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
def forward(self, hidden_states, temb=None, encoder_states=None):
|
||||
def forward(self, hidden_states, temb=None):
|
||||
hidden_states = self.resnets[0](hidden_states, temb)
|
||||
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
||||
if self.attention_type == "default":
|
||||
if attn is not None:
|
||||
hidden_states = attn(hidden_states)
|
||||
else:
|
||||
hidden_states = attn(hidden_states, encoder_states)
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
return hidden_states
|
||||
@@ -334,7 +409,6 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
||||
resnet_groups: int = 32,
|
||||
resnet_pre_norm: bool = True,
|
||||
attn_num_head_channels=1,
|
||||
attention_type="default",
|
||||
output_scale_factor=1.0,
|
||||
cross_attention_dim=1280,
|
||||
dual_cross_attention=False,
|
||||
@@ -344,7 +418,6 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
self.has_cross_attention = True
|
||||
self.attention_type = attention_type
|
||||
self.attn_num_head_channels = attn_num_head_channels
|
||||
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
||||
|
||||
@@ -408,10 +481,104 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
||||
self.attentions = nn.ModuleList(attentions)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
|
||||
def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):
|
||||
# TODO(Patrick, William) - attention_mask is currently not used. Implement once used
|
||||
hidden_states = self.resnets[0](hidden_states, temb)
|
||||
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
||||
hidden_states = attn(hidden_states, encoder_hidden_states).sample
|
||||
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class UNetMidBlock2DSimpleCrossAttn(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
temb_channels: int,
|
||||
dropout: float = 0.0,
|
||||
num_layers: int = 1,
|
||||
resnet_eps: float = 1e-6,
|
||||
resnet_time_scale_shift: str = "default",
|
||||
resnet_act_fn: str = "swish",
|
||||
resnet_groups: int = 32,
|
||||
resnet_pre_norm: bool = True,
|
||||
attn_num_head_channels=1,
|
||||
output_scale_factor=1.0,
|
||||
cross_attention_dim=1280,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.has_cross_attention = True
|
||||
|
||||
self.attn_num_head_channels = attn_num_head_channels
|
||||
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
||||
|
||||
self.num_heads = in_channels // self.attn_num_head_channels
|
||||
|
||||
# there is always at least one resnet
|
||||
resnets = [
|
||||
ResnetBlock2D(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
temb_channels=temb_channels,
|
||||
eps=resnet_eps,
|
||||
groups=resnet_groups,
|
||||
dropout=dropout,
|
||||
time_embedding_norm=resnet_time_scale_shift,
|
||||
non_linearity=resnet_act_fn,
|
||||
output_scale_factor=output_scale_factor,
|
||||
pre_norm=resnet_pre_norm,
|
||||
)
|
||||
]
|
||||
attentions = []
|
||||
|
||||
for _ in range(num_layers):
|
||||
attentions.append(
|
||||
CrossAttention(
|
||||
query_dim=in_channels,
|
||||
cross_attention_dim=in_channels,
|
||||
heads=self.num_heads,
|
||||
dim_head=attn_num_head_channels,
|
||||
added_kv_proj_dim=cross_attention_dim,
|
||||
norm_num_groups=resnet_groups,
|
||||
bias=True,
|
||||
upcast_softmax=True,
|
||||
)
|
||||
)
|
||||
resnets.append(
|
||||
ResnetBlock2D(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
temb_channels=temb_channels,
|
||||
eps=resnet_eps,
|
||||
groups=resnet_groups,
|
||||
dropout=dropout,
|
||||
time_embedding_norm=resnet_time_scale_shift,
|
||||
non_linearity=resnet_act_fn,
|
||||
output_scale_factor=output_scale_factor,
|
||||
pre_norm=resnet_pre_norm,
|
||||
)
|
||||
)
|
||||
|
||||
self.attentions = nn.ModuleList(attentions)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):
|
||||
hidden_states = self.resnets[0](hidden_states, temb)
|
||||
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
||||
# attn
|
||||
residual = hidden_states
|
||||
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
|
||||
hidden_states = attn(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states.transpose(1, 2),
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
# resnet
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
return hidden_states
|
||||
@@ -431,7 +598,6 @@ class AttnDownBlock2D(nn.Module):
|
||||
resnet_groups: int = 32,
|
||||
resnet_pre_norm: bool = True,
|
||||
attn_num_head_channels=1,
|
||||
attention_type="default",
|
||||
output_scale_factor=1.0,
|
||||
downsample_padding=1,
|
||||
add_downsample=True,
|
||||
@@ -440,8 +606,6 @@ class AttnDownBlock2D(nn.Module):
|
||||
resnets = []
|
||||
attentions = []
|
||||
|
||||
self.attention_type = attention_type
|
||||
|
||||
for i in range(num_layers):
|
||||
in_channels = in_channels if i == 0 else out_channels
|
||||
resnets.append(
|
||||
@@ -514,7 +678,6 @@ class CrossAttnDownBlock2D(nn.Module):
|
||||
resnet_pre_norm: bool = True,
|
||||
attn_num_head_channels=1,
|
||||
cross_attention_dim=1280,
|
||||
attention_type="default",
|
||||
output_scale_factor=1.0,
|
||||
downsample_padding=1,
|
||||
add_downsample=True,
|
||||
@@ -528,7 +691,6 @@ class CrossAttnDownBlock2D(nn.Module):
|
||||
attentions = []
|
||||
|
||||
self.has_cross_attention = True
|
||||
self.attention_type = attention_type
|
||||
self.attn_num_head_channels = attn_num_head_channels
|
||||
|
||||
for i in range(num_layers):
|
||||
@@ -588,7 +750,8 @@ class CrossAttnDownBlock2D(nn.Module):
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
|
||||
def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):
|
||||
# TODO(Patrick, William) - attention mask is not used
|
||||
output_states = ()
|
||||
|
||||
for resnet, attn in zip(self.resnets, self.attentions):
|
||||
@@ -605,7 +768,9 @@ class CrossAttnDownBlock2D(nn.Module):
|
||||
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states
|
||||
create_custom_forward(attn, return_dict=False),
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
)[0]
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
@@ -847,7 +1012,6 @@ class AttnSkipDownBlock2D(nn.Module):
|
||||
resnet_act_fn: str = "swish",
|
||||
resnet_pre_norm: bool = True,
|
||||
attn_num_head_channels=1,
|
||||
attention_type="default",
|
||||
output_scale_factor=np.sqrt(2.0),
|
||||
downsample_padding=1,
|
||||
add_downsample=True,
|
||||
@@ -856,8 +1020,6 @@ class AttnSkipDownBlock2D(nn.Module):
|
||||
self.attentions = nn.ModuleList([])
|
||||
self.resnets = nn.ModuleList([])
|
||||
|
||||
self.attention_type = attention_type
|
||||
|
||||
for i in range(num_layers):
|
||||
in_channels = in_channels if i == 0 else out_channels
|
||||
self.resnets.append(
|
||||
@@ -1006,6 +1168,205 @@ class SkipDownBlock2D(nn.Module):
|
||||
return hidden_states, output_states, skip_sample
|
||||
|
||||
|
||||
class ResnetDownsampleBlock2D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
temb_channels: int,
|
||||
dropout: float = 0.0,
|
||||
num_layers: int = 1,
|
||||
resnet_eps: float = 1e-6,
|
||||
resnet_time_scale_shift: str = "default",
|
||||
resnet_act_fn: str = "swish",
|
||||
resnet_groups: int = 32,
|
||||
resnet_pre_norm: bool = True,
|
||||
output_scale_factor=1.0,
|
||||
add_downsample=True,
|
||||
):
|
||||
super().__init__()
|
||||
resnets = []
|
||||
|
||||
for i in range(num_layers):
|
||||
in_channels = in_channels if i == 0 else out_channels
|
||||
resnets.append(
|
||||
ResnetBlock2D(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
eps=resnet_eps,
|
||||
groups=resnet_groups,
|
||||
dropout=dropout,
|
||||
time_embedding_norm=resnet_time_scale_shift,
|
||||
non_linearity=resnet_act_fn,
|
||||
output_scale_factor=output_scale_factor,
|
||||
pre_norm=resnet_pre_norm,
|
||||
)
|
||||
)
|
||||
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
if add_downsample:
|
||||
self.downsamplers = nn.ModuleList(
|
||||
[
|
||||
ResnetBlock2D(
|
||||
in_channels=out_channels,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
eps=resnet_eps,
|
||||
groups=resnet_groups,
|
||||
dropout=dropout,
|
||||
time_embedding_norm=resnet_time_scale_shift,
|
||||
non_linearity=resnet_act_fn,
|
||||
output_scale_factor=output_scale_factor,
|
||||
pre_norm=resnet_pre_norm,
|
||||
down=True,
|
||||
)
|
||||
]
|
||||
)
|
||||
else:
|
||||
self.downsamplers = None
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, hidden_states, temb=None):
|
||||
output_states = ()
|
||||
|
||||
for resnet in self.resnets:
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
output_states += (hidden_states,)
|
||||
|
||||
if self.downsamplers is not None:
|
||||
for downsampler in self.downsamplers:
|
||||
hidden_states = downsampler(hidden_states, temb)
|
||||
|
||||
output_states += (hidden_states,)
|
||||
|
||||
return hidden_states, output_states
|
||||
|
||||
|
||||
class SimpleCrossAttnDownBlock2D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
temb_channels: int,
|
||||
dropout: float = 0.0,
|
||||
num_layers: int = 1,
|
||||
resnet_eps: float = 1e-6,
|
||||
resnet_time_scale_shift: str = "default",
|
||||
resnet_act_fn: str = "swish",
|
||||
resnet_groups: int = 32,
|
||||
resnet_pre_norm: bool = True,
|
||||
attn_num_head_channels=1,
|
||||
cross_attention_dim=1280,
|
||||
output_scale_factor=1.0,
|
||||
add_downsample=True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.has_cross_attention = True
|
||||
|
||||
resnets = []
|
||||
attentions = []
|
||||
|
||||
self.attn_num_head_channels = attn_num_head_channels
|
||||
self.num_heads = out_channels // self.attn_num_head_channels
|
||||
|
||||
for i in range(num_layers):
|
||||
in_channels = in_channels if i == 0 else out_channels
|
||||
resnets.append(
|
||||
ResnetBlock2D(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
eps=resnet_eps,
|
||||
groups=resnet_groups,
|
||||
dropout=dropout,
|
||||
time_embedding_norm=resnet_time_scale_shift,
|
||||
non_linearity=resnet_act_fn,
|
||||
output_scale_factor=output_scale_factor,
|
||||
pre_norm=resnet_pre_norm,
|
||||
)
|
||||
)
|
||||
attentions.append(
|
||||
CrossAttention(
|
||||
query_dim=out_channels,
|
||||
cross_attention_dim=out_channels,
|
||||
heads=self.num_heads,
|
||||
dim_head=attn_num_head_channels,
|
||||
added_kv_proj_dim=cross_attention_dim,
|
||||
norm_num_groups=resnet_groups,
|
||||
bias=True,
|
||||
upcast_softmax=True,
|
||||
)
|
||||
)
|
||||
self.attentions = nn.ModuleList(attentions)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
if add_downsample:
|
||||
self.downsamplers = nn.ModuleList(
|
||||
[
|
||||
ResnetBlock2D(
|
||||
in_channels=out_channels,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
eps=resnet_eps,
|
||||
groups=resnet_groups,
|
||||
dropout=dropout,
|
||||
time_embedding_norm=resnet_time_scale_shift,
|
||||
non_linearity=resnet_act_fn,
|
||||
output_scale_factor=output_scale_factor,
|
||||
pre_norm=resnet_pre_norm,
|
||||
down=True,
|
||||
)
|
||||
]
|
||||
)
|
||||
else:
|
||||
self.downsamplers = None
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):
|
||||
output_states = ()
|
||||
|
||||
for resnet, attn in zip(self.resnets, self.attentions):
|
||||
# resnet
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
# attn
|
||||
residual = hidden_states
|
||||
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
|
||||
hidden_states = attn(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states.transpose(1, 2),
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
output_states += (hidden_states,)
|
||||
|
||||
if self.downsamplers is not None:
|
||||
for downsampler in self.downsamplers:
|
||||
hidden_states = downsampler(hidden_states, temb)
|
||||
|
||||
output_states += (hidden_states,)
|
||||
|
||||
return hidden_states, output_states
|
||||
|
||||
|
||||
class AttnUpBlock2D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -1020,7 +1381,6 @@ class AttnUpBlock2D(nn.Module):
|
||||
resnet_act_fn: str = "swish",
|
||||
resnet_groups: int = 32,
|
||||
resnet_pre_norm: bool = True,
|
||||
attention_type="default",
|
||||
attn_num_head_channels=1,
|
||||
output_scale_factor=1.0,
|
||||
add_upsample=True,
|
||||
@@ -1029,8 +1389,6 @@ class AttnUpBlock2D(nn.Module):
|
||||
resnets = []
|
||||
attentions = []
|
||||
|
||||
self.attention_type = attention_type
|
||||
|
||||
for i in range(num_layers):
|
||||
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
||||
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
||||
@@ -1100,7 +1458,6 @@ class CrossAttnUpBlock2D(nn.Module):
|
||||
resnet_pre_norm: bool = True,
|
||||
attn_num_head_channels=1,
|
||||
cross_attention_dim=1280,
|
||||
attention_type="default",
|
||||
output_scale_factor=1.0,
|
||||
add_upsample=True,
|
||||
dual_cross_attention=False,
|
||||
@@ -1113,7 +1470,6 @@ class CrossAttnUpBlock2D(nn.Module):
|
||||
attentions = []
|
||||
|
||||
self.has_cross_attention = True
|
||||
self.attention_type = attention_type
|
||||
self.attn_num_head_channels = attn_num_head_channels
|
||||
|
||||
for i in range(num_layers):
|
||||
@@ -1176,7 +1532,9 @@ class CrossAttnUpBlock2D(nn.Module):
|
||||
temb=None,
|
||||
encoder_hidden_states=None,
|
||||
upsample_size=None,
|
||||
attention_mask=None,
|
||||
):
|
||||
# TODO(Patrick, William) - attention mask is not used
|
||||
for resnet, attn in zip(self.resnets, self.attentions):
|
||||
# pop res hidden states
|
||||
res_hidden_states = res_hidden_states_tuple[-1]
|
||||
@@ -1196,7 +1554,9 @@ class CrossAttnUpBlock2D(nn.Module):
|
||||
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states
|
||||
create_custom_forward(attn, return_dict=False),
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
)[0]
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
@@ -1418,7 +1778,6 @@ class AttnSkipUpBlock2D(nn.Module):
|
||||
resnet_act_fn: str = "swish",
|
||||
resnet_pre_norm: bool = True,
|
||||
attn_num_head_channels=1,
|
||||
attention_type="default",
|
||||
output_scale_factor=np.sqrt(2.0),
|
||||
upsample_padding=1,
|
||||
add_upsample=True,
|
||||
@@ -1427,8 +1786,6 @@ class AttnSkipUpBlock2D(nn.Module):
|
||||
self.attentions = nn.ModuleList([])
|
||||
self.resnets = nn.ModuleList([])
|
||||
|
||||
self.attention_type = attention_type
|
||||
|
||||
for i in range(num_layers):
|
||||
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
||||
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
||||
@@ -1608,3 +1965,213 @@ class SkipUpBlock2D(nn.Module):
|
||||
hidden_states = self.resnet_up(hidden_states, temb)
|
||||
|
||||
return hidden_states, skip_sample
|
||||
|
||||
|
||||
class ResnetUpsampleBlock2D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
prev_output_channel: int,
|
||||
out_channels: int,
|
||||
temb_channels: int,
|
||||
dropout: float = 0.0,
|
||||
num_layers: int = 1,
|
||||
resnet_eps: float = 1e-6,
|
||||
resnet_time_scale_shift: str = "default",
|
||||
resnet_act_fn: str = "swish",
|
||||
resnet_groups: int = 32,
|
||||
resnet_pre_norm: bool = True,
|
||||
output_scale_factor=1.0,
|
||||
add_upsample=True,
|
||||
):
|
||||
super().__init__()
|
||||
resnets = []
|
||||
|
||||
for i in range(num_layers):
|
||||
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
||||
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
||||
|
||||
resnets.append(
|
||||
ResnetBlock2D(
|
||||
in_channels=resnet_in_channels + res_skip_channels,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
eps=resnet_eps,
|
||||
groups=resnet_groups,
|
||||
dropout=dropout,
|
||||
time_embedding_norm=resnet_time_scale_shift,
|
||||
non_linearity=resnet_act_fn,
|
||||
output_scale_factor=output_scale_factor,
|
||||
pre_norm=resnet_pre_norm,
|
||||
)
|
||||
)
|
||||
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
if add_upsample:
|
||||
self.upsamplers = nn.ModuleList(
|
||||
[
|
||||
ResnetBlock2D(
|
||||
in_channels=out_channels,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
eps=resnet_eps,
|
||||
groups=resnet_groups,
|
||||
dropout=dropout,
|
||||
time_embedding_norm=resnet_time_scale_shift,
|
||||
non_linearity=resnet_act_fn,
|
||||
output_scale_factor=output_scale_factor,
|
||||
pre_norm=resnet_pre_norm,
|
||||
up=True,
|
||||
)
|
||||
]
|
||||
)
|
||||
else:
|
||||
self.upsamplers = None
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
|
||||
for resnet in self.resnets:
|
||||
# pop res hidden states
|
||||
res_hidden_states = res_hidden_states_tuple[-1]
|
||||
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
||||
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
if self.upsamplers is not None:
|
||||
for upsampler in self.upsamplers:
|
||||
hidden_states = upsampler(hidden_states, temb)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class SimpleCrossAttnUpBlock2D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
prev_output_channel: int,
|
||||
temb_channels: int,
|
||||
dropout: float = 0.0,
|
||||
num_layers: int = 1,
|
||||
resnet_eps: float = 1e-6,
|
||||
resnet_time_scale_shift: str = "default",
|
||||
resnet_act_fn: str = "swish",
|
||||
resnet_groups: int = 32,
|
||||
resnet_pre_norm: bool = True,
|
||||
attn_num_head_channels=1,
|
||||
cross_attention_dim=1280,
|
||||
output_scale_factor=1.0,
|
||||
add_upsample=True,
|
||||
):
|
||||
super().__init__()
|
||||
resnets = []
|
||||
attentions = []
|
||||
|
||||
self.has_cross_attention = True
|
||||
self.attn_num_head_channels = attn_num_head_channels
|
||||
|
||||
self.num_heads = out_channels // self.attn_num_head_channels
|
||||
|
||||
for i in range(num_layers):
|
||||
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
||||
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
||||
|
||||
resnets.append(
|
||||
ResnetBlock2D(
|
||||
in_channels=resnet_in_channels + res_skip_channels,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
eps=resnet_eps,
|
||||
groups=resnet_groups,
|
||||
dropout=dropout,
|
||||
time_embedding_norm=resnet_time_scale_shift,
|
||||
non_linearity=resnet_act_fn,
|
||||
output_scale_factor=output_scale_factor,
|
||||
pre_norm=resnet_pre_norm,
|
||||
)
|
||||
)
|
||||
attentions.append(
|
||||
CrossAttention(
|
||||
query_dim=out_channels,
|
||||
cross_attention_dim=out_channels,
|
||||
heads=self.num_heads,
|
||||
dim_head=attn_num_head_channels,
|
||||
added_kv_proj_dim=cross_attention_dim,
|
||||
norm_num_groups=resnet_groups,
|
||||
bias=True,
|
||||
upcast_softmax=True,
|
||||
)
|
||||
)
|
||||
self.attentions = nn.ModuleList(attentions)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
if add_upsample:
|
||||
self.upsamplers = nn.ModuleList(
|
||||
[
|
||||
ResnetBlock2D(
|
||||
in_channels=out_channels,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
eps=resnet_eps,
|
||||
groups=resnet_groups,
|
||||
dropout=dropout,
|
||||
time_embedding_norm=resnet_time_scale_shift,
|
||||
non_linearity=resnet_act_fn,
|
||||
output_scale_factor=output_scale_factor,
|
||||
pre_norm=resnet_pre_norm,
|
||||
up=True,
|
||||
)
|
||||
]
|
||||
)
|
||||
else:
|
||||
self.upsamplers = None
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
res_hidden_states_tuple,
|
||||
temb=None,
|
||||
encoder_hidden_states=None,
|
||||
upsample_size=None,
|
||||
attention_mask=None,
|
||||
):
|
||||
for resnet, attn in zip(self.resnets, self.attentions):
|
||||
# resnet
|
||||
# pop res hidden states
|
||||
res_hidden_states = res_hidden_states_tuple[-1]
|
||||
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
||||
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
# attn
|
||||
residual = hidden_states
|
||||
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
|
||||
hidden_states = attn(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states.transpose(1, 2),
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
if self.upsamplers is not None:
|
||||
for upsampler in self.upsamplers:
|
||||
hidden_states = upsampler(hidden_states, temb)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@@ -27,6 +27,7 @@ from .unet_2d_blocks import (
|
||||
CrossAttnUpBlock2D,
|
||||
DownBlock2D,
|
||||
UNetMidBlock2DCrossAttn,
|
||||
UNetMidBlock2DSimpleCrossAttn,
|
||||
UpBlock2D,
|
||||
get_down_block,
|
||||
get_up_block,
|
||||
@@ -66,6 +67,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
|
||||
freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
|
||||
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
|
||||
The tuple of downsample blocks to use.
|
||||
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
|
||||
The mid block type. Choose from `UNetMidBlock2DCrossAttn` or `UNetMidBlock2DSimpleCrossAttn`.
|
||||
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`):
|
||||
The tuple of upsample blocks to use.
|
||||
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
|
||||
@@ -78,6 +81,10 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
|
||||
norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
|
||||
cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features.
|
||||
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
|
||||
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
|
||||
for resnet blocks, see [`~models.resnet.ResnetBlock2D`]. Choose from `default` or `scale_shift`.
|
||||
class_embed_type (`str`, *optional*, defaults to None): The type of class embedding to use which is ultimately
|
||||
summed with the time embeddings. Choose from `None`, `"timestep"`, or `"identity"`.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
@@ -97,6 +104,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
|
||||
"CrossAttnDownBlock2D",
|
||||
"DownBlock2D",
|
||||
),
|
||||
mid_block_type: str = "UNetMidBlock2DCrossAttn",
|
||||
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
|
||||
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
||||
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
||||
@@ -110,8 +118,10 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
|
||||
attention_head_dim: Union[int, Tuple[int]] = 8,
|
||||
dual_cross_attention: bool = False,
|
||||
use_linear_projection: bool = False,
|
||||
class_embed_type: Optional[str] = None,
|
||||
num_class_embeds: Optional[int] = None,
|
||||
upcast_attention: bool = False,
|
||||
resnet_time_scale_shift: str = "default",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -128,8 +138,14 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
|
||||
self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
||||
|
||||
# class embedding
|
||||
if num_class_embeds is not None:
|
||||
if class_embed_type is None and num_class_embeds is not None:
|
||||
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
|
||||
elif class_embed_type == "timestep":
|
||||
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
||||
elif class_embed_type == "identity":
|
||||
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
|
||||
else:
|
||||
self.class_embedding = None
|
||||
|
||||
self.down_blocks = nn.ModuleList([])
|
||||
self.mid_block = None
|
||||
@@ -165,24 +181,40 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
|
||||
use_linear_projection=use_linear_projection,
|
||||
only_cross_attention=only_cross_attention[i],
|
||||
upcast_attention=upcast_attention,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
)
|
||||
self.down_blocks.append(down_block)
|
||||
|
||||
# mid
|
||||
self.mid_block = UNetMidBlock2DCrossAttn(
|
||||
in_channels=block_out_channels[-1],
|
||||
temb_channels=time_embed_dim,
|
||||
resnet_eps=norm_eps,
|
||||
resnet_act_fn=act_fn,
|
||||
output_scale_factor=mid_block_scale_factor,
|
||||
resnet_time_scale_shift="default",
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
attn_num_head_channels=attention_head_dim[-1],
|
||||
resnet_groups=norm_num_groups,
|
||||
dual_cross_attention=dual_cross_attention,
|
||||
use_linear_projection=use_linear_projection,
|
||||
upcast_attention=upcast_attention,
|
||||
)
|
||||
if mid_block_type == "UNetMidBlock2DCrossAttn":
|
||||
self.mid_block = UNetMidBlock2DCrossAttn(
|
||||
in_channels=block_out_channels[-1],
|
||||
temb_channels=time_embed_dim,
|
||||
resnet_eps=norm_eps,
|
||||
resnet_act_fn=act_fn,
|
||||
output_scale_factor=mid_block_scale_factor,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
attn_num_head_channels=attention_head_dim[-1],
|
||||
resnet_groups=norm_num_groups,
|
||||
dual_cross_attention=dual_cross_attention,
|
||||
use_linear_projection=use_linear_projection,
|
||||
upcast_attention=upcast_attention,
|
||||
)
|
||||
elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
|
||||
self.mid_block = UNetMidBlock2DSimpleCrossAttn(
|
||||
in_channels=block_out_channels[-1],
|
||||
temb_channels=time_embed_dim,
|
||||
resnet_eps=norm_eps,
|
||||
resnet_act_fn=act_fn,
|
||||
output_scale_factor=mid_block_scale_factor,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
attn_num_head_channels=attention_head_dim[-1],
|
||||
resnet_groups=norm_num_groups,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"unknown mid_block_type : {mid_block_type}")
|
||||
|
||||
# count how many layers upsample the images
|
||||
self.num_upsamplers = 0
|
||||
@@ -223,6 +255,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
|
||||
use_linear_projection=use_linear_projection,
|
||||
only_cross_attention=only_cross_attention[i],
|
||||
upcast_attention=upcast_attention,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
)
|
||||
self.up_blocks.append(up_block)
|
||||
prev_output_channel = output_channel
|
||||
@@ -307,6 +340,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
|
||||
timestep: Union[torch.Tensor, float, int],
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
class_labels: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[UNet2DConditionOutput, Tuple]:
|
||||
r"""
|
||||
@@ -336,6 +370,11 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
|
||||
logger.info("Forward upsample size to force interpolation output size.")
|
||||
forward_upsample_size = True
|
||||
|
||||
# prepare attention_mask
|
||||
if attention_mask is not None:
|
||||
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
||||
attention_mask = attention_mask.unsqueeze(1)
|
||||
|
||||
# 0. center input if necessary
|
||||
if self.config.center_input_sample:
|
||||
sample = 2 * sample - 1.0
|
||||
@@ -365,9 +404,13 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
|
||||
t_emb = t_emb.to(dtype=self.dtype)
|
||||
emb = self.time_embedding(t_emb)
|
||||
|
||||
if self.config.num_class_embeds is not None:
|
||||
if self.class_embedding is not None:
|
||||
if class_labels is None:
|
||||
raise ValueError("class_labels should be provided when num_class_embeds > 0")
|
||||
|
||||
if self.config.class_embed_type == "timestep":
|
||||
class_labels = self.time_proj(class_labels)
|
||||
|
||||
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
|
||||
emb = emb + class_emb
|
||||
|
||||
@@ -382,6 +425,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
|
||||
hidden_states=sample,
|
||||
temb=emb,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
else:
|
||||
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
||||
@@ -389,7 +433,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
|
||||
down_block_res_samples += res_samples
|
||||
|
||||
# 4. mid
|
||||
sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states)
|
||||
sample = self.mid_block(
|
||||
sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
|
||||
)
|
||||
|
||||
# 5. up
|
||||
for i, upsample_block in enumerate(self.up_blocks):
|
||||
@@ -410,6 +456,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
|
||||
res_hidden_states_tuple=res_samples,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
upsample_size=upsample_size,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
else:
|
||||
sample = upsample_block(
|
||||
|
||||
@@ -113,7 +113,6 @@ from diffusers import StableDiffusionImg2ImgPipeline
|
||||
device = "cuda"
|
||||
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5",
|
||||
revision="fp16",
|
||||
torch_dtype=torch.float16,
|
||||
).to(device)
|
||||
|
||||
@@ -161,7 +160,6 @@ mask_image = download_image(mask_url).resize((512, 512))
|
||||
|
||||
pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-inpainting",
|
||||
revision="fp16",
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
pipe = pipe.to("cuda")
|
||||
|
||||
@@ -53,6 +53,7 @@ else:
|
||||
StableDiffusionUpscalePipeline,
|
||||
)
|
||||
from .stable_diffusion_safe import StableDiffusionPipelineSafe
|
||||
from .unclip import UnCLIPPipeline
|
||||
from .versatile_diffusion import (
|
||||
VersatileDiffusionDualGuidedPipeline,
|
||||
VersatileDiffusionImageVariationPipeline,
|
||||
|
||||
@@ -379,12 +379,24 @@ class AltDiffusionPipeline(DiffusionPipeline):
|
||||
|
||||
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
|
||||
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
if latents is None:
|
||||
if device.type == "mps":
|
||||
# randn does not work reproducibly on mps
|
||||
latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
|
||||
rand_device = "cpu" if device.type == "mps" else device
|
||||
|
||||
if isinstance(generator, list):
|
||||
shape = (1,) + shape[1:]
|
||||
latents = [
|
||||
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype)
|
||||
for i in range(batch_size)
|
||||
]
|
||||
latents = torch.cat(latents, dim=0).to(device)
|
||||
else:
|
||||
latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
|
||||
latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
|
||||
else:
|
||||
if latents.shape != shape:
|
||||
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
|
||||
@@ -405,7 +417,7 @@ class AltDiffusionPipeline(DiffusionPipeline):
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
@@ -440,8 +452,8 @@ class AltDiffusionPipeline(DiffusionPipeline):
|
||||
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
||||
[`schedulers.DDIMScheduler`], will be ignored for others.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
||||
deterministic.
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
|
||||
@@ -396,8 +396,22 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
|
||||
def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
init_latent_dist = self.vae.encode(image).latent_dist
|
||||
init_latents = init_latent_dist.sample(generator=generator)
|
||||
|
||||
batch_size = batch_size * num_images_per_prompt
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
if isinstance(generator, list):
|
||||
init_latents = [
|
||||
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
|
||||
]
|
||||
init_latents = torch.cat(init_latents, dim=0)
|
||||
else:
|
||||
init_latents = self.vae.encode(image).latent_dist.sample(generator)
|
||||
|
||||
init_latents = 0.18215 * init_latents
|
||||
|
||||
if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
|
||||
@@ -410,16 +424,24 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
)
|
||||
deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False)
|
||||
additional_image_per_prompt = batch_size // init_latents.shape[0]
|
||||
init_latents = torch.cat([init_latents] * additional_image_per_prompt * num_images_per_prompt, dim=0)
|
||||
init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0)
|
||||
elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
|
||||
raise ValueError(
|
||||
f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
|
||||
)
|
||||
else:
|
||||
init_latents = torch.cat([init_latents] * num_images_per_prompt, dim=0)
|
||||
init_latents = torch.cat([init_latents], dim=0)
|
||||
|
||||
# add noise to latents using the timesteps
|
||||
noise = torch.randn(init_latents.shape, generator=generator, device=device, dtype=dtype)
|
||||
rand_device = "cpu" if device.type == "mps" else device
|
||||
shape = init_latents.shape
|
||||
if isinstance(generator, list):
|
||||
shape = (1,) + shape[1:]
|
||||
noise = [
|
||||
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) for i in range(batch_size)
|
||||
]
|
||||
noise = torch.cat(noise, dim=0).to(device)
|
||||
else:
|
||||
noise = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
|
||||
|
||||
# get latents
|
||||
init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
|
||||
@@ -438,7 +460,7 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
eta: Optional[float] = 0.0,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
@@ -478,8 +500,8 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
||||
[`schedulers.DDIMScheduler`], will be ignored for others.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
||||
deterministic.
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import Optional, Tuple, Union
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -45,7 +45,7 @@ class DanceDiffusionPipeline(DiffusionPipeline):
|
||||
self,
|
||||
batch_size: int = 1,
|
||||
num_inference_steps: int = 100,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
audio_length_in_s: Optional[float] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[AudioPipelineOutput, Tuple]:
|
||||
@@ -57,8 +57,8 @@ class DanceDiffusionPipeline(DiffusionPipeline):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality audio sample at
|
||||
the expense of slower inference.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
||||
deterministic.
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
audio_length_in_s (`float`, *optional*, defaults to `self.unet.config.sample_size/self.unet.config.sample_rate`):
|
||||
The length of the generated audio sample in seconds. Note that the output of the pipeline, *i.e.*
|
||||
`sample_size`, will be `audio_length_in_s` * `self.unet.sample_rate`.
|
||||
@@ -94,9 +94,23 @@ class DanceDiffusionPipeline(DiffusionPipeline):
|
||||
sample_size = int(sample_size)
|
||||
|
||||
dtype = next(iter(self.unet.parameters())).dtype
|
||||
audio = torch.randn(
|
||||
(batch_size, self.unet.in_channels, sample_size), generator=generator, device=self.device, dtype=dtype
|
||||
)
|
||||
shape = (batch_size, self.unet.in_channels, sample_size)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
rand_device = "cpu" if self.device.type == "mps" else self.device
|
||||
if isinstance(generator, list):
|
||||
shape = (1,) + shape[1:]
|
||||
audio = [
|
||||
torch.randn(shape, generator=generator[i], device=rand_device, dtype=self.unet.dtype)
|
||||
for i in range(batch_size)
|
||||
]
|
||||
audio = torch.cat(audio, dim=0).to(self.device)
|
||||
else:
|
||||
audio = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(self.device)
|
||||
|
||||
# set step values
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=audio.device)
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional, Tuple, Union
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -40,7 +40,7 @@ class DDIMPipeline(DiffusionPipeline):
|
||||
def __call__(
|
||||
self,
|
||||
batch_size: int = 1,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
eta: float = 0.0,
|
||||
num_inference_steps: int = 50,
|
||||
use_clipped_model_output: Optional[bool] = None,
|
||||
@@ -52,8 +52,8 @@ class DDIMPipeline(DiffusionPipeline):
|
||||
batch_size (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
||||
deterministic.
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
eta (`float`, *optional*, defaults to 0.0):
|
||||
The eta parameter which controls the scale of the variance (0 is DDIM and 1 is one type of DDPM).
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
@@ -74,7 +74,12 @@ class DDIMPipeline(DiffusionPipeline):
|
||||
generated images.
|
||||
"""
|
||||
|
||||
if generator is not None and generator.device.type != self.device.type and self.device.type != "mps":
|
||||
if (
|
||||
generator is not None
|
||||
and isinstance(generator, torch.Generator)
|
||||
and generator.device.type != self.device.type
|
||||
and self.device.type != "mps"
|
||||
):
|
||||
message = (
|
||||
f"The `generator` device is `{generator.device}` and does not match the pipeline "
|
||||
f"device `{self.device}`, so the `generator` will be ignored. "
|
||||
@@ -93,12 +98,23 @@ class DDIMPipeline(DiffusionPipeline):
|
||||
else:
|
||||
image_shape = (batch_size, self.unet.in_channels, *self.unet.sample_size)
|
||||
|
||||
if self.device.type == "mps":
|
||||
# randn does not work reproducibly on mps
|
||||
image = torch.randn(image_shape, generator=generator, dtype=self.unet.dtype)
|
||||
image = image.to(self.device)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
rand_device = "cpu" if self.device.type == "mps" else self.device
|
||||
if isinstance(generator, list):
|
||||
shape = (1,) + image_shape[1:]
|
||||
image = [
|
||||
torch.randn(shape, generator=generator[i], device=rand_device, dtype=self.unet.dtype)
|
||||
for i in range(batch_size)
|
||||
]
|
||||
image = torch.cat(image, dim=0).to(self.device)
|
||||
else:
|
||||
image = torch.randn(image_shape, generator=generator, device=self.device, dtype=self.unet.dtype)
|
||||
image = torch.randn(image_shape, generator=generator, device=rand_device, dtype=self.unet.dtype)
|
||||
image = image.to(self.device)
|
||||
|
||||
# set step values
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import Optional, Tuple, Union
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -42,7 +42,7 @@ class DDPMPipeline(DiffusionPipeline):
|
||||
def __call__(
|
||||
self,
|
||||
batch_size: int = 1,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
num_inference_steps: int = 1000,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
@@ -53,8 +53,8 @@ class DDPMPipeline(DiffusionPipeline):
|
||||
batch_size (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
||||
deterministic.
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
num_inference_steps (`int`, *optional*, defaults to 1000):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
|
||||
@@ -71,7 +71,8 @@ class LDMTextToImagePipeline(DiffusionPipeline):
|
||||
num_inference_steps: Optional[int] = 50,
|
||||
guidance_scale: Optional[float] = 1.0,
|
||||
eta: Optional[float] = 0.0,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
**kwargs,
|
||||
@@ -94,8 +95,12 @@ class LDMTextToImagePipeline(DiffusionPipeline):
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt` at
|
||||
the, usually at the expense of lower image quality.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
||||
deterministic.
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
@@ -123,17 +128,41 @@ class LDMTextToImagePipeline(DiffusionPipeline):
|
||||
|
||||
# get unconditional embeddings for classifier free guidance
|
||||
if guidance_scale != 1.0:
|
||||
uncond_input = self.tokenizer([""] * batch_size, padding="max_length", max_length=77, return_tensors="pt")
|
||||
uncond_input = self.tokenizer(
|
||||
[""] * batch_size, padding="max_length", max_length=77, truncation=True, return_tensors="pt"
|
||||
)
|
||||
uncond_embeddings = self.bert(uncond_input.input_ids.to(self.device))[0]
|
||||
|
||||
# get prompt text embeddings
|
||||
text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt")
|
||||
text_input = self.tokenizer(prompt, padding="max_length", max_length=77, truncation=True, return_tensors="pt")
|
||||
text_embeddings = self.bert(text_input.input_ids.to(self.device))[0]
|
||||
|
||||
latents = torch.randn(
|
||||
(batch_size, self.unet.in_channels, height // 8, width // 8),
|
||||
generator=generator,
|
||||
)
|
||||
# get the initial random noise unless the user supplied it
|
||||
latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
if latents is None:
|
||||
rand_device = "cpu" if self.device.type == "mps" else self.device
|
||||
|
||||
if isinstance(generator, list):
|
||||
latents_shape = (1,) + latents_shape[1:]
|
||||
latents = [
|
||||
torch.randn(latents_shape, generator=generator[i], device=rand_device, dtype=text_embeddings.dtype)
|
||||
for i in range(batch_size)
|
||||
]
|
||||
latents = torch.cat(latents, dim=0)
|
||||
else:
|
||||
latents = torch.randn(
|
||||
latents_shape, generator=generator, device=rand_device, dtype=text_embeddings.dtype
|
||||
)
|
||||
latents = latents.to(self.device)
|
||||
else:
|
||||
if latents.shape != latents_shape:
|
||||
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
|
||||
latents = latents.to(self.device)
|
||||
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import inspect
|
||||
from typing import Optional, Tuple, Union
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -70,7 +70,7 @@ class LDMSuperResolutionPipeline(DiffusionPipeline):
|
||||
batch_size: Optional[int] = 1,
|
||||
num_inference_steps: Optional[int] = 100,
|
||||
eta: Optional[float] = 0.0,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
**kwargs,
|
||||
@@ -89,8 +89,8 @@ class LDMSuperResolutionPipeline(DiffusionPipeline):
|
||||
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
||||
[`schedulers.DDIMScheduler`], will be ignored for others.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
||||
deterministic.
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from typing import Optional, Tuple, Union
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -43,7 +43,7 @@ class LDMPipeline(DiffusionPipeline):
|
||||
def __call__(
|
||||
self,
|
||||
batch_size: int = 1,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
eta: float = 0.0,
|
||||
num_inference_steps: int = 50,
|
||||
output_type: Optional[str] = "pil",
|
||||
@@ -55,8 +55,8 @@ class LDMPipeline(DiffusionPipeline):
|
||||
batch_size (`int`, *optional*, defaults to 1):
|
||||
Number of images to generate.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
||||
deterministic.
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
|
||||
@@ -291,12 +291,24 @@ class PaintByExamplePipeline(DiffusionPipeline):
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
|
||||
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
|
||||
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
if latents is None:
|
||||
if device.type == "mps":
|
||||
# randn does not work reproducibly on mps
|
||||
latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
|
||||
rand_device = "cpu" if device.type == "mps" else device
|
||||
|
||||
if isinstance(generator, list):
|
||||
shape = (1,) + shape[1:]
|
||||
latents = [
|
||||
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype)
|
||||
for i in range(batch_size)
|
||||
]
|
||||
latents = torch.cat(latents, dim=0).to(device)
|
||||
else:
|
||||
latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
|
||||
latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
|
||||
else:
|
||||
if latents.shape != shape:
|
||||
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
|
||||
@@ -321,7 +333,14 @@ class PaintByExamplePipeline(DiffusionPipeline):
|
||||
masked_image = masked_image.to(device=device, dtype=dtype)
|
||||
|
||||
# encode the mask image into latents space so we can concatenate it to the latents
|
||||
masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(generator=generator)
|
||||
if isinstance(generator, list):
|
||||
masked_image_latents = [
|
||||
self.vae.encode(masked_image[i : i + 1]).latent_dist.sample(generator=generator[i])
|
||||
for i in range(batch_size)
|
||||
]
|
||||
masked_image_latents = torch.cat(masked_image_latents, dim=0)
|
||||
else:
|
||||
masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(generator=generator)
|
||||
masked_image_latents = 0.18215 * masked_image_latents
|
||||
|
||||
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
|
||||
@@ -390,7 +409,7 @@ class PaintByExamplePipeline(DiffusionPipeline):
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
@@ -433,8 +452,8 @@ class PaintByExamplePipeline(DiffusionPipeline):
|
||||
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
||||
[`schedulers.DDIMScheduler`], will be ignored for others.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
||||
deterministic.
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import Optional, Tuple, Union
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -45,7 +45,7 @@ class PNDMPipeline(DiffusionPipeline):
|
||||
self,
|
||||
batch_size: int = 1,
|
||||
num_inference_steps: int = 50,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
**kwargs,
|
||||
|
||||
@@ -13,33 +13,61 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import Optional, Tuple, Union
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import PIL
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from ...models import UNet2DModel
|
||||
from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||
from ...schedulers import RePaintScheduler
|
||||
from ...utils import PIL_INTERPOLATION, deprecate, logging
|
||||
|
||||
|
||||
def _preprocess_image(image: PIL.Image.Image):
|
||||
image = np.array(image.convert("RGB"))
|
||||
image = image[None].transpose(0, 3, 1, 2)
|
||||
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess
|
||||
def _preprocess_image(image: Union[List, PIL.Image.Image, torch.Tensor]):
|
||||
if isinstance(image, torch.Tensor):
|
||||
return image
|
||||
elif isinstance(image, PIL.Image.Image):
|
||||
image = [image]
|
||||
|
||||
if isinstance(image[0], PIL.Image.Image):
|
||||
w, h = image[0].size
|
||||
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
|
||||
|
||||
image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image]
|
||||
image = np.concatenate(image, axis=0)
|
||||
image = np.array(image).astype(np.float32) / 255.0
|
||||
image = image.transpose(0, 3, 1, 2)
|
||||
image = 2.0 * image - 1.0
|
||||
image = torch.from_numpy(image)
|
||||
elif isinstance(image[0], torch.Tensor):
|
||||
image = torch.cat(image, dim=0)
|
||||
return image
|
||||
|
||||
|
||||
def _preprocess_mask(mask: PIL.Image.Image):
|
||||
mask = np.array(mask.convert("L"))
|
||||
mask = mask.astype(np.float32) / 255.0
|
||||
mask = mask[None, None]
|
||||
mask[mask < 0.5] = 0
|
||||
mask[mask >= 0.5] = 1
|
||||
mask = torch.from_numpy(mask)
|
||||
def _preprocess_mask(mask: Union[List, PIL.Image.Image, torch.Tensor]):
|
||||
if isinstance(mask, torch.Tensor):
|
||||
return mask
|
||||
elif isinstance(mask, PIL.Image.Image):
|
||||
mask = [mask]
|
||||
|
||||
if isinstance(mask[0], PIL.Image.Image):
|
||||
w, h = mask[0].size
|
||||
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
|
||||
mask = [np.array(m.convert("L").resize((w, h), resample=PIL_INTERPOLATION["nearest"]))[None, :] for m in mask]
|
||||
mask = np.concatenate(mask, axis=0)
|
||||
mask = mask.astype(np.float32) / 255.0
|
||||
mask[mask < 0.5] = 0
|
||||
mask[mask >= 0.5] = 1
|
||||
mask = torch.from_numpy(mask)
|
||||
elif isinstance(mask[0], torch.Tensor):
|
||||
mask = torch.cat(mask, dim=0)
|
||||
return mask
|
||||
|
||||
|
||||
@@ -54,19 +82,20 @@ class RePaintPipeline(DiffusionPipeline):
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
original_image: Union[torch.FloatTensor, PIL.Image.Image],
|
||||
mask_image: Union[torch.FloatTensor, PIL.Image.Image],
|
||||
image: Union[torch.Tensor, PIL.Image.Image],
|
||||
mask_image: Union[torch.Tensor, PIL.Image.Image],
|
||||
num_inference_steps: int = 250,
|
||||
eta: float = 0.0,
|
||||
jump_length: int = 10,
|
||||
jump_n_sample: int = 10,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
**kwargs,
|
||||
) -> Union[ImagePipelineOutput, Tuple]:
|
||||
r"""
|
||||
Args:
|
||||
original_image (`torch.FloatTensor` or `PIL.Image.Image`):
|
||||
image (`torch.FloatTensor` or `PIL.Image.Image`):
|
||||
The original image to inpaint on.
|
||||
mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
|
||||
The mask_image where 0.0 values define which part of the original image to inpaint (change).
|
||||
@@ -83,8 +112,8 @@ class RePaintPipeline(DiffusionPipeline):
|
||||
The number of times we will make forward time jump for a given chosen time sample. Take a look at
|
||||
Figure 9 and 10 in https://arxiv.org/pdf/2201.09865.pdf.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
||||
deterministic.
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
@@ -97,27 +126,44 @@ class RePaintPipeline(DiffusionPipeline):
|
||||
generated images.
|
||||
"""
|
||||
|
||||
if not isinstance(original_image, torch.FloatTensor):
|
||||
original_image = _preprocess_image(original_image)
|
||||
original_image = original_image.to(self.device)
|
||||
if not isinstance(mask_image, torch.FloatTensor):
|
||||
mask_image = _preprocess_mask(mask_image)
|
||||
mask_image = mask_image.to(self.device)
|
||||
message = "Please use `image` instead of `original_image`."
|
||||
original_image = deprecate("original_image", "0.15.0", message, take_from=kwargs)
|
||||
original_image = original_image or image
|
||||
|
||||
original_image = _preprocess_image(original_image)
|
||||
original_image = original_image.to(device=self.device, dtype=self.unet.dtype)
|
||||
mask_image = _preprocess_mask(mask_image)
|
||||
mask_image = mask_image.to(device=self.device, dtype=self.unet.dtype)
|
||||
|
||||
batch_size = original_image.shape[0]
|
||||
|
||||
# sample gaussian noise to begin the loop
|
||||
image = torch.randn(
|
||||
original_image.shape,
|
||||
generator=generator,
|
||||
device=self.device,
|
||||
)
|
||||
image = image.to(self.device)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
rand_device = "cpu" if self.device.type == "mps" else self.device
|
||||
image_shape = original_image.shape
|
||||
if isinstance(generator, list):
|
||||
shape = (1,) + image_shape[1:]
|
||||
image = [
|
||||
torch.randn(shape, generator=generator[i], device=rand_device, dtype=self.unet.dtype)
|
||||
for i in range(batch_size)
|
||||
]
|
||||
image = torch.cat(image, dim=0).to(self.device)
|
||||
else:
|
||||
image = torch.randn(image_shape, generator=generator, device=rand_device, dtype=self.unet.dtype)
|
||||
image = image.to(self.device)
|
||||
|
||||
# set step values
|
||||
self.scheduler.set_timesteps(num_inference_steps, jump_length, jump_n_sample, self.device)
|
||||
self.scheduler.eta = eta
|
||||
|
||||
t_last = self.scheduler.timesteps[0] + 1
|
||||
for i, t in enumerate(tqdm(self.scheduler.timesteps)):
|
||||
generator = generator[0] if isinstance(generator, list) else generator
|
||||
for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
|
||||
if t < t_last:
|
||||
# predict the noise residual
|
||||
model_output = self.unet(image, t).sample
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional, Tuple, Union
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -41,7 +41,7 @@ class ScoreSdeVePipeline(DiffusionPipeline):
|
||||
self,
|
||||
batch_size: int = 1,
|
||||
num_inference_steps: int = 2000,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
**kwargs,
|
||||
@@ -51,8 +51,8 @@ class ScoreSdeVePipeline(DiffusionPipeline):
|
||||
batch_size (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
||||
deterministic.
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
|
||||
@@ -435,8 +435,22 @@ class CycleDiffusionPipeline(DiffusionPipeline):
|
||||
|
||||
def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
init_latent_dist = self.vae.encode(image).latent_dist
|
||||
init_latents = init_latent_dist.sample(generator=generator)
|
||||
|
||||
batch_size = image.shape[0]
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
if isinstance(generator, list):
|
||||
init_latents = [
|
||||
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
|
||||
]
|
||||
init_latents = torch.cat(init_latents, dim=0)
|
||||
else:
|
||||
init_latents = self.vae.encode(image).latent_dist.sample(generator)
|
||||
|
||||
init_latents = 0.18215 * init_latents
|
||||
|
||||
if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
|
||||
@@ -458,7 +472,16 @@ class CycleDiffusionPipeline(DiffusionPipeline):
|
||||
init_latents = torch.cat([init_latents] * num_images_per_prompt, dim=0)
|
||||
|
||||
# add noise to latents using the timestep
|
||||
noise = torch.randn(init_latents.shape, generator=generator, device=device, dtype=dtype)
|
||||
rand_device = "cpu" if device.type == "mps" else device
|
||||
shape = init_latents.shape
|
||||
if isinstance(generator, list):
|
||||
shape = (1,) + shape[1:]
|
||||
noise = [
|
||||
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) for i in range(batch_size)
|
||||
]
|
||||
noise = torch.cat(noise, dim=0).to(device)
|
||||
else:
|
||||
noise = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
|
||||
|
||||
# get latents
|
||||
clean_latents = init_latents
|
||||
@@ -479,7 +502,7 @@ class CycleDiffusionPipeline(DiffusionPipeline):
|
||||
source_guidance_scale: Optional[float] = 1,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
eta: Optional[float] = 0.1,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
@@ -519,8 +542,8 @@ class CycleDiffusionPipeline(DiffusionPipeline):
|
||||
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
||||
[`schedulers.DDIMScheduler`], will be ignored for others.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
||||
deterministic.
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
|
||||
@@ -309,13 +309,10 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
||||
deterministic.
|
||||
latents (`jnp.array`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
generation. Can be used to tweak the same generation with different prompts. tensor will ge generated
|
||||
by sampling using the supplied random `generator`.
|
||||
jit (`bool`, defaults to `False`):
|
||||
Whether to run `pmap` versions of the generation and safety scoring functions. NOTE: This argument
|
||||
exists because `__call__` is not yet end-to-end pmap-able. It will be removed in a future release.
|
||||
|
||||
@@ -90,7 +90,7 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
safety_checker: OnnxRuntimeModel
|
||||
feature_extractor: CLIPFeatureExtractor
|
||||
|
||||
_optional_components = ["safety_checker"]
|
||||
_optional_components = ["safety_checker", "feature_extractor"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -338,7 +338,7 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
# set timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
image = preprocess(image)
|
||||
image = preprocess(image).cpu().numpy()
|
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
|
||||
@@ -378,12 +378,24 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
||||
|
||||
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
|
||||
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
if latents is None:
|
||||
if device.type == "mps":
|
||||
# randn does not work reproducibly on mps
|
||||
latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
|
||||
rand_device = "cpu" if device.type == "mps" else device
|
||||
|
||||
if isinstance(generator, list):
|
||||
shape = (1,) + shape[1:]
|
||||
latents = [
|
||||
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype)
|
||||
for i in range(batch_size)
|
||||
]
|
||||
latents = torch.cat(latents, dim=0).to(device)
|
||||
else:
|
||||
latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
|
||||
latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
|
||||
else:
|
||||
if latents.shape != shape:
|
||||
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
|
||||
@@ -404,7 +416,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
@@ -439,8 +451,8 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
||||
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
||||
[`schedulers.DDIMScheduler`], will be ignored for others.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
||||
deterministic.
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
|
||||
@@ -345,8 +345,22 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline):
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.prepare_latents
|
||||
def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
init_latent_dist = self.vae.encode(image).latent_dist
|
||||
init_latents = init_latent_dist.sample(generator=generator)
|
||||
|
||||
batch_size = batch_size * num_images_per_prompt
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
if isinstance(generator, list):
|
||||
init_latents = [
|
||||
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
|
||||
]
|
||||
init_latents = torch.cat(init_latents, dim=0)
|
||||
else:
|
||||
init_latents = self.vae.encode(image).latent_dist.sample(generator)
|
||||
|
||||
init_latents = 0.18215 * init_latents
|
||||
|
||||
if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
|
||||
@@ -359,16 +373,24 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline):
|
||||
)
|
||||
deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False)
|
||||
additional_image_per_prompt = batch_size // init_latents.shape[0]
|
||||
init_latents = torch.cat([init_latents] * additional_image_per_prompt * num_images_per_prompt, dim=0)
|
||||
init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0)
|
||||
elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
|
||||
raise ValueError(
|
||||
f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
|
||||
)
|
||||
else:
|
||||
init_latents = torch.cat([init_latents] * num_images_per_prompt, dim=0)
|
||||
init_latents = torch.cat([init_latents], dim=0)
|
||||
|
||||
# add noise to latents using the timesteps
|
||||
noise = torch.randn(init_latents.shape, generator=generator, device=device, dtype=dtype)
|
||||
rand_device = "cpu" if device.type == "mps" else device
|
||||
shape = init_latents.shape
|
||||
if isinstance(generator, list):
|
||||
shape = (1,) + shape[1:]
|
||||
noise = [
|
||||
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) for i in range(batch_size)
|
||||
]
|
||||
noise = torch.cat(noise, dim=0).to(device)
|
||||
else:
|
||||
noise = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
|
||||
|
||||
# get latents
|
||||
init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
|
||||
@@ -429,7 +451,7 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline):
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
eta: Optional[float] = 0.0,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
@@ -468,8 +490,8 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline):
|
||||
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
||||
[`schedulers.DDIMScheduler`], will be ignored for others.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
||||
deterministic.
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
|
||||
@@ -258,12 +258,24 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline):
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
|
||||
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
|
||||
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
if latents is None:
|
||||
if device.type == "mps":
|
||||
# randn does not work reproducibly on mps
|
||||
latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
|
||||
rand_device = "cpu" if device.type == "mps" else device
|
||||
|
||||
if isinstance(generator, list):
|
||||
shape = (1,) + shape[1:]
|
||||
latents = [
|
||||
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype)
|
||||
for i in range(batch_size)
|
||||
]
|
||||
latents = torch.cat(latents, dim=0).to(device)
|
||||
else:
|
||||
latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
|
||||
latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
|
||||
else:
|
||||
if latents.shape != shape:
|
||||
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
|
||||
@@ -283,7 +295,7 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline):
|
||||
guidance_scale: float = 7.5,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
@@ -318,8 +330,8 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline):
|
||||
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
||||
[`schedulers.DDIMScheduler`], will be ignored for others.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
||||
deterministic.
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
|
||||
@@ -401,8 +401,22 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
|
||||
def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
init_latent_dist = self.vae.encode(image).latent_dist
|
||||
init_latents = init_latent_dist.sample(generator=generator)
|
||||
|
||||
batch_size = batch_size * num_images_per_prompt
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
if isinstance(generator, list):
|
||||
init_latents = [
|
||||
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
|
||||
]
|
||||
init_latents = torch.cat(init_latents, dim=0)
|
||||
else:
|
||||
init_latents = self.vae.encode(image).latent_dist.sample(generator)
|
||||
|
||||
init_latents = 0.18215 * init_latents
|
||||
|
||||
if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
|
||||
@@ -415,16 +429,24 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
)
|
||||
deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False)
|
||||
additional_image_per_prompt = batch_size // init_latents.shape[0]
|
||||
init_latents = torch.cat([init_latents] * additional_image_per_prompt * num_images_per_prompt, dim=0)
|
||||
init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0)
|
||||
elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
|
||||
raise ValueError(
|
||||
f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
|
||||
)
|
||||
else:
|
||||
init_latents = torch.cat([init_latents] * num_images_per_prompt, dim=0)
|
||||
init_latents = torch.cat([init_latents], dim=0)
|
||||
|
||||
# add noise to latents using the timesteps
|
||||
noise = torch.randn(init_latents.shape, generator=generator, device=device, dtype=dtype)
|
||||
rand_device = "cpu" if device.type == "mps" else device
|
||||
shape = init_latents.shape
|
||||
if isinstance(generator, list):
|
||||
shape = (1,) + shape[1:]
|
||||
noise = [
|
||||
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) for i in range(batch_size)
|
||||
]
|
||||
noise = torch.cat(noise, dim=0).to(device)
|
||||
else:
|
||||
noise = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
|
||||
|
||||
# get latents
|
||||
init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
|
||||
@@ -443,7 +465,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
eta: Optional[float] = 0.0,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
@@ -483,8 +505,8 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
||||
[`schedulers.DDIMScheduler`], will be ignored for others.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
||||
deterministic.
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
|
||||
@@ -463,12 +463,24 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
|
||||
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
|
||||
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
if latents is None:
|
||||
if device.type == "mps":
|
||||
# randn does not work reproducibly on mps
|
||||
latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
|
||||
rand_device = "cpu" if device.type == "mps" else device
|
||||
|
||||
if isinstance(generator, list):
|
||||
shape = (1,) + shape[1:]
|
||||
latents = [
|
||||
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype)
|
||||
for i in range(batch_size)
|
||||
]
|
||||
latents = torch.cat(latents, dim=0).to(device)
|
||||
else:
|
||||
latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
|
||||
latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
|
||||
else:
|
||||
if latents.shape != shape:
|
||||
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
|
||||
@@ -492,7 +504,14 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
masked_image = masked_image.to(device=device, dtype=dtype)
|
||||
|
||||
# encode the mask image into latents space so we can concatenate it to the latents
|
||||
masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(generator=generator)
|
||||
if isinstance(generator, list):
|
||||
masked_image_latents = [
|
||||
self.vae.encode(masked_image[i : i + 1]).latent_dist.sample(generator=generator[i])
|
||||
for i in range(batch_size)
|
||||
]
|
||||
masked_image_latents = torch.cat(masked_image_latents, dim=0)
|
||||
else:
|
||||
masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(generator=generator)
|
||||
masked_image_latents = 0.18215 * masked_image_latents
|
||||
|
||||
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
|
||||
@@ -535,7 +554,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
@@ -578,8 +597,8 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
||||
[`schedulers.DDIMScheduler`], will be ignored for others.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
||||
deterministic.
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
|
||||
@@ -430,8 +430,9 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
|
||||
guidance_scale: Optional[float] = 7.5,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
add_predicted_noise: Optional[bool] = False,
|
||||
eta: Optional[float] = 0.0,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
@@ -471,12 +472,15 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
|
||||
if `guidance_scale` is less than `1`).
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
add_predicted_noise (`bool`, *optional*, defaults to True):
|
||||
Use predicted noise instead of random noise when constructing noisy versions of the original image in
|
||||
the reverse diffusion process
|
||||
eta (`float`, *optional*, defaults to 0.0):
|
||||
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
||||
[`schedulers.DDIMScheduler`], will be ignored for others.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
||||
deterministic.
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
@@ -561,7 +565,12 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
# masking
|
||||
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
|
||||
if add_predicted_noise:
|
||||
init_latents_proper = self.scheduler.add_noise(
|
||||
init_latents_orig, noise_pred_uncond, torch.tensor([t])
|
||||
)
|
||||
else:
|
||||
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
|
||||
|
||||
latents = (init_latents_proper * mask) + (latents * (1 - mask))
|
||||
|
||||
|
||||
@@ -332,7 +332,7 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline):
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
@@ -367,8 +367,8 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline):
|
||||
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
||||
[`schedulers.DDIMScheduler`], will be ignored for others.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
||||
deterministic.
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
|
||||
@@ -338,7 +338,7 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline):
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
@@ -371,8 +371,8 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline):
|
||||
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
||||
[`schedulers.DDIMScheduler`], will be ignored for others.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
||||
deterministic.
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
|
||||
@@ -422,12 +422,24 @@ class StableDiffusionPipelineSafe(DiffusionPipeline):
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
|
||||
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
|
||||
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
if latents is None:
|
||||
if device.type == "mps":
|
||||
# randn does not work reproducibly on mps
|
||||
latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
|
||||
rand_device = "cpu" if device.type == "mps" else device
|
||||
|
||||
if isinstance(generator, list):
|
||||
shape = (1,) + shape[1:]
|
||||
latents = [
|
||||
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype)
|
||||
for i in range(batch_size)
|
||||
]
|
||||
latents = torch.cat(latents, dim=0).to(device)
|
||||
else:
|
||||
latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
|
||||
latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
|
||||
else:
|
||||
if latents.shape != shape:
|
||||
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
|
||||
@@ -490,7 +502,7 @@ class StableDiffusionPipelineSafe(DiffusionPipeline):
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
@@ -530,8 +542,8 @@ class StableDiffusionPipelineSafe(DiffusionPipeline):
|
||||
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
||||
[`schedulers.DDIMScheduler`], will be ignored for others.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
||||
deterministic.
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional, Tuple, Union
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -49,7 +49,7 @@ class KarrasVePipeline(DiffusionPipeline):
|
||||
self,
|
||||
batch_size: int = 1,
|
||||
num_inference_steps: int = 50,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
**kwargs,
|
||||
@@ -59,8 +59,8 @@ class KarrasVePipeline(DiffusionPipeline):
|
||||
batch_size (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
||||
deterministic.
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
|
||||
16
src/diffusers/pipelines/unclip/__init__.py
Normal file
16
src/diffusers/pipelines/unclip/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from ...utils import (
|
||||
OptionalDependencyNotAvailable,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
is_transformers_version,
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.25.0")):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import UnCLIPPipeline
|
||||
else:
|
||||
from .pipeline_unclip import UnCLIPPipeline
|
||||
from .text_proj import UnCLIPTextProjModel
|
||||
489
src/diffusers/pipelines/unclip/pipeline_unclip.py
Normal file
489
src/diffusers/pipelines/unclip/pipeline_unclip.py
Normal file
@@ -0,0 +1,489 @@
|
||||
# Copyright 2022 Kakao Brain and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
from diffusers import PriorTransformer, UNet2DConditionModel, UNet2DModel
|
||||
from diffusers.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||
from diffusers.schedulers import UnCLIPScheduler
|
||||
from transformers import CLIPTextModelWithProjection, CLIPTokenizer
|
||||
|
||||
from ...utils import is_accelerate_available, logging
|
||||
from .text_proj import UnCLIPTextProjModel
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class UnCLIPPipeline(DiffusionPipeline):
|
||||
"""
|
||||
Pipeline for text-to-image generation using unCLIP
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
Args:
|
||||
text_encoder ([`CLIPTextModelWithProjection`]):
|
||||
Frozen text-encoder.
|
||||
tokenizer (`CLIPTokenizer`):
|
||||
Tokenizer of class
|
||||
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
||||
prior ([`PriorTransformer`]):
|
||||
The canonincal unCLIP prior to approximate the image embedding from the text embedding.
|
||||
decoder ([`UNet2DConditionModel`]):
|
||||
The decoder to invert the image embedding into an image.
|
||||
super_res_first ([`UNet2DModel`]):
|
||||
Super resolution unet. Used in all but the last step of the super resolution diffusion process.
|
||||
super_res_last ([`UNet2DModel`]):
|
||||
Super resolution unet. Used in the last step of the super resolution diffusion process.
|
||||
prior_scheduler ([`UnCLIPScheduler`]):
|
||||
Scheduler used in the prior denoising process. Just a modified DDPMScheduler.
|
||||
decoder_scheduler ([`UnCLIPScheduler`]):
|
||||
Scheduler used in the decoder denoising process. Just a modified DDPMScheduler.
|
||||
super_res_scheduler ([`UnCLIPScheduler`]):
|
||||
Scheduler used in the super resolution denoising process. Just a modified DDPMScheduler.
|
||||
|
||||
"""
|
||||
|
||||
prior: PriorTransformer
|
||||
decoder: UNet2DConditionModel
|
||||
text_proj: UnCLIPTextProjModel
|
||||
text_encoder: CLIPTextModelWithProjection
|
||||
tokenizer: CLIPTokenizer
|
||||
super_res_first: UNet2DModel
|
||||
super_res_last: UNet2DModel
|
||||
|
||||
prior_scheduler: UnCLIPScheduler
|
||||
decoder_scheduler: UnCLIPScheduler
|
||||
super_res_scheduler: UnCLIPScheduler
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prior: PriorTransformer,
|
||||
decoder: UNet2DConditionModel,
|
||||
text_encoder: CLIPTextModelWithProjection,
|
||||
tokenizer: CLIPTokenizer,
|
||||
text_proj: UnCLIPTextProjModel,
|
||||
super_res_first: UNet2DModel,
|
||||
super_res_last: UNet2DModel,
|
||||
prior_scheduler: UnCLIPScheduler,
|
||||
decoder_scheduler: UnCLIPScheduler,
|
||||
super_res_scheduler: UnCLIPScheduler,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
prior=prior,
|
||||
decoder=decoder,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
text_proj=text_proj,
|
||||
super_res_first=super_res_first,
|
||||
super_res_last=super_res_last,
|
||||
prior_scheduler=prior_scheduler,
|
||||
decoder_scheduler=decoder_scheduler,
|
||||
super_res_scheduler=super_res_scheduler,
|
||||
)
|
||||
|
||||
def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):
|
||||
if latents is None:
|
||||
if device.type == "mps":
|
||||
# randn does not work reproducibly on mps
|
||||
latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
|
||||
else:
|
||||
latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
|
||||
else:
|
||||
if latents.shape != shape:
|
||||
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
|
||||
latents = latents.to(device)
|
||||
|
||||
latents = latents * scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance):
|
||||
batch_size = len(prompt) if isinstance(prompt, list) else 1
|
||||
|
||||
# get prompt text embeddings
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
text_mask = text_inputs.attention_mask.bool().to(device)
|
||||
|
||||
if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
|
||||
removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
||||
)
|
||||
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
|
||||
|
||||
text_encoder_output = self.text_encoder(text_input_ids.to(device))
|
||||
|
||||
text_embeddings = text_encoder_output.text_embeds
|
||||
text_encoder_hidden_states = text_encoder_output.last_hidden_state
|
||||
|
||||
text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
uncond_tokens = [""] * batch_size
|
||||
|
||||
max_length = text_input_ids.shape[-1]
|
||||
uncond_input = self.tokenizer(
|
||||
uncond_tokens,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
uncond_text_mask = uncond_input.attention_mask.bool().to(device)
|
||||
uncond_embeddings_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(device))
|
||||
|
||||
uncond_embeddings = uncond_embeddings_text_encoder_output.text_embeds
|
||||
uncond_text_encoder_hidden_states = uncond_embeddings_text_encoder_output.last_hidden_state
|
||||
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
|
||||
seq_len = uncond_embeddings.shape[1]
|
||||
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt)
|
||||
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len)
|
||||
|
||||
seq_len = uncond_text_encoder_hidden_states.shape[1]
|
||||
uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1)
|
||||
uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view(
|
||||
batch_size * num_images_per_prompt, seq_len, -1
|
||||
)
|
||||
uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
|
||||
# done duplicates
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
||||
text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states])
|
||||
|
||||
text_mask = torch.cat([uncond_text_mask, text_mask])
|
||||
|
||||
return text_embeddings, text_encoder_hidden_states, text_mask
|
||||
|
||||
def enable_sequential_cpu_offload(self, gpu_id=0):
|
||||
r"""
|
||||
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's
|
||||
models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only
|
||||
when their specific submodule has its `forward` method called.
|
||||
"""
|
||||
if is_accelerate_available():
|
||||
from accelerate import cpu_offload
|
||||
else:
|
||||
raise ImportError("Please install accelerate via `pip install accelerate`")
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
# TODO: self.prior.post_process_latents is not covered by the offload hooks, so it fails if added to the list
|
||||
models = [
|
||||
self.decoder,
|
||||
self.text_proj,
|
||||
self.text_encoder,
|
||||
self.super_res_first,
|
||||
self.super_res_last,
|
||||
]
|
||||
for cpu_offloaded_model in models:
|
||||
if cpu_offloaded_model is not None:
|
||||
cpu_offload(cpu_offloaded_model, device)
|
||||
|
||||
@property
|
||||
def _execution_device(self):
|
||||
r"""
|
||||
Returns the device on which the pipeline's models will be executed. After calling
|
||||
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
|
||||
hooks.
|
||||
"""
|
||||
if self.device != torch.device("meta") or not hasattr(self.decoder, "_hf_hook"):
|
||||
return self.device
|
||||
for module in self.decoder.modules():
|
||||
if (
|
||||
hasattr(module, "_hf_hook")
|
||||
and hasattr(module._hf_hook, "execution_device")
|
||||
and module._hf_hook.execution_device is not None
|
||||
):
|
||||
return torch.device(module._hf_hook.execution_device)
|
||||
return self.device
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
num_images_per_prompt: int = 1,
|
||||
prior_num_inference_steps: int = 25,
|
||||
decoder_num_inference_steps: int = 25,
|
||||
super_res_num_inference_steps: int = 7,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
prior_latents: Optional[torch.FloatTensor] = None,
|
||||
decoder_latents: Optional[torch.FloatTensor] = None,
|
||||
super_res_latents: Optional[torch.FloatTensor] = None,
|
||||
prior_guidance_scale: float = 4.0,
|
||||
decoder_guidance_scale: float = 8.0,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
):
|
||||
"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`):
|
||||
The prompt or prompts to guide the image generation.
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
prior_num_inference_steps (`int`, *optional*, defaults to 25):
|
||||
The number of denoising steps for the prior. More denoising steps usually lead to a higher quality
|
||||
image at the expense of slower inference.
|
||||
decoder_num_inference_steps (`int`, *optional*, defaults to 25):
|
||||
The number of denoising steps for the decoder. More denoising steps usually lead to a higher quality
|
||||
image at the expense of slower inference.
|
||||
super_res_num_inference_steps (`int`, *optional*, defaults to 7):
|
||||
The number of denoising steps for super resolution. More denoising steps usually lead to a higher
|
||||
quality image at the expense of slower inference.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
prior_latents (`torch.FloatTensor` of shape (batch size, embeddings dimension), *optional*):
|
||||
Pre-generated noisy latents to be used as inputs for the prior.
|
||||
decoder_latents (`torch.FloatTensor` of shape (batch size, channels, height, width), *optional*):
|
||||
Pre-generated noisy latents to be used as inputs for the decoder.
|
||||
super_res_latents (`torch.FloatTensor` of shape (batch size, channels, super res height, super res width), *optional*):
|
||||
Pre-generated noisy latents to be used as inputs for the decoder.
|
||||
prior_guidance_scale (`float`, *optional*, defaults to 4.0):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
decoder_guidance_scale (`float`, *optional*, defaults to 4.0):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generated image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple.
|
||||
"""
|
||||
if isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
device = self._execution_device
|
||||
|
||||
batch_size = batch_size * num_images_per_prompt
|
||||
|
||||
do_classifier_free_guidance = prior_guidance_scale > 1.0 or decoder_guidance_scale > 1.0
|
||||
|
||||
text_embeddings, text_encoder_hidden_states, text_mask = self._encode_prompt(
|
||||
prompt, device, num_images_per_prompt, do_classifier_free_guidance
|
||||
)
|
||||
|
||||
# prior
|
||||
|
||||
self.prior_scheduler.set_timesteps(prior_num_inference_steps, device=device)
|
||||
prior_timesteps_tensor = self.prior_scheduler.timesteps
|
||||
|
||||
embedding_dim = self.prior.config.embedding_dim
|
||||
prior_latents = self.prepare_latents(
|
||||
(batch_size, embedding_dim),
|
||||
text_embeddings.dtype,
|
||||
device,
|
||||
generator,
|
||||
prior_latents,
|
||||
self.prior_scheduler,
|
||||
)
|
||||
|
||||
for i, t in enumerate(self.progress_bar(prior_timesteps_tensor)):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([prior_latents] * 2) if do_classifier_free_guidance else prior_latents
|
||||
|
||||
predicted_image_embedding = self.prior(
|
||||
latent_model_input,
|
||||
timestep=t,
|
||||
proj_embedding=text_embeddings,
|
||||
encoder_hidden_states=text_encoder_hidden_states,
|
||||
attention_mask=text_mask,
|
||||
).predicted_image_embedding
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
predicted_image_embedding_uncond, predicted_image_embedding_text = predicted_image_embedding.chunk(2)
|
||||
predicted_image_embedding = predicted_image_embedding_uncond + prior_guidance_scale * (
|
||||
predicted_image_embedding_text - predicted_image_embedding_uncond
|
||||
)
|
||||
|
||||
if i + 1 == prior_timesteps_tensor.shape[0]:
|
||||
prev_timestep = None
|
||||
else:
|
||||
prev_timestep = prior_timesteps_tensor[i + 1]
|
||||
|
||||
prior_latents = self.prior_scheduler.step(
|
||||
predicted_image_embedding,
|
||||
timestep=t,
|
||||
sample=prior_latents,
|
||||
generator=generator,
|
||||
prev_timestep=prev_timestep,
|
||||
).prev_sample
|
||||
|
||||
prior_latents = self.prior.post_process_latents(prior_latents)
|
||||
|
||||
image_embeddings = prior_latents
|
||||
|
||||
# done prior
|
||||
|
||||
# decoder
|
||||
|
||||
text_encoder_hidden_states, additive_clip_time_embeddings = self.text_proj(
|
||||
image_embeddings=image_embeddings,
|
||||
text_embeddings=text_embeddings,
|
||||
text_encoder_hidden_states=text_encoder_hidden_states,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
)
|
||||
|
||||
decoder_text_mask = F.pad(text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=1)
|
||||
|
||||
self.decoder_scheduler.set_timesteps(decoder_num_inference_steps, device=device)
|
||||
decoder_timesteps_tensor = self.decoder_scheduler.timesteps
|
||||
|
||||
num_channels_latents = self.decoder.in_channels
|
||||
height = self.decoder.sample_size
|
||||
width = self.decoder.sample_size
|
||||
decoder_latents = self.prepare_latents(
|
||||
(batch_size, num_channels_latents, height, width),
|
||||
text_encoder_hidden_states.dtype,
|
||||
device,
|
||||
generator,
|
||||
decoder_latents,
|
||||
self.decoder_scheduler,
|
||||
)
|
||||
|
||||
for i, t in enumerate(self.progress_bar(decoder_timesteps_tensor)):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([decoder_latents] * 2) if do_classifier_free_guidance else decoder_latents
|
||||
|
||||
noise_pred = self.decoder(
|
||||
sample=latent_model_input,
|
||||
timestep=t,
|
||||
encoder_hidden_states=text_encoder_hidden_states,
|
||||
class_labels=additive_clip_time_embeddings,
|
||||
attention_mask=decoder_text_mask,
|
||||
).sample
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred_uncond, _ = noise_pred_uncond.split(latent_model_input.shape[1], dim=1)
|
||||
noise_pred_text, predicted_variance = noise_pred_text.split(latent_model_input.shape[1], dim=1)
|
||||
noise_pred = noise_pred_uncond + decoder_guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)
|
||||
|
||||
if i + 1 == decoder_timesteps_tensor.shape[0]:
|
||||
prev_timestep = None
|
||||
else:
|
||||
prev_timestep = decoder_timesteps_tensor[i + 1]
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
decoder_latents = self.decoder_scheduler.step(
|
||||
noise_pred, t, decoder_latents, prev_timestep=prev_timestep, generator=generator
|
||||
).prev_sample
|
||||
|
||||
decoder_latents = decoder_latents.clamp(-1, 1)
|
||||
|
||||
image_small = decoder_latents
|
||||
|
||||
# done decoder
|
||||
|
||||
# super res
|
||||
|
||||
self.super_res_scheduler.set_timesteps(super_res_num_inference_steps, device=device)
|
||||
super_res_timesteps_tensor = self.super_res_scheduler.timesteps
|
||||
|
||||
channels = self.super_res_first.in_channels // 2
|
||||
height = self.super_res_first.sample_size
|
||||
width = self.super_res_first.sample_size
|
||||
super_res_latents = self.prepare_latents(
|
||||
(batch_size, channels, height, width),
|
||||
image_small.dtype,
|
||||
device,
|
||||
generator,
|
||||
super_res_latents,
|
||||
self.super_res_scheduler,
|
||||
)
|
||||
|
||||
interpolate_antialias = {}
|
||||
if "antialias" in inspect.signature(F.interpolate).parameters:
|
||||
interpolate_antialias["antialias"] = True
|
||||
|
||||
image_upscaled = F.interpolate(
|
||||
image_small, size=[height, width], mode="bicubic", align_corners=False, **interpolate_antialias
|
||||
)
|
||||
|
||||
for i, t in enumerate(self.progress_bar(super_res_timesteps_tensor)):
|
||||
# no classifier free guidance
|
||||
|
||||
if i == super_res_timesteps_tensor.shape[0] - 1:
|
||||
unet = self.super_res_last
|
||||
else:
|
||||
unet = self.super_res_first
|
||||
|
||||
latent_model_input = torch.cat([super_res_latents, image_upscaled], dim=1)
|
||||
|
||||
noise_pred = unet(
|
||||
sample=latent_model_input,
|
||||
timestep=t,
|
||||
).sample
|
||||
|
||||
if i + 1 == super_res_timesteps_tensor.shape[0]:
|
||||
prev_timestep = None
|
||||
else:
|
||||
prev_timestep = super_res_timesteps_tensor[i + 1]
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
super_res_latents = self.super_res_scheduler.step(
|
||||
noise_pred, t, super_res_latents, prev_timestep=prev_timestep, generator=generator
|
||||
).prev_sample
|
||||
|
||||
image = super_res_latents
|
||||
|
||||
# done super res
|
||||
|
||||
# post processing
|
||||
|
||||
image = image * 0.5 + 0.5
|
||||
image = image.clamp(0, 1)
|
||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
|
||||
if output_type == "pil":
|
||||
image = self.numpy_to_pil(image)
|
||||
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
return ImagePipelineOutput(images=image)
|
||||
87
src/diffusers/pipelines/unclip/text_proj.py
Normal file
87
src/diffusers/pipelines/unclip/text_proj.py
Normal file
@@ -0,0 +1,87 @@
|
||||
# Copyright 2022 Kakao Brain and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from diffusers.modeling_utils import ModelMixin
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
|
||||
|
||||
class UnCLIPTextProjModel(ModelMixin, ConfigMixin):
|
||||
"""
|
||||
Utility class for CLIP embeddings. Used to combine the image and text embeddings into a format usable by the
|
||||
decoder.
|
||||
|
||||
For more details, see the original paper: https://arxiv.org/abs/2204.06125 section 2.1
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
clip_extra_context_tokens: int = 4,
|
||||
clip_embeddings_dim: int = 768,
|
||||
time_embed_dim: int,
|
||||
cross_attention_dim,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.learned_classifier_free_guidance_embeddings = nn.Parameter(torch.zeros(clip_embeddings_dim))
|
||||
|
||||
# parameters for additional clip time embeddings
|
||||
self.embedding_proj = nn.Linear(clip_embeddings_dim, time_embed_dim)
|
||||
self.clip_image_embeddings_project_to_time_embeddings = nn.Linear(clip_embeddings_dim, time_embed_dim)
|
||||
|
||||
# parameters for encoder hidden states
|
||||
self.clip_extra_context_tokens = clip_extra_context_tokens
|
||||
self.clip_extra_context_tokens_proj = nn.Linear(
|
||||
clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim
|
||||
)
|
||||
self.encoder_hidden_states_proj = nn.Linear(clip_embeddings_dim, cross_attention_dim)
|
||||
self.text_encoder_hidden_states_norm = nn.LayerNorm(cross_attention_dim)
|
||||
|
||||
def forward(self, *, image_embeddings, text_embeddings, text_encoder_hidden_states, do_classifier_free_guidance):
|
||||
if do_classifier_free_guidance:
|
||||
# Add the classifier free guidance embeddings to the image embeddings
|
||||
image_embeddings_batch_size = image_embeddings.shape[0]
|
||||
classifier_free_guidance_embeddings = self.learned_classifier_free_guidance_embeddings.unsqueeze(0)
|
||||
classifier_free_guidance_embeddings = classifier_free_guidance_embeddings.expand(
|
||||
image_embeddings_batch_size, -1
|
||||
)
|
||||
image_embeddings = torch.cat([classifier_free_guidance_embeddings, image_embeddings], dim=0)
|
||||
|
||||
# The image embeddings batch size and the text embeddings batch size are equal
|
||||
assert image_embeddings.shape[0] == text_embeddings.shape[0]
|
||||
|
||||
batch_size = text_embeddings.shape[0]
|
||||
|
||||
# "Specifically, we modify the architecture described in Nichol et al. (2021) by projecting and
|
||||
# adding CLIP embeddings to the existing timestep embedding, ...
|
||||
time_projected_text_embeddings = self.embedding_proj(text_embeddings)
|
||||
time_projected_image_embeddings = self.clip_image_embeddings_project_to_time_embeddings(image_embeddings)
|
||||
additive_clip_time_embeddings = time_projected_image_embeddings + time_projected_text_embeddings
|
||||
|
||||
# ... and by projecting CLIP embeddings into four
|
||||
# extra tokens of context that are concatenated to the sequence of outputs from the GLIDE text encoder"
|
||||
clip_extra_context_tokens = self.clip_extra_context_tokens_proj(image_embeddings)
|
||||
clip_extra_context_tokens = clip_extra_context_tokens.reshape(batch_size, -1, self.clip_extra_context_tokens)
|
||||
|
||||
text_encoder_hidden_states = self.encoder_hidden_states_proj(text_encoder_hidden_states)
|
||||
text_encoder_hidden_states = self.text_encoder_hidden_states_norm(text_encoder_hidden_states)
|
||||
text_encoder_hidden_states = text_encoder_hidden_states.permute(0, 2, 1)
|
||||
text_encoder_hidden_states = torch.cat([clip_extra_context_tokens, text_encoder_hidden_states], dim=2)
|
||||
|
||||
return text_encoder_hidden_states, additive_clip_time_embeddings
|
||||
@@ -6,8 +6,9 @@ import torch.nn as nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...modeling_utils import ModelMixin
|
||||
from ...models.attention import DualTransformer2DModel, Transformer2DModel
|
||||
from ...models.attention import CrossAttention, DualTransformer2DModel, Transformer2DModel
|
||||
from ...models.embeddings import TimestepEmbedding, Timesteps
|
||||
from ...models.unet_2d_blocks import UNetMidBlock2DSimpleCrossAttn as UNetMidBlockFlatSimpleCrossAttn
|
||||
from ...models.unet_2d_condition import UNet2DConditionOutput
|
||||
from ...utils import logging
|
||||
|
||||
@@ -32,6 +33,7 @@ def get_down_block(
|
||||
use_linear_projection=False,
|
||||
only_cross_attention=False,
|
||||
upcast_attention=False,
|
||||
resnet_time_scale_shift="default",
|
||||
):
|
||||
down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
|
||||
if down_block_type == "DownBlockFlat":
|
||||
@@ -45,6 +47,7 @@ def get_down_block(
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
resnet_groups=resnet_groups,
|
||||
downsample_padding=downsample_padding,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
)
|
||||
elif down_block_type == "CrossAttnDownBlockFlat":
|
||||
if cross_attention_dim is None:
|
||||
@@ -64,6 +67,7 @@ def get_down_block(
|
||||
dual_cross_attention=dual_cross_attention,
|
||||
use_linear_projection=use_linear_projection,
|
||||
only_cross_attention=only_cross_attention,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
)
|
||||
raise ValueError(f"{down_block_type} is not supported.")
|
||||
|
||||
@@ -85,6 +89,7 @@ def get_up_block(
|
||||
use_linear_projection=False,
|
||||
only_cross_attention=False,
|
||||
upcast_attention=False,
|
||||
resnet_time_scale_shift="default",
|
||||
):
|
||||
up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
|
||||
if up_block_type == "UpBlockFlat":
|
||||
@@ -98,6 +103,7 @@ def get_up_block(
|
||||
resnet_eps=resnet_eps,
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
resnet_groups=resnet_groups,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
)
|
||||
elif up_block_type == "CrossAttnUpBlockFlat":
|
||||
if cross_attention_dim is None:
|
||||
@@ -117,6 +123,7 @@ def get_up_block(
|
||||
dual_cross_attention=dual_cross_attention,
|
||||
use_linear_projection=use_linear_projection,
|
||||
only_cross_attention=only_cross_attention,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
)
|
||||
raise ValueError(f"{up_block_type} is not supported.")
|
||||
|
||||
@@ -141,6 +148,8 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
|
||||
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlockFlat", "CrossAttnDownBlockFlat", "CrossAttnDownBlockFlat", "DownBlockFlat")`):
|
||||
The tuple of downsample blocks to use.
|
||||
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlockFlatCrossAttn"`):
|
||||
The mid block type. Choose from `UNetMidBlockFlatCrossAttn` or `UNetMidBlockFlatSimpleCrossAttn`.
|
||||
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlockFlat", "CrossAttnUpBlockFlat", "CrossAttnUpBlockFlat", "CrossAttnUpBlockFlat",)`):
|
||||
The tuple of upsample blocks to use.
|
||||
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
|
||||
@@ -153,6 +162,10 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
|
||||
cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features.
|
||||
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
|
||||
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
|
||||
for resnet blocks, see [`~models.resnet.ResnetBlockFlat`]. Choose from `default` or `scale_shift`.
|
||||
class_embed_type (`str`, *optional*, defaults to None): The type of class embedding to use which is ultimately
|
||||
summed with the time embeddings. Choose from `None`, `"timestep"`, or `"identity"`.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
@@ -172,6 +185,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
"CrossAttnDownBlockFlat",
|
||||
"DownBlockFlat",
|
||||
),
|
||||
mid_block_type: str = "UNetMidBlockFlatCrossAttn",
|
||||
up_block_types: Tuple[str] = (
|
||||
"UpBlockFlat",
|
||||
"CrossAttnUpBlockFlat",
|
||||
@@ -190,8 +204,10 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
attention_head_dim: Union[int, Tuple[int]] = 8,
|
||||
dual_cross_attention: bool = False,
|
||||
use_linear_projection: bool = False,
|
||||
class_embed_type: Optional[str] = None,
|
||||
num_class_embeds: Optional[int] = None,
|
||||
upcast_attention: bool = False,
|
||||
resnet_time_scale_shift: str = "default",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -208,8 +224,14 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
||||
|
||||
# class embedding
|
||||
if num_class_embeds is not None:
|
||||
if class_embed_type is None and num_class_embeds is not None:
|
||||
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
|
||||
elif class_embed_type == "timestep":
|
||||
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
||||
elif class_embed_type == "identity":
|
||||
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
|
||||
else:
|
||||
self.class_embedding = None
|
||||
|
||||
self.down_blocks = nn.ModuleList([])
|
||||
self.mid_block = None
|
||||
@@ -245,24 +267,40 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
use_linear_projection=use_linear_projection,
|
||||
only_cross_attention=only_cross_attention[i],
|
||||
upcast_attention=upcast_attention,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
)
|
||||
self.down_blocks.append(down_block)
|
||||
|
||||
# mid
|
||||
self.mid_block = UNetMidBlockFlatCrossAttn(
|
||||
in_channels=block_out_channels[-1],
|
||||
temb_channels=time_embed_dim,
|
||||
resnet_eps=norm_eps,
|
||||
resnet_act_fn=act_fn,
|
||||
output_scale_factor=mid_block_scale_factor,
|
||||
resnet_time_scale_shift="default",
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
attn_num_head_channels=attention_head_dim[-1],
|
||||
resnet_groups=norm_num_groups,
|
||||
dual_cross_attention=dual_cross_attention,
|
||||
use_linear_projection=use_linear_projection,
|
||||
upcast_attention=upcast_attention,
|
||||
)
|
||||
if mid_block_type == "UNetMidBlockFlatCrossAttn":
|
||||
self.mid_block = UNetMidBlockFlatCrossAttn(
|
||||
in_channels=block_out_channels[-1],
|
||||
temb_channels=time_embed_dim,
|
||||
resnet_eps=norm_eps,
|
||||
resnet_act_fn=act_fn,
|
||||
output_scale_factor=mid_block_scale_factor,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
attn_num_head_channels=attention_head_dim[-1],
|
||||
resnet_groups=norm_num_groups,
|
||||
dual_cross_attention=dual_cross_attention,
|
||||
use_linear_projection=use_linear_projection,
|
||||
upcast_attention=upcast_attention,
|
||||
)
|
||||
elif mid_block_type == "UNetMidBlockFlatSimpleCrossAttn":
|
||||
self.mid_block = UNetMidBlockFlatSimpleCrossAttn(
|
||||
in_channels=block_out_channels[-1],
|
||||
temb_channels=time_embed_dim,
|
||||
resnet_eps=norm_eps,
|
||||
resnet_act_fn=act_fn,
|
||||
output_scale_factor=mid_block_scale_factor,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
attn_num_head_channels=attention_head_dim[-1],
|
||||
resnet_groups=norm_num_groups,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"unknown mid_block_type : {mid_block_type}")
|
||||
|
||||
# count how many layers upsample the images
|
||||
self.num_upsamplers = 0
|
||||
@@ -303,6 +341,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
use_linear_projection=use_linear_projection,
|
||||
only_cross_attention=only_cross_attention[i],
|
||||
upcast_attention=upcast_attention,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
)
|
||||
self.up_blocks.append(up_block)
|
||||
prev_output_channel = output_channel
|
||||
@@ -387,6 +426,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
timestep: Union[torch.Tensor, float, int],
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
class_labels: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[UNet2DConditionOutput, Tuple]:
|
||||
r"""
|
||||
@@ -416,6 +456,11 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
logger.info("Forward upsample size to force interpolation output size.")
|
||||
forward_upsample_size = True
|
||||
|
||||
# prepare attention_mask
|
||||
if attention_mask is not None:
|
||||
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
||||
attention_mask = attention_mask.unsqueeze(1)
|
||||
|
||||
# 0. center input if necessary
|
||||
if self.config.center_input_sample:
|
||||
sample = 2 * sample - 1.0
|
||||
@@ -445,9 +490,13 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
t_emb = t_emb.to(dtype=self.dtype)
|
||||
emb = self.time_embedding(t_emb)
|
||||
|
||||
if self.config.num_class_embeds is not None:
|
||||
if self.class_embedding is not None:
|
||||
if class_labels is None:
|
||||
raise ValueError("class_labels should be provided when num_class_embeds > 0")
|
||||
|
||||
if self.config.class_embed_type == "timestep":
|
||||
class_labels = self.time_proj(class_labels)
|
||||
|
||||
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
|
||||
emb = emb + class_emb
|
||||
|
||||
@@ -462,6 +511,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
hidden_states=sample,
|
||||
temb=emb,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
else:
|
||||
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
||||
@@ -469,7 +519,9 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
down_block_res_samples += res_samples
|
||||
|
||||
# 4. mid
|
||||
sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states)
|
||||
sample = self.mid_block(
|
||||
sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
|
||||
)
|
||||
|
||||
# 5. up
|
||||
for i, upsample_block in enumerate(self.up_blocks):
|
||||
@@ -490,6 +542,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
res_hidden_states_tuple=res_samples,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
upsample_size=upsample_size,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
else:
|
||||
sample = upsample_block(
|
||||
@@ -715,7 +768,6 @@ class CrossAttnDownBlockFlat(nn.Module):
|
||||
resnet_pre_norm: bool = True,
|
||||
attn_num_head_channels=1,
|
||||
cross_attention_dim=1280,
|
||||
attention_type="default",
|
||||
output_scale_factor=1.0,
|
||||
downsample_padding=1,
|
||||
add_downsample=True,
|
||||
@@ -729,7 +781,6 @@ class CrossAttnDownBlockFlat(nn.Module):
|
||||
attentions = []
|
||||
|
||||
self.has_cross_attention = True
|
||||
self.attention_type = attention_type
|
||||
self.attn_num_head_channels = attn_num_head_channels
|
||||
|
||||
for i in range(num_layers):
|
||||
@@ -789,7 +840,8 @@ class CrossAttnDownBlockFlat(nn.Module):
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
|
||||
def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):
|
||||
# TODO(Patrick, William) - attention mask is not used
|
||||
output_states = ()
|
||||
|
||||
for resnet, attn in zip(self.resnets, self.attentions):
|
||||
@@ -806,7 +858,9 @@ class CrossAttnDownBlockFlat(nn.Module):
|
||||
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states
|
||||
create_custom_forward(attn, return_dict=False),
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
)[0]
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
@@ -915,7 +969,6 @@ class CrossAttnUpBlockFlat(nn.Module):
|
||||
resnet_pre_norm: bool = True,
|
||||
attn_num_head_channels=1,
|
||||
cross_attention_dim=1280,
|
||||
attention_type="default",
|
||||
output_scale_factor=1.0,
|
||||
add_upsample=True,
|
||||
dual_cross_attention=False,
|
||||
@@ -928,7 +981,6 @@ class CrossAttnUpBlockFlat(nn.Module):
|
||||
attentions = []
|
||||
|
||||
self.has_cross_attention = True
|
||||
self.attention_type = attention_type
|
||||
self.attn_num_head_channels = attn_num_head_channels
|
||||
|
||||
for i in range(num_layers):
|
||||
@@ -991,7 +1043,9 @@ class CrossAttnUpBlockFlat(nn.Module):
|
||||
temb=None,
|
||||
encoder_hidden_states=None,
|
||||
upsample_size=None,
|
||||
attention_mask=None,
|
||||
):
|
||||
# TODO(Patrick, William) - attention mask is not used
|
||||
for resnet, attn in zip(self.resnets, self.attentions):
|
||||
# pop res hidden states
|
||||
res_hidden_states = res_hidden_states_tuple[-1]
|
||||
@@ -1011,7 +1065,9 @@ class CrossAttnUpBlockFlat(nn.Module):
|
||||
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states
|
||||
create_custom_forward(attn, return_dict=False),
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
)[0]
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
@@ -1038,7 +1094,6 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
|
||||
resnet_groups: int = 32,
|
||||
resnet_pre_norm: bool = True,
|
||||
attn_num_head_channels=1,
|
||||
attention_type="default",
|
||||
output_scale_factor=1.0,
|
||||
cross_attention_dim=1280,
|
||||
dual_cross_attention=False,
|
||||
@@ -1048,7 +1103,6 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
self.has_cross_attention = True
|
||||
self.attention_type = attention_type
|
||||
self.attn_num_head_channels = attn_num_head_channels
|
||||
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
||||
|
||||
@@ -1112,10 +1166,105 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
|
||||
self.attentions = nn.ModuleList(attentions)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
|
||||
def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):
|
||||
# TODO(Patrick, William) - attention_mask is currently not used. Implement once used
|
||||
hidden_states = self.resnets[0](hidden_states, temb)
|
||||
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
||||
hidden_states = attn(hidden_states, encoder_hidden_states).sample
|
||||
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Copied from diffusers.models.unet_2d_blocks.UNetMidBlock2DSimpleCrossAttn with UNetMidBlock2DSimpleCrossAttn->UNetMidBlockFlatCrossAttn, ResnetBlock2D->ResnetBlockFlat
|
||||
class UnCLIPUNetMidBlockFlatCrossAttn(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
temb_channels: int,
|
||||
dropout: float = 0.0,
|
||||
num_layers: int = 1,
|
||||
resnet_eps: float = 1e-6,
|
||||
resnet_time_scale_shift: str = "default",
|
||||
resnet_act_fn: str = "swish",
|
||||
resnet_groups: int = 32,
|
||||
resnet_pre_norm: bool = True,
|
||||
attn_num_head_channels=1,
|
||||
output_scale_factor=1.0,
|
||||
cross_attention_dim=1280,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.has_cross_attention = True
|
||||
|
||||
self.attn_num_head_channels = attn_num_head_channels
|
||||
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
||||
|
||||
self.num_heads = in_channels // self.attn_num_head_channels
|
||||
|
||||
# there is always at least one resnet
|
||||
resnets = [
|
||||
ResnetBlockFlat(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
temb_channels=temb_channels,
|
||||
eps=resnet_eps,
|
||||
groups=resnet_groups,
|
||||
dropout=dropout,
|
||||
time_embedding_norm=resnet_time_scale_shift,
|
||||
non_linearity=resnet_act_fn,
|
||||
output_scale_factor=output_scale_factor,
|
||||
pre_norm=resnet_pre_norm,
|
||||
)
|
||||
]
|
||||
attentions = []
|
||||
|
||||
for _ in range(num_layers):
|
||||
attentions.append(
|
||||
CrossAttention(
|
||||
query_dim=in_channels,
|
||||
cross_attention_dim=in_channels,
|
||||
heads=self.num_heads,
|
||||
dim_head=attn_num_head_channels,
|
||||
added_kv_proj_dim=cross_attention_dim,
|
||||
norm_num_groups=resnet_groups,
|
||||
bias=True,
|
||||
upcast_softmax=True,
|
||||
)
|
||||
)
|
||||
resnets.append(
|
||||
ResnetBlockFlat(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
temb_channels=temb_channels,
|
||||
eps=resnet_eps,
|
||||
groups=resnet_groups,
|
||||
dropout=dropout,
|
||||
time_embedding_norm=resnet_time_scale_shift,
|
||||
non_linearity=resnet_act_fn,
|
||||
output_scale_factor=output_scale_factor,
|
||||
pre_norm=resnet_pre_norm,
|
||||
)
|
||||
)
|
||||
|
||||
self.attentions = nn.ModuleList(attentions)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):
|
||||
hidden_states = self.resnets[0](hidden_states, temb)
|
||||
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
||||
# attn
|
||||
residual = hidden_states
|
||||
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
|
||||
hidden_states = attn(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states.transpose(1, 2),
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
# resnet
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@@ -91,7 +91,7 @@ class VersatileDiffusionPipeline(DiffusionPipeline):
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
@@ -126,8 +126,8 @@ class VersatileDiffusionPipeline(DiffusionPipeline):
|
||||
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
||||
[`schedulers.DDIMScheduler`], will be ignored for others.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
||||
deterministic.
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
@@ -207,7 +207,7 @@ class VersatileDiffusionPipeline(DiffusionPipeline):
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
@@ -242,8 +242,8 @@ class VersatileDiffusionPipeline(DiffusionPipeline):
|
||||
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
||||
[`schedulers.DDIMScheduler`], will be ignored for others.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
||||
deterministic.
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
@@ -320,7 +320,7 @@ class VersatileDiffusionPipeline(DiffusionPipeline):
|
||||
guidance_scale: float = 7.5,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
@@ -355,8 +355,8 @@ class VersatileDiffusionPipeline(DiffusionPipeline):
|
||||
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
||||
[`schedulers.DDIMScheduler`], will be ignored for others.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
||||
deterministic.
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
|
||||
@@ -376,12 +376,24 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline):
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
|
||||
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
|
||||
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
if latents is None:
|
||||
if device.type == "mps":
|
||||
# randn does not work reproducibly on mps
|
||||
latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
|
||||
rand_device = "cpu" if device.type == "mps" else device
|
||||
|
||||
if isinstance(generator, list):
|
||||
shape = (1,) + shape[1:]
|
||||
latents = [
|
||||
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype)
|
||||
for i in range(batch_size)
|
||||
]
|
||||
latents = torch.cat(latents, dim=0).to(device)
|
||||
else:
|
||||
latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
|
||||
latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
|
||||
else:
|
||||
if latents.shape != shape:
|
||||
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
|
||||
@@ -416,7 +428,7 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline):
|
||||
guidance_scale: float = 7.5,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
@@ -452,8 +464,8 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline):
|
||||
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
||||
[`schedulers.DDIMScheduler`], will be ignored for others.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
||||
deterministic.
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user