mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-11 15:04:45 +08:00
Compare commits
67 Commits
0.1.0
...
update-che
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fd5c52e16b | ||
|
|
df90f0ce98 | ||
|
|
ed22b4fd07 | ||
|
|
f9522d825c | ||
|
|
80e0c8ba9e | ||
|
|
3cd20d59d7 | ||
|
|
e36a36788e | ||
|
|
4b02f53e62 | ||
|
|
27d11a0094 | ||
|
|
554e67cb06 | ||
|
|
45cb500667 | ||
|
|
8c78e73fef | ||
|
|
c1b378db69 | ||
|
|
b50a9ae383 | ||
|
|
ea2e177c1d | ||
|
|
513f1fbfb0 | ||
|
|
d7b692083c | ||
|
|
9070c394aa | ||
|
|
194ed794d8 | ||
|
|
051b34635f | ||
|
|
5f25818a0f | ||
|
|
c25d8c905c | ||
|
|
5782e0393d | ||
|
|
92b6dbba1a | ||
|
|
c72e343085 | ||
|
|
3228eb1609 | ||
|
|
c1488ff348 | ||
|
|
b344c953a8 | ||
|
|
dd10da76a7 | ||
|
|
543ee1e092 | ||
|
|
75b6c16567 | ||
|
|
c4ae7c2421 | ||
|
|
a2090375ca | ||
|
|
c4a3b09a36 | ||
|
|
616c3a42cb | ||
|
|
d23cf98769 | ||
|
|
eeb9264acd | ||
|
|
b6447fa87e | ||
|
|
b6cadcef98 | ||
|
|
3100bc9670 | ||
|
|
e05f03ae41 | ||
|
|
6c15636b0b | ||
|
|
89f2011ced | ||
|
|
0f8547c2af | ||
|
|
343180c2cf | ||
|
|
27782bc18e | ||
|
|
cde0ed162a | ||
|
|
570d3f1eb9 | ||
|
|
85244d4a59 | ||
|
|
1a84bd2a0f | ||
|
|
3247eadde4 | ||
|
|
a487b5095a | ||
|
|
04fa7baea8 | ||
|
|
9a04a8a6a8 | ||
|
|
a05a5fb9ba | ||
|
|
71faf347fd | ||
|
|
2f1f7b01d6 | ||
|
|
5311f564ed | ||
|
|
3b7f514a1c | ||
|
|
7c0a861894 | ||
|
|
a73ae3e5b0 | ||
|
|
06505ba4b4 | ||
|
|
13457002c0 | ||
|
|
302b86bd0b | ||
|
|
d87d5edf66 | ||
|
|
e795a4c6f8 | ||
|
|
4293b9f54f |
37
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
Normal file
37
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
Normal file
@@ -0,0 +1,37 @@
|
||||
name: "\U0001F41B Bug Report"
|
||||
description: Report a bug on diffusers
|
||||
labels: [ "bug" ]
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
Thanks for taking the time to fill out this bug report!
|
||||
- type: textarea
|
||||
id: bug-description
|
||||
attributes:
|
||||
label: Describe the bug
|
||||
description: A clear and concise description of what the bug is. If you intend to submit a pull request for this issue, tell us in the description. Thanks!
|
||||
placeholder: Bug description
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
id: reproduction
|
||||
attributes:
|
||||
label: Reproduction
|
||||
description: Please provide a minimal reproducible code which we can copy/paste and reproduce the issue.
|
||||
placeholder: Reproduction
|
||||
- type: textarea
|
||||
id: logs
|
||||
attributes:
|
||||
label: Logs
|
||||
description: "Please include the Python logs if you can."
|
||||
render: shell
|
||||
- type: textarea
|
||||
id: system-info
|
||||
attributes:
|
||||
label: System Info
|
||||
description: Please share your system info with us,
|
||||
render: shell
|
||||
placeholder: diffusers version, Python Version, etc
|
||||
validations:
|
||||
required: true
|
||||
7
.github/ISSUE_TEMPLATE/config.yml
vendored
Normal file
7
.github/ISSUE_TEMPLATE/config.yml
vendored
Normal file
@@ -0,0 +1,7 @@
|
||||
contact_links:
|
||||
- name: Forum
|
||||
url: https://discuss.huggingface.co/c/discussion-related-to-httpsgithubcomhuggingfacediffusers/63
|
||||
about: General usage questions and community discussions
|
||||
- name: Blank issue
|
||||
url: https://github.com/huggingface/diffusers/issues/new
|
||||
about: Please note that the Forum is in most places the right place for discussions
|
||||
20
.github/ISSUE_TEMPLATE/feature_request.md
vendored
Normal file
20
.github/ISSUE_TEMPLATE/feature_request.md
vendored
Normal file
@@ -0,0 +1,20 @@
|
||||
---
|
||||
name: "\U0001F680 Feature request"
|
||||
about: Suggest an idea for this project
|
||||
title: ''
|
||||
labels: ''
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
**Is your feature request related to a problem? Please describe.**
|
||||
A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
|
||||
|
||||
**Describe the solution you'd like**
|
||||
A clear and concise description of what you want to happen.
|
||||
|
||||
**Describe alternatives you've considered**
|
||||
A clear and concise description of any alternative solutions or features you've considered.
|
||||
|
||||
**Additional context**
|
||||
Add any other context or screenshots about the feature request here.
|
||||
@@ -1 +1 @@
|
||||
include diffusers/utils/model_card_template.md
|
||||
include src/diffusers/utils/model_card_template.md
|
||||
|
||||
5
Makefile
5
Makefile
@@ -79,11 +79,6 @@ test:
|
||||
test-examples:
|
||||
python -m pytest -n auto --dist=loadfile -s -v ./examples/pytorch/
|
||||
|
||||
# Run tests for SageMaker DLC release
|
||||
|
||||
test-sagemaker: # install sagemaker dependencies in advance with pip install .[sagemaker]
|
||||
TEST_SAGEMAKER=True python -m pytest -n auto -s -v ./tests/sagemaker
|
||||
|
||||
|
||||
# Release stuff
|
||||
|
||||
|
||||
198
README.md
198
README.md
@@ -22,9 +22,99 @@ More precisely, 🤗 Diffusers offers:
|
||||
|
||||
- State-of-the-art diffusion pipelines that can be run in inference with just a couple of lines of code (see [src/diffusers/pipelines](https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines)).
|
||||
- Various noise schedulers that can be used interchangeably for the prefered speed vs. quality trade-off in inference (see [src/diffusers/schedulers](https://github.com/huggingface/diffusers/tree/main/src/diffusers/schedulers)).
|
||||
- Multiple types of models, such as UNet, that can be used as building blocks in an end-to-end diffusion system (see [src/diffusers/models](https://github.com/huggingface/diffusers/tree/main/src/diffusers/models)).
|
||||
- Multiple types of models, such as UNet, can be used as building blocks in an end-to-end diffusion system (see [src/diffusers/models](https://github.com/huggingface/diffusers/tree/main/src/diffusers/models)).
|
||||
- Training examples to show how to train the most popular diffusion models (see [examples](https://github.com/huggingface/diffusers/tree/main/examples)).
|
||||
|
||||
## Quickstart
|
||||
|
||||
In order to get started, we recommend taking a look at two notebooks:
|
||||
|
||||
- The [Getting started with Diffusers](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/diffusers_intro.ipynb) notebook, which showcases an end-to-end example of usage for diffusion models, schedulers and pipelines.
|
||||
Take a look at this notebook to learn how to use the pipeline abstraction, which takes care of everything (model, scheduler, noise handling) for you, and also to understand each independent building block in the library.
|
||||
- The [Training a diffusers model](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb) notebook summarizes diffuser model training methods. This notebook takes a step-by-step approach to training your
|
||||
diffuser model on an image dataset, with explanatory graphics.
|
||||
|
||||
## **New 🎨🎨🎨** Stable Diffusion is now fully compatible with `diffusers`!
|
||||
|
||||
Stable Diffusion is a text-to-image latent diffusion model created by the researchers and engineers from [CompVis](https://github.com/CompVis), [Stability AI](https://stability.ai/) and [LAION](https://laion.ai/). It's trained on 512x512 images from a subset of the [LAION-5B](https://laion.ai/blog/laion-5b/) database. This model uses a frozen CLIP ViT-L/14 text encoder to condition the model on text prompts. With its 860M UNet and 123M text encoder, the model is relatively lightweight and runs on a GPU with at least 10GB VRAM.
|
||||
See the [model card](https://huggingface.co/CompVis/stable-diffusion) for more information.
|
||||
|
||||
**The Stable Diffusion weights are currently only available to universities, academics, research institutions and independent researchers. Please request access applying to <a href="https://stability.ai/academia-access-form" target="_blank">this</a> form**
|
||||
|
||||
```py
|
||||
# make sure you're logged in with `huggingface-cli login`
|
||||
from torch import autocast
|
||||
from diffusers import StableDiffusionPipeline, LMSDiscreteScheduler
|
||||
|
||||
lms = LMSDiscreteScheduler(
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
beta_schedule="scaled_linear"
|
||||
)
|
||||
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-3-diffusers",
|
||||
scheduler=lms,
|
||||
use_auth_token=True
|
||||
)
|
||||
|
||||
prompt = "a photo of an astronaut riding a horse on mars"
|
||||
with autocast("cuda"):
|
||||
image = pipe(prompt)["sample"][0]
|
||||
|
||||
image.save("astronaut_rides_horse.png")
|
||||
```
|
||||
|
||||
For more details, check out [the Stable Diffusion notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/stable_diffusion.ipynb)
|
||||
and have a look into the [release notes](https://github.com/huggingface/diffusers/releases/tag/v0.2.0).
|
||||
|
||||
## Examples
|
||||
|
||||
If you want to run the code yourself 💻, you can try out:
|
||||
- [Text-to-Image Latent Diffusion](https://huggingface.co/CompVis/ldm-text2im-large-256)
|
||||
```python
|
||||
# !pip install diffusers transformers
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
model_id = "CompVis/ldm-text2im-large-256"
|
||||
|
||||
# load model and scheduler
|
||||
ldm = DiffusionPipeline.from_pretrained(model_id)
|
||||
|
||||
# run pipeline in inference (sample random noise and denoise)
|
||||
prompt = "A painting of a squirrel eating a burger"
|
||||
images = ldm([prompt], num_inference_steps=50, eta=0.3, guidance_scale=6)["sample"]
|
||||
|
||||
# save images
|
||||
for idx, image in enumerate(images):
|
||||
image.save(f"squirrel-{idx}.png")
|
||||
```
|
||||
- [Unconditional Diffusion with discrete scheduler](https://huggingface.co/google/ddpm-celebahq-256)
|
||||
```python
|
||||
# !pip install diffusers
|
||||
from diffusers import DDPMPipeline, DDIMPipeline, PNDMPipeline
|
||||
|
||||
model_id = "google/ddpm-celebahq-256"
|
||||
|
||||
# load model and scheduler
|
||||
ddpm = DDPMPipeline.from_pretrained(model_id) # you can replace DDPMPipeline with DDIMPipeline or PNDMPipeline for faster inference
|
||||
|
||||
# run pipeline in inference (sample random noise and denoise)
|
||||
image = ddpm()["sample"]
|
||||
|
||||
# save image
|
||||
image[0].save("ddpm_generated_image.png")
|
||||
```
|
||||
- [Unconditional Latent Diffusion](https://huggingface.co/CompVis/ldm-celebahq-256)
|
||||
- [Unconditional Diffusion with continous scheduler](https://huggingface.co/google/ncsnpp-ffhq-1024)
|
||||
|
||||
If you just want to play around with some web demos, you can try out the following 🚀 Spaces:
|
||||
| Model | Hugging Face Spaces |
|
||||
|-------------------------------- |------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| Text-to-Image Latent Diffusion | [](https://huggingface.co/spaces/CompVis/text2img-latent-diffusion) |
|
||||
| Faces generator | [](https://huggingface.co/spaces/CompVis/celeba-latent-diffusion) |
|
||||
| DDPM with different schedulers | [](https://huggingface.co/spaces/fusing/celeba-diffusion) |
|
||||
|
||||
## Definitions
|
||||
|
||||
**Models**: Neural network that models $p_\theta(\mathbf{x}_{t-1}|\mathbf{x}_t)$ (see image below) and is trained end-to-end to *denoise* a noisy input to an image.
|
||||
@@ -59,77 +149,47 @@ The class provides functionality to compute previous image according to alpha, b
|
||||
## Philosophy
|
||||
|
||||
- Readability and clarity is prefered over highly optimized code. A strong importance is put on providing readable, intuitive and elementary code design. *E.g.*, the provided [schedulers](https://github.com/huggingface/diffusers/tree/main/src/diffusers/schedulers) are separated from the provided [models](https://github.com/huggingface/diffusers/tree/main/src/diffusers/models) and provide well-commented code that can be read alongside the original paper.
|
||||
- Diffusers is **modality independent** and focusses on providing pretrained models and tools to build systems that generate **continous outputs**, *e.g.* vision and audio.
|
||||
- Diffusion models and schedulers are provided as consise, elementary building blocks whereas diffusion pipelines are a collection of end-to-end diffusion systems that can be used out-of-the-box, should stay as close as possible to their original implementation and can include components of other library, such as text-encoders. Examples for diffusion pipelines are [Glide](https://github.com/openai/glide-text2im) and [Latent Diffusion](https://github.com/CompVis/latent-diffusion).
|
||||
- Diffusers is **modality independent** and focuses on providing pretrained models and tools to build systems that generate **continous outputs**, *e.g.* vision and audio.
|
||||
- Diffusion models and schedulers are provided as concise, elementary building blocks. In contrast, diffusion pipelines are a collection of end-to-end diffusion systems that can be used out-of-the-box, should stay as close as possible to their original implementation and can include components of another library, such as text-encoders. Examples for diffusion pipelines are [Glide](https://github.com/openai/glide-text2im) and [Latent Diffusion](https://github.com/CompVis/latent-diffusion).
|
||||
|
||||
## Quickstart
|
||||
## Installation
|
||||
|
||||
**Check out this notebook: https://colab.research.google.com/drive/1nMfF04cIxg6FujxsNYi9kiTRrzj4_eZU?usp=sharing**
|
||||
|
||||
### Installation
|
||||
|
||||
```
|
||||
pip install diffusers # should install diffusers 0.0.4
|
||||
**With `pip`**
|
||||
|
||||
```bash
|
||||
pip install --upgrade diffusers # should install diffusers 0.2.1
|
||||
```
|
||||
|
||||
### 1. `diffusers` as a toolbox for schedulers and models
|
||||
**With `conda`**
|
||||
|
||||
`diffusers` is more modularized than `transformers`. The idea is that researchers and engineers can use only parts of the library easily for the own use cases.
|
||||
It could become a central place for all kinds of models, schedulers, training utils and processors that one can mix and match for one's own use case.
|
||||
Both models and schedulers should be load- and saveable from the Hub.
|
||||
|
||||
For more examples see [schedulers](https://github.com/huggingface/diffusers/tree/main/src/diffusers/schedulers) and [models](https://github.com/huggingface/diffusers/tree/main/src/diffusers/models)
|
||||
|
||||
#### **Example for Unconditonal Image generation [DDPM](https://arxiv.org/abs/2006.11239):**
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import UNet2DModel, DDIMScheduler
|
||||
import PIL.Image
|
||||
import numpy as np
|
||||
import tqdm
|
||||
|
||||
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
# 1. Load models
|
||||
scheduler = DDIMScheduler.from_config("fusing/ddpm-celeba-hq", tensor_format="pt")
|
||||
unet = UNet2DModel.from_pretrained("fusing/ddpm-celeba-hq", ddpm=True).to(torch_device)
|
||||
|
||||
# 2. Sample gaussian noise
|
||||
generator = torch.manual_seed(23)
|
||||
unet.image_size = unet.resolution
|
||||
image = torch.randn(
|
||||
(1, unet.in_channels, unet.image_size, unet.image_size),
|
||||
generator=generator,
|
||||
)
|
||||
image = image.to(torch_device)
|
||||
|
||||
# 3. Denoise
|
||||
num_inference_steps = 50
|
||||
eta = 0.0 # <- deterministic sampling
|
||||
scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
for t in tqdm.tqdm(scheduler.timesteps):
|
||||
# 1. predict noise residual
|
||||
with torch.no_grad():
|
||||
residual = unet(image, t)["sample"]
|
||||
|
||||
prev_image = scheduler.step(residual, t, image, eta)["prev_sample"]
|
||||
|
||||
# 3. set current image to prev_image: x_t -> x_t-1
|
||||
image = prev_image
|
||||
|
||||
# 4. process image to PIL
|
||||
image_processed = image.cpu().permute(0, 2, 3, 1)
|
||||
image_processed = (image_processed + 1.0) * 127.5
|
||||
image_processed = image_processed.numpy().astype(np.uint8)
|
||||
image_pil = PIL.Image.fromarray(image_processed[0])
|
||||
|
||||
# 5. save image
|
||||
image_pil.save("generated_image.png")
|
||||
```
|
||||
|
||||
#### **Example for Unconditonal Image generation [LDM](https://github.com/CompVis/latent-diffusion):**
|
||||
|
||||
```python
|
||||
```sh
|
||||
conda install -c conda-forge diffusers
|
||||
```
|
||||
|
||||
## In the works
|
||||
|
||||
For the first release, 🤗 Diffusers focuses on text-to-image diffusion techniques. However, diffusers can be used for much more than that! Over the upcoming releases, we'll be focusing on:
|
||||
|
||||
- Diffusers for audio
|
||||
- Diffusers for reinforcement learning (initial work happening in https://github.com/huggingface/diffusers/pull/105).
|
||||
- Diffusers for video generation
|
||||
- Diffusers for molecule generation (initial work happening in https://github.com/huggingface/diffusers/pull/54)
|
||||
|
||||
A few pipeline components are already being worked on, namely:
|
||||
|
||||
- BDDMPipeline for spectrogram-to-sound vocoding
|
||||
- GLIDEPipeline to support OpenAI's GLIDE model
|
||||
- Grad-TTS for text to audio generation / conditional audio generation
|
||||
|
||||
We want diffusers to be a toolbox useful for diffusers models in general; if you find yourself limited in any way by the current API, or would like to see additional models, schedulers, or techniques, please open a [GitHub issue](https://github.com/huggingface/diffusers/issues) mentioning what you would like to see.
|
||||
|
||||
## Credits
|
||||
|
||||
This library concretizes previous work by many different authors and would not have been possible without their great research and implementations. We'd like to thank, in particular, the following implementations which have helped us in our development and without which the API could not have been as polished today:
|
||||
|
||||
- @CompVis' latent diffusion models library, available [here](https://github.com/CompVis/latent-diffusion)
|
||||
- @hojonathanho original DDPM implementation, available [here](https://github.com/hojonathanho/diffusion) as well as the extremely useful translation into PyTorch by @pesser, available [here](https://github.com/pesser/pytorch_diffusion)
|
||||
- @ermongroup's DDIM implementation, available [here](https://github.com/ermongroup/ddim).
|
||||
- @yang-song's Score-VE and Score-VP implementations, available [here](https://github.com/yang-song/score_sde_pytorch)
|
||||
|
||||
We also want to thank @heejkoo for the very helpful overview of papers, code and resources on diffusion models, available [here](https://github.com/heejkoo/Awesome-Diffusion-Models) as well as @crowsonkb and @rromb for useful discussions and insights.
|
||||
|
||||
@@ -30,4 +30,4 @@ with a `set_format(...)` method.
|
||||
- The ['DDPMScheduler'] was proposed in [Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239) and can be found in [scheduling_ddpm.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_ddpm.py).
|
||||
An example of how to use this scheduler can be found in [pipeline_ddpm.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_ddpm.py).
|
||||
- The ['DDIMScheduler'] was proposed in [Denoising Diffusion Implicit Models](https://arxiv.org/abs/2010.02502) and can be found in [scheduling_ddim.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_ddim.py). An example of how to use this scheduler can be found in [pipeline_ddim.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_ddim.py).
|
||||
- The ['PNMDScheduler'] was proposed in [Pseudo Numerical Methods for Diffusion Models on Manifolds](https://arxiv.org/abs/2202.09778) and can be found in [scheduling_pndm.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_pndm.py). An example of how to use this scheduler can be found in [pipeline_pndm.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_pndm.py).
|
||||
- The ['PNDMScheduler'] was proposed in [Pseudo Numerical Methods for Diffusion Models on Manifolds](https://arxiv.org/abs/2202.09778) and can be found in [scheduling_pndm.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_pndm.py). An example of how to use this scheduler can be found in [pipeline_pndm.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_pndm.py).
|
||||
@@ -1,12 +1,28 @@
|
||||
## Training examples
|
||||
|
||||
Creating a training image set is [described in a different document](https://huggingface.co/docs/datasets/image_process#image-datasets).
|
||||
|
||||
### Installing the dependencies
|
||||
|
||||
Before running the scipts, make sure to install the library's training dependencies:
|
||||
|
||||
```bash
|
||||
pip install diffusers[training] accelerate datasets
|
||||
```
|
||||
|
||||
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
|
||||
|
||||
```bash
|
||||
accelerate config
|
||||
```
|
||||
|
||||
### Unconditional Flowers
|
||||
|
||||
The command to train a DDPM UNet model on the Oxford Flowers dataset:
|
||||
|
||||
```bash
|
||||
accelerate launch train_unconditional.py \
|
||||
--dataset="huggan/flowers-102-categories" \
|
||||
--dataset_name="huggan/flowers-102-categories" \
|
||||
--resolution=64 \
|
||||
--output_dir="ddpm-ema-flowers-64" \
|
||||
--train_batch_size=16 \
|
||||
@@ -17,10 +33,11 @@ accelerate launch train_unconditional.py \
|
||||
--mixed_precision=no \
|
||||
--push_to_hub
|
||||
```
|
||||
An example trained model: https://huggingface.co/anton-l/ddpm-ema-flowers-64
|
||||
|
||||
A full training run takes 2 hours on 4xV100 GPUs.
|
||||
|
||||
<img src="https://user-images.githubusercontent.com/26864830/173855866-5628989f-856b-4725-a944-d6c09490b2df.png" width="500" />
|
||||
<img src="https://user-images.githubusercontent.com/26864830/180248660-a0b143d0-b89a-42c5-8656-2ebf6ece7e52.png" width="700" />
|
||||
|
||||
|
||||
### Unconditional Pokemon
|
||||
@@ -29,7 +46,7 @@ The command to train a DDPM UNet model on the Pokemon dataset:
|
||||
|
||||
```bash
|
||||
accelerate launch train_unconditional.py \
|
||||
--dataset="huggan/pokemon" \
|
||||
--dataset_name="huggan/pokemon" \
|
||||
--resolution=64 \
|
||||
--output_dir="ddpm-ema-pokemon-64" \
|
||||
--train_batch_size=16 \
|
||||
@@ -40,7 +57,73 @@ accelerate launch train_unconditional.py \
|
||||
--mixed_precision=no \
|
||||
--push_to_hub
|
||||
```
|
||||
An example trained model: https://huggingface.co/anton-l/ddpm-ema-pokemon-64
|
||||
|
||||
A full training run takes 2 hours on 4xV100 GPUs.
|
||||
|
||||
<img src="https://user-images.githubusercontent.com/26864830/173856733-4f117f8c-97bd-4f51-8002-56b488c96df9.png" width="500" />
|
||||
<img src="https://user-images.githubusercontent.com/26864830/180248200-928953b4-db38-48db-b0c6-8b740fe6786f.png" width="700" />
|
||||
|
||||
|
||||
### Using your own data
|
||||
|
||||
To use your own dataset, there are 2 ways:
|
||||
- you can either provide your own folder as `--train_data_dir`
|
||||
- or you can upload your dataset to the hub (possibly as a private repo, if you prefer so), and simply pass the `--dataset_name` argument.
|
||||
|
||||
Below, we explain both in more detail.
|
||||
|
||||
#### Provide the dataset as a folder
|
||||
|
||||
If you provide your own folders with images, the script expects the following directory structure:
|
||||
|
||||
```bash
|
||||
data_dir/xxx.png
|
||||
data_dir/xxy.png
|
||||
data_dir/[...]/xxz.png
|
||||
```
|
||||
|
||||
In other words, the script will take care of gathering all images inside the folder. You can then run the script like this:
|
||||
|
||||
```bash
|
||||
accelerate launch train_unconditional.py \
|
||||
--train_data_dir <path-to-train-directory> \
|
||||
<other-arguments>
|
||||
```
|
||||
|
||||
Internally, the script will use the [`ImageFolder`](https://huggingface.co/docs/datasets/v2.0.0/en/image_process#imagefolder) feature which will automatically turn the folders into 🤗 Dataset objects.
|
||||
|
||||
#### Upload your data to the hub, as a (possibly private) repo
|
||||
|
||||
It's very easy (and convenient) to upload your image dataset to the hub using the [`ImageFolder`](https://huggingface.co/docs/datasets/v2.0.0/en/image_process#imagefolder) feature available in 🤗 Datasets. Simply do the following:
|
||||
|
||||
```python
|
||||
from datasets import load_dataset
|
||||
|
||||
# example 1: local folder
|
||||
dataset = load_dataset("imagefolder", data_dir="path_to_your_folder")
|
||||
|
||||
# example 2: local files (suppoted formats are tar, gzip, zip, xz, rar, zstd)
|
||||
dataset = load_dataset("imagefolder", data_files="path_to_zip_file")
|
||||
|
||||
# example 3: remote files (supported formats are tar, gzip, zip, xz, rar, zstd)
|
||||
dataset = load_dataset("imagefolder", data_files="https://download.microsoft.com/download/3/E/1/3E1C3F21-ECDB-4869-8368-6DEBA77B919F/kagglecatsanddogs_3367a.zip")
|
||||
|
||||
# example 4: providing several splits
|
||||
dataset = load_dataset("imagefolder", data_files={"train": ["path/to/file1", "path/to/file2"], "test": ["path/to/file3", "path/to/file4"]})
|
||||
```
|
||||
|
||||
`ImageFolder` will create an `image` column containing the PIL-encoded images.
|
||||
|
||||
Next, push it to the hub!
|
||||
|
||||
```python
|
||||
# assuming you have ran the huggingface-cli login command in a terminal
|
||||
dataset.push_to_hub("name_of_your_dataset")
|
||||
|
||||
# if you want to push to a private repo, simply pass private=True:
|
||||
dataset.push_to_hub("name_of_your_dataset", private=True)
|
||||
```
|
||||
|
||||
and that's it! You can now train your model by simply setting the `--dataset_name` argument to the name of your dataset on the hub.
|
||||
|
||||
More on this can also be found in [this blog post](https://huggingface.co/blog/image-search-datasets).
|
||||
|
||||
@@ -1,201 +0,0 @@
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
import bitsandbytes as bnb
|
||||
import PIL.Image
|
||||
from accelerate import Accelerator
|
||||
from datasets import load_dataset
|
||||
from diffusers import DDPMScheduler, Glide, GlideUNetModel
|
||||
from diffusers.hub_utils import init_git_repo, push_to_hub
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.utils import logging
|
||||
from torchvision.transforms import (
|
||||
CenterCrop,
|
||||
Compose,
|
||||
InterpolationMode,
|
||||
Normalize,
|
||||
RandomHorizontalFlip,
|
||||
Resize,
|
||||
ToTensor,
|
||||
)
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def main(args):
|
||||
accelerator = Accelerator(mixed_precision=args.mixed_precision)
|
||||
|
||||
pipeline = Glide.from_pretrained("fusing/glide-base")
|
||||
model = pipeline.text_unet
|
||||
noise_scheduler = DDPMScheduler(timesteps=1000, tensor_format="pt")
|
||||
optimizer = bnb.optim.Adam8bit(model.parameters(), lr=args.lr)
|
||||
|
||||
augmentations = Compose(
|
||||
[
|
||||
Resize(args.resolution, interpolation=InterpolationMode.BILINEAR),
|
||||
CenterCrop(args.resolution),
|
||||
RandomHorizontalFlip(),
|
||||
ToTensor(),
|
||||
Normalize([0.5], [0.5]),
|
||||
]
|
||||
)
|
||||
dataset = load_dataset(args.dataset, split="train")
|
||||
|
||||
text_encoder = pipeline.text_encoder.eval()
|
||||
|
||||
def transforms(examples):
|
||||
images = [augmentations(image.convert("RGB")) for image in examples["image"]]
|
||||
text_inputs = pipeline.tokenizer(examples["caption"], padding="max_length", max_length=77, return_tensors="pt")
|
||||
text_inputs = text_inputs.input_ids.to(accelerator.device)
|
||||
with torch.no_grad():
|
||||
text_embeddings = accelerator.unwrap_model(text_encoder)(text_inputs).last_hidden_state
|
||||
return {"images": images, "text_embeddings": text_embeddings}
|
||||
|
||||
dataset.set_transform(transforms)
|
||||
train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=True)
|
||||
|
||||
lr_scheduler = get_scheduler(
|
||||
"linear",
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=args.warmup_steps,
|
||||
num_training_steps=(len(train_dataloader) * args.num_epochs) // args.gradient_accumulation_steps,
|
||||
)
|
||||
|
||||
model, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
model, text_encoder, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
|
||||
if args.push_to_hub:
|
||||
repo = init_git_repo(args, at_init=True)
|
||||
|
||||
# Train!
|
||||
is_distributed = torch.distributed.is_available() and torch.distributed.is_initialized()
|
||||
world_size = torch.distributed.get_world_size() if is_distributed else 1
|
||||
total_train_batch_size = args.batch_size * args.gradient_accumulation_steps * world_size
|
||||
max_steps = len(train_dataloader) // args.gradient_accumulation_steps * args.num_epochs
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(f" Num examples = {len(train_dataloader.dataset)}")
|
||||
logger.info(f" Num Epochs = {args.num_epochs}")
|
||||
logger.info(f" Instantaneous batch size per device = {args.batch_size}")
|
||||
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}")
|
||||
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
|
||||
logger.info(f" Total optimization steps = {max_steps}")
|
||||
|
||||
for epoch in range(args.num_epochs):
|
||||
model.train()
|
||||
with tqdm(total=len(train_dataloader), unit="ba") as pbar:
|
||||
pbar.set_description(f"Epoch {epoch}")
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
clean_images = batch["images"]
|
||||
batch_size, n_channels, height, width = clean_images.shape
|
||||
noise_samples = torch.randn(clean_images.shape).to(clean_images.device)
|
||||
timesteps = torch.randint(
|
||||
0, noise_scheduler.timesteps, (batch_size,), device=clean_images.device
|
||||
).long()
|
||||
|
||||
# add noise onto the clean images according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process)
|
||||
noisy_images = noise_scheduler.training_step(clean_images, noise_samples, timesteps)
|
||||
|
||||
if step % args.gradient_accumulation_steps != 0:
|
||||
with accelerator.no_sync(model):
|
||||
model_output = model(noisy_images, timesteps, batch["text_embeddings"])
|
||||
model_output, model_var_values = torch.split(model_output, n_channels, dim=1)
|
||||
# Learn the variance using the variational bound, but don't let
|
||||
# it affect our mean prediction.
|
||||
frozen_out = torch.cat([model_output.detach(), model_var_values], dim=1)
|
||||
|
||||
# predict the noise residual
|
||||
loss = F.mse_loss(model_output, noise_samples)
|
||||
|
||||
loss = loss / args.gradient_accumulation_steps
|
||||
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
else:
|
||||
model_output = model(noisy_images, timesteps, batch["text_embeddings"])
|
||||
model_output, model_var_values = torch.split(model_output, n_channels, dim=1)
|
||||
# Learn the variance using the variational bound, but don't let
|
||||
# it affect our mean prediction.
|
||||
frozen_out = torch.cat([model_output.detach(), model_var_values], dim=1)
|
||||
|
||||
# predict the noise residual
|
||||
loss = F.mse_loss(model_output, noise_samples)
|
||||
loss = loss / args.gradient_accumulation_steps
|
||||
accelerator.backward(loss)
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
pbar.update(1)
|
||||
pbar.set_postfix(loss=loss.detach().item(), lr=optimizer.param_groups[0]["lr"])
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# Generate a sample image for visual inspection
|
||||
if accelerator.is_main_process:
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
pipeline.unet = accelerator.unwrap_model(model)
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
# run pipeline in inference (sample random noise and denoise)
|
||||
image = pipeline("a clip art of a corgi", generator=generator, num_upscale_inference_steps=50)
|
||||
|
||||
# process image to PIL
|
||||
image_processed = image.squeeze(0)
|
||||
image_processed = ((image_processed + 1) * 127.5).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
|
||||
image_pil = PIL.Image.fromarray(image_processed)
|
||||
|
||||
# save image
|
||||
test_dir = os.path.join(args.output_dir, "test_samples")
|
||||
os.makedirs(test_dir, exist_ok=True)
|
||||
image_pil.save(f"{test_dir}/{epoch:04d}.png")
|
||||
|
||||
# save the model
|
||||
if args.push_to_hub:
|
||||
push_to_hub(args, pipeline, repo, commit_message=f"Epoch {epoch}", blocking=False)
|
||||
else:
|
||||
pipeline.save_pretrained(args.output_dir)
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
||||
parser.add_argument("--local_rank", type=int, default=-1)
|
||||
parser.add_argument("--dataset", type=str, default="fusing/dog_captions")
|
||||
parser.add_argument("--output_dir", type=str, default="glide-text2image")
|
||||
parser.add_argument("--overwrite_output_dir", action="store_true")
|
||||
parser.add_argument("--resolution", type=int, default=64)
|
||||
parser.add_argument("--batch_size", type=int, default=4)
|
||||
parser.add_argument("--num_epochs", type=int, default=100)
|
||||
parser.add_argument("--gradient_accumulation_steps", type=int, default=4)
|
||||
parser.add_argument("--lr", type=float, default=1e-4)
|
||||
parser.add_argument("--warmup_steps", type=int, default=500)
|
||||
parser.add_argument("--push_to_hub", action="store_true")
|
||||
parser.add_argument("--hub_token", type=str, default=None)
|
||||
parser.add_argument("--hub_model_id", type=str, default=None)
|
||||
parser.add_argument("--hub_private_repo", action="store_true")
|
||||
parser.add_argument(
|
||||
"--mixed_precision",
|
||||
type=str,
|
||||
default="no",
|
||||
choices=["no", "fp16", "bf16"],
|
||||
help=(
|
||||
"Whether to use mixed precision. Choose"
|
||||
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
|
||||
"and an Nvidia Ampere GPU."
|
||||
),
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
||||
if env_local_rank != -1 and env_local_rank != args.local_rank:
|
||||
args.local_rank = env_local_rank
|
||||
|
||||
main(args)
|
||||
@@ -1,216 +0,0 @@
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
import bitsandbytes as bnb
|
||||
import PIL.Image
|
||||
from accelerate import Accelerator
|
||||
from datasets import load_dataset
|
||||
from diffusers import DDPMScheduler, LatentDiffusion, UNetLDMModel
|
||||
from diffusers.hub_utils import init_git_repo, push_to_hub
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.utils import logging
|
||||
from torchvision.transforms import (
|
||||
CenterCrop,
|
||||
Compose,
|
||||
InterpolationMode,
|
||||
Normalize,
|
||||
RandomHorizontalFlip,
|
||||
Resize,
|
||||
ToTensor,
|
||||
)
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def main(args):
|
||||
accelerator = Accelerator(mixed_precision=args.mixed_precision)
|
||||
|
||||
pipeline = LatentDiffusion.from_pretrained("fusing/latent-diffusion-text2im-large")
|
||||
pipeline.unet = None # this model will be trained from scratch now
|
||||
model = UNetLDMModel(
|
||||
attention_resolutions=[4, 2, 1],
|
||||
channel_mult=[1, 2, 4, 4],
|
||||
context_dim=1280,
|
||||
conv_resample=True,
|
||||
dims=2,
|
||||
dropout=0,
|
||||
image_size=8,
|
||||
in_channels=4,
|
||||
model_channels=320,
|
||||
num_heads=8,
|
||||
num_res_blocks=2,
|
||||
out_channels=4,
|
||||
resblock_updown=False,
|
||||
transformer_depth=1,
|
||||
use_new_attention_order=False,
|
||||
use_scale_shift_norm=False,
|
||||
use_spatial_transformer=True,
|
||||
legacy=False,
|
||||
)
|
||||
noise_scheduler = DDPMScheduler(timesteps=1000, tensor_format="pt")
|
||||
optimizer = bnb.optim.Adam8bit(model.parameters(), lr=args.lr)
|
||||
|
||||
augmentations = Compose(
|
||||
[
|
||||
Resize(args.resolution, interpolation=InterpolationMode.BILINEAR),
|
||||
CenterCrop(args.resolution),
|
||||
RandomHorizontalFlip(),
|
||||
ToTensor(),
|
||||
Normalize([0.5], [0.5]),
|
||||
]
|
||||
)
|
||||
dataset = load_dataset(args.dataset, split="train")
|
||||
|
||||
text_encoder = pipeline.bert.eval()
|
||||
vqvae = pipeline.vqvae.eval()
|
||||
|
||||
def transforms(examples):
|
||||
images = [augmentations(image.convert("RGB")) for image in examples["image"]]
|
||||
text_inputs = pipeline.tokenizer(examples["caption"], padding="max_length", max_length=77, return_tensors="pt")
|
||||
with torch.no_grad():
|
||||
text_embeddings = accelerator.unwrap_model(text_encoder)(text_inputs.input_ids.cpu()).last_hidden_state
|
||||
images = 1 / 0.18215 * torch.stack(images, dim=0)
|
||||
latents = accelerator.unwrap_model(vqvae).encode(images.cpu()).mode()
|
||||
return {"images": images, "text_embeddings": text_embeddings, "latents": latents}
|
||||
|
||||
dataset.set_transform(transforms)
|
||||
train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=True)
|
||||
|
||||
lr_scheduler = get_scheduler(
|
||||
"linear",
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=args.warmup_steps,
|
||||
num_training_steps=(len(train_dataloader) * args.num_epochs) // args.gradient_accumulation_steps,
|
||||
)
|
||||
|
||||
model, text_encoder, vqvae, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
model, text_encoder, vqvae, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
text_encoder = text_encoder.cpu()
|
||||
vqvae = vqvae.cpu()
|
||||
|
||||
if args.push_to_hub:
|
||||
repo = init_git_repo(args, at_init=True)
|
||||
|
||||
# Train!
|
||||
is_distributed = torch.distributed.is_available() and torch.distributed.is_initialized()
|
||||
world_size = torch.distributed.get_world_size() if is_distributed else 1
|
||||
total_train_batch_size = args.batch_size * args.gradient_accumulation_steps * world_size
|
||||
max_steps = len(train_dataloader) // args.gradient_accumulation_steps * args.num_epochs
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(f" Num examples = {len(train_dataloader.dataset)}")
|
||||
logger.info(f" Num Epochs = {args.num_epochs}")
|
||||
logger.info(f" Instantaneous batch size per device = {args.batch_size}")
|
||||
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}")
|
||||
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
|
||||
logger.info(f" Total optimization steps = {max_steps}")
|
||||
|
||||
global_step = 0
|
||||
for epoch in range(args.num_epochs):
|
||||
model.train()
|
||||
with tqdm(total=len(train_dataloader), unit="ba") as pbar:
|
||||
pbar.set_description(f"Epoch {epoch}")
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
clean_latents = batch["latents"]
|
||||
noise_samples = torch.randn(clean_latents.shape).to(clean_latents.device)
|
||||
bsz = clean_latents.shape[0]
|
||||
timesteps = torch.randint(0, noise_scheduler.timesteps, (bsz,), device=clean_latents.device).long()
|
||||
|
||||
# add noise onto the clean latents according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process)
|
||||
noisy_latents = noise_scheduler.training_step(clean_latents, noise_samples, timesteps)
|
||||
|
||||
if step % args.gradient_accumulation_steps != 0:
|
||||
with accelerator.no_sync(model):
|
||||
output = model(noisy_latents, timesteps, context=batch["text_embeddings"])
|
||||
# predict the noise residual
|
||||
loss = F.mse_loss(output, noise_samples)
|
||||
loss = loss / args.gradient_accumulation_steps
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
else:
|
||||
output = model(noisy_latents, timesteps, context=batch["text_embeddings"])
|
||||
# predict the noise residual
|
||||
loss = F.mse_loss(output, noise_samples)
|
||||
loss = loss / args.gradient_accumulation_steps
|
||||
accelerator.backward(loss)
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
pbar.update(1)
|
||||
pbar.set_postfix(loss=loss.detach().item(), lr=optimizer.param_groups[0]["lr"])
|
||||
global_step += 1
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# Generate a sample image for visual inspection
|
||||
if accelerator.is_main_process:
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
pipeline.unet = accelerator.unwrap_model(model)
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
# run pipeline in inference (sample random noise and denoise)
|
||||
image = pipeline(
|
||||
["a clip art of a corgi"], generator=generator, eta=0.3, guidance_scale=6.0, num_inference_steps=50
|
||||
)
|
||||
|
||||
# process image to PIL
|
||||
image_processed = image.cpu().permute(0, 2, 3, 1)
|
||||
image_processed = image_processed * 255.0
|
||||
image_processed = image_processed.type(torch.uint8).numpy()
|
||||
image_pil = PIL.Image.fromarray(image_processed[0])
|
||||
|
||||
# save image
|
||||
test_dir = os.path.join(args.output_dir, "test_samples")
|
||||
os.makedirs(test_dir, exist_ok=True)
|
||||
image_pil.save(f"{test_dir}/{epoch:04d}.png")
|
||||
|
||||
# save the model
|
||||
if args.push_to_hub:
|
||||
push_to_hub(args, pipeline, repo, commit_message=f"Epoch {epoch}", blocking=False)
|
||||
else:
|
||||
pipeline.save_pretrained(args.output_dir)
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
||||
parser.add_argument("--local_rank", type=int, default=-1)
|
||||
parser.add_argument("--dataset", type=str, default="fusing/dog_captions")
|
||||
parser.add_argument("--output_dir", type=str, default="ldm-text2image")
|
||||
parser.add_argument("--overwrite_output_dir", action="store_true")
|
||||
parser.add_argument("--resolution", type=int, default=128)
|
||||
parser.add_argument("--batch_size", type=int, default=1)
|
||||
parser.add_argument("--num_epochs", type=int, default=100)
|
||||
parser.add_argument("--gradient_accumulation_steps", type=int, default=16)
|
||||
parser.add_argument("--lr", type=float, default=1e-4)
|
||||
parser.add_argument("--warmup_steps", type=int, default=500)
|
||||
parser.add_argument("--push_to_hub", action="store_true")
|
||||
parser.add_argument("--hub_token", type=str, default=None)
|
||||
parser.add_argument("--hub_model_id", type=str, default=None)
|
||||
parser.add_argument("--hub_private_repo", action="store_true")
|
||||
parser.add_argument(
|
||||
"--mixed_precision",
|
||||
type=str,
|
||||
default="no",
|
||||
choices=["no", "fp16", "bf16"],
|
||||
help=(
|
||||
"Whether to use mixed precision. Choose"
|
||||
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
|
||||
"and an Nvidia Ampere GPU."
|
||||
),
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
||||
if env_local_rank != -1 and env_local_rank != args.local_rank:
|
||||
args.local_rank = env_local_rank
|
||||
|
||||
main(args)
|
||||
@@ -7,7 +7,7 @@ import torch.nn.functional as F
|
||||
from accelerate import Accelerator
|
||||
from accelerate.logging import get_logger
|
||||
from datasets import load_dataset
|
||||
from diffusers import DDPMPipeline, DDPMScheduler, UNetUnconditionalModel
|
||||
from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
|
||||
from diffusers.hub_utils import init_git_repo, push_to_hub
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import EMAModel
|
||||
@@ -34,27 +34,27 @@ def main(args):
|
||||
logging_dir=logging_dir,
|
||||
)
|
||||
|
||||
model = UNetUnconditionalModel(
|
||||
image_size=args.resolution,
|
||||
model = UNet2DModel(
|
||||
sample_size=args.resolution,
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
num_res_blocks=2,
|
||||
block_channels=(128, 128, 256, 256, 512, 512),
|
||||
down_blocks=(
|
||||
"UNetResDownBlock2D",
|
||||
"UNetResDownBlock2D",
|
||||
"UNetResDownBlock2D",
|
||||
"UNetResDownBlock2D",
|
||||
"UNetResAttnDownBlock2D",
|
||||
"UNetResDownBlock2D",
|
||||
layers_per_block=2,
|
||||
block_out_channels=(128, 128, 256, 256, 512, 512),
|
||||
down_block_types=(
|
||||
"DownBlock2D",
|
||||
"DownBlock2D",
|
||||
"DownBlock2D",
|
||||
"DownBlock2D",
|
||||
"AttnDownBlock2D",
|
||||
"DownBlock2D",
|
||||
),
|
||||
up_blocks=(
|
||||
"UNetResUpBlock2D",
|
||||
"UNetResAttnUpBlock2D",
|
||||
"UNetResUpBlock2D",
|
||||
"UNetResUpBlock2D",
|
||||
"UNetResUpBlock2D",
|
||||
"UNetResUpBlock2D",
|
||||
up_block_types=(
|
||||
"UpBlock2D",
|
||||
"AttnUpBlock2D",
|
||||
"UpBlock2D",
|
||||
"UpBlock2D",
|
||||
"UpBlock2D",
|
||||
"UpBlock2D",
|
||||
),
|
||||
)
|
||||
noise_scheduler = DDPMScheduler(num_train_timesteps=1000, tensor_format="pt")
|
||||
@@ -75,7 +75,17 @@ def main(args):
|
||||
Normalize([0.5], [0.5]),
|
||||
]
|
||||
)
|
||||
dataset = load_dataset(args.dataset, split="train")
|
||||
|
||||
if args.dataset_name is not None:
|
||||
dataset = load_dataset(
|
||||
args.dataset_name,
|
||||
args.dataset_config_name,
|
||||
cache_dir=args.cache_dir,
|
||||
use_auth_token=True if args.use_auth_token else None,
|
||||
split="train",
|
||||
)
|
||||
else:
|
||||
dataset = load_dataset("imagefolder", data_dir=args.train_data_dir, cache_dir=args.cache_dir, split="train")
|
||||
|
||||
def transforms(examples):
|
||||
images = [augmentations(image.convert("RGB")) for image in examples["image"]]
|
||||
@@ -147,9 +157,9 @@ def main(args):
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# Generate a sample image for visual inspection
|
||||
# Generate sample images for visual inspection
|
||||
if accelerator.is_main_process:
|
||||
with torch.no_grad():
|
||||
if epoch % args.save_images_epochs == 0 or epoch == args.num_epochs - 1:
|
||||
pipeline = DDPMPipeline(
|
||||
unet=accelerator.unwrap_model(ema_model.averaged_model if args.use_ema else model),
|
||||
scheduler=noise_scheduler,
|
||||
@@ -157,13 +167,13 @@ def main(args):
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
# run pipeline in inference (sample random noise and denoise)
|
||||
images = pipeline(generator=generator, batch_size=args.eval_batch_size)
|
||||
images = pipeline(generator=generator, batch_size=args.eval_batch_size, output_type="numpy")["sample"]
|
||||
|
||||
# denormalize the images and save to tensorboard
|
||||
images_processed = (images.cpu() + 1.0) * 127.5
|
||||
images_processed = images_processed.clamp(0, 255).type(torch.uint8).numpy()
|
||||
|
||||
accelerator.trackers[0].writer.add_images("test_samples", images_processed, epoch)
|
||||
# denormalize the images and save to tensorboard
|
||||
images_processed = (images * 255).round().astype("uint8")
|
||||
accelerator.trackers[0].writer.add_images(
|
||||
"test_samples", images_processed.transpose(0, 3, 1, 2), epoch
|
||||
)
|
||||
|
||||
if epoch % args.save_model_epochs == 0 or epoch == args.num_epochs - 1:
|
||||
# save the model
|
||||
@@ -179,14 +189,18 @@ def main(args):
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
||||
parser.add_argument("--local_rank", type=int, default=-1)
|
||||
parser.add_argument("--dataset", type=str, default="huggan/flowers-102-categories")
|
||||
parser.add_argument("--output_dir", type=str, default="ddpm-flowers-64")
|
||||
parser.add_argument("--dataset_name", type=str, default=None)
|
||||
parser.add_argument("--dataset_config_name", type=str, default=None)
|
||||
parser.add_argument("--train_data_dir", type=str, default=None, help="A folder containing the training data.")
|
||||
parser.add_argument("--output_dir", type=str, default="ddpm-model-64")
|
||||
parser.add_argument("--overwrite_output_dir", action="store_true")
|
||||
parser.add_argument("--cache_dir", type=str, default=None)
|
||||
parser.add_argument("--resolution", type=int, default=64)
|
||||
parser.add_argument("--train_batch_size", type=int, default=16)
|
||||
parser.add_argument("--eval_batch_size", type=int, default=16)
|
||||
parser.add_argument("--num_epochs", type=int, default=100)
|
||||
parser.add_argument("--save_model_epochs", type=int, default=5)
|
||||
parser.add_argument("--save_images_epochs", type=int, default=10)
|
||||
parser.add_argument("--save_model_epochs", type=int, default=10)
|
||||
parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
|
||||
parser.add_argument("--learning_rate", type=float, default=1e-4)
|
||||
parser.add_argument("--lr_scheduler", type=str, default="cosine")
|
||||
@@ -194,12 +208,13 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--adam_beta1", type=float, default=0.95)
|
||||
parser.add_argument("--adam_beta2", type=float, default=0.999)
|
||||
parser.add_argument("--adam_weight_decay", type=float, default=1e-6)
|
||||
parser.add_argument("--adam_epsilon", type=float, default=1e-3)
|
||||
parser.add_argument("--adam_epsilon", type=float, default=1e-08)
|
||||
parser.add_argument("--use_ema", action="store_true", default=True)
|
||||
parser.add_argument("--ema_inv_gamma", type=float, default=1.0)
|
||||
parser.add_argument("--ema_power", type=float, default=3 / 4)
|
||||
parser.add_argument("--ema_max_decay", type=float, default=0.9999)
|
||||
parser.add_argument("--push_to_hub", action="store_true")
|
||||
parser.add_argument("--use_auth_token", action="store_true")
|
||||
parser.add_argument("--hub_token", type=str, default=None)
|
||||
parser.add_argument("--hub_model_id", type=str, default=None)
|
||||
parser.add_argument("--hub_private_repo", action="store_true")
|
||||
@@ -221,4 +236,7 @@ if __name__ == "__main__":
|
||||
if env_local_rank != -1 and env_local_rank != args.local_rank:
|
||||
args.local_rank = env_local_rank
|
||||
|
||||
if args.dataset_name is None and args.train_data_dir is None:
|
||||
raise ValueError("You must specify either a dataset name from the hub or a train data directory.")
|
||||
|
||||
main(args)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from diffusers import UNet2DModel, DDPMScheduler, DDPMPipeline
|
||||
from diffusers import UNet2DModel, DDPMScheduler, DDPMPipeline, VQModel, AutoencoderKL
|
||||
import argparse
|
||||
import json
|
||||
import torch
|
||||
@@ -64,7 +64,7 @@ def assign_to_checkpoint(paths, checkpoint, old_checkpoint, attention_paths_to_s
|
||||
|
||||
target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
|
||||
|
||||
num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
|
||||
num_heads = old_tensor.shape[0] // config.get("num_head_channels", 1) // 3
|
||||
|
||||
old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
|
||||
query, key, value = old_tensor.split(channels // num_heads, dim=1)
|
||||
@@ -79,7 +79,7 @@ def assign_to_checkpoint(paths, checkpoint, old_checkpoint, attention_paths_to_s
|
||||
if attention_paths_to_split is not None and new_path in attention_paths_to_split:
|
||||
continue
|
||||
|
||||
new_path = new_path.replace('down.', 'downsample_blocks.')
|
||||
new_path = new_path.replace('down.', 'down_blocks.')
|
||||
new_path = new_path.replace('up.', 'up_blocks.')
|
||||
|
||||
if additional_replacements is not None:
|
||||
@@ -111,36 +111,36 @@ def convert_ddpm_checkpoint(checkpoint, config):
|
||||
new_checkpoint['conv_out.weight'] = checkpoint['conv_out.weight']
|
||||
new_checkpoint['conv_out.bias'] = checkpoint['conv_out.bias']
|
||||
|
||||
num_downsample_blocks = len({'.'.join(layer.split('.')[:2]) for layer in checkpoint if 'down' in layer})
|
||||
downsample_blocks = {layer_id: [key for key in checkpoint if f'down.{layer_id}' in key] for layer_id in range(num_downsample_blocks)}
|
||||
num_down_blocks = len({'.'.join(layer.split('.')[:2]) for layer in checkpoint if 'down' in layer})
|
||||
down_blocks = {layer_id: [key for key in checkpoint if f'down.{layer_id}' in key] for layer_id in range(num_down_blocks)}
|
||||
|
||||
num_up_blocks = len({'.'.join(layer.split('.')[:2]) for layer in checkpoint if 'up' in layer})
|
||||
up_blocks = {layer_id: [key for key in checkpoint if f'up.{layer_id}' in key] for layer_id in range(num_up_blocks)}
|
||||
|
||||
for i in range(num_downsample_blocks):
|
||||
block_id = (i - 1) // (config['num_res_blocks'] + 1)
|
||||
for i in range(num_down_blocks):
|
||||
block_id = (i - 1) // (config['layers_per_block'] + 1)
|
||||
|
||||
if any('downsample' in layer for layer in downsample_blocks[i]):
|
||||
new_checkpoint[f'downsample_blocks.{i}.downsamplers.0.conv.weight'] = checkpoint[f'down.{i}.downsample.conv.weight']
|
||||
new_checkpoint[f'downsample_blocks.{i}.downsamplers.0.conv.bias'] = checkpoint[f'down.{i}.downsample.conv.bias']
|
||||
new_checkpoint[f'downsample_blocks.{i}.downsamplers.0.op.weight'] = checkpoint[f'down.{i}.downsample.conv.weight']
|
||||
new_checkpoint[f'downsample_blocks.{i}.downsamplers.0.op.bias'] = checkpoint[f'down.{i}.downsample.conv.bias']
|
||||
if any('downsample' in layer for layer in down_blocks[i]):
|
||||
new_checkpoint[f'down_blocks.{i}.downsamplers.0.conv.weight'] = checkpoint[f'down.{i}.downsample.op.weight']
|
||||
new_checkpoint[f'down_blocks.{i}.downsamplers.0.conv.bias'] = checkpoint[f'down.{i}.downsample.op.bias']
|
||||
# new_checkpoint[f'down_blocks.{i}.downsamplers.0.op.weight'] = checkpoint[f'down.{i}.downsample.conv.weight']
|
||||
# new_checkpoint[f'down_blocks.{i}.downsamplers.0.op.bias'] = checkpoint[f'down.{i}.downsample.conv.bias']
|
||||
|
||||
if any('block' in layer for layer in downsample_blocks[i]):
|
||||
num_blocks = len({'.'.join(shave_segments(layer, 2).split('.')[:2]) for layer in downsample_blocks[i] if 'block' in layer})
|
||||
blocks = {layer_id: [key for key in downsample_blocks[i] if f'block.{layer_id}' in key] for layer_id in range(num_blocks)}
|
||||
if any('block' in layer for layer in down_blocks[i]):
|
||||
num_blocks = len({'.'.join(shave_segments(layer, 2).split('.')[:2]) for layer in down_blocks[i] if 'block' in layer})
|
||||
blocks = {layer_id: [key for key in down_blocks[i] if f'block.{layer_id}' in key] for layer_id in range(num_blocks)}
|
||||
|
||||
if num_blocks > 0:
|
||||
for j in range(config['num_res_blocks']):
|
||||
for j in range(config['layers_per_block']):
|
||||
paths = renew_resnet_paths(blocks[j])
|
||||
assign_to_checkpoint(paths, new_checkpoint, checkpoint)
|
||||
|
||||
if any('attn' in layer for layer in downsample_blocks[i]):
|
||||
num_attn = len({'.'.join(shave_segments(layer, 2).split('.')[:2]) for layer in downsample_blocks[i] if 'attn' in layer})
|
||||
attns = {layer_id: [key for key in downsample_blocks[i] if f'attn.{layer_id}' in key] for layer_id in range(num_blocks)}
|
||||
if any('attn' in layer for layer in down_blocks[i]):
|
||||
num_attn = len({'.'.join(shave_segments(layer, 2).split('.')[:2]) for layer in down_blocks[i] if 'attn' in layer})
|
||||
attns = {layer_id: [key for key in down_blocks[i] if f'attn.{layer_id}' in key] for layer_id in range(num_blocks)}
|
||||
|
||||
if num_attn > 0:
|
||||
for j in range(config['num_res_blocks']):
|
||||
for j in range(config['layers_per_block']):
|
||||
paths = renew_attention_paths(attns[j])
|
||||
assign_to_checkpoint(paths, new_checkpoint, checkpoint, config=config)
|
||||
|
||||
@@ -176,7 +176,7 @@ def convert_ddpm_checkpoint(checkpoint, config):
|
||||
blocks = {layer_id: [key for key in up_blocks[i] if f'block.{layer_id}' in key] for layer_id in range(num_blocks)}
|
||||
|
||||
if num_blocks > 0:
|
||||
for j in range(config['num_res_blocks'] + 1):
|
||||
for j in range(config['layers_per_block'] + 1):
|
||||
replace_indices = {'old': f'up_blocks.{i}', 'new': f'up_blocks.{block_id}'}
|
||||
paths = renew_resnet_paths(blocks[j])
|
||||
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[replace_indices])
|
||||
@@ -186,7 +186,7 @@ def convert_ddpm_checkpoint(checkpoint, config):
|
||||
attns = {layer_id: [key for key in up_blocks[i] if f'attn.{layer_id}' in key] for layer_id in range(num_blocks)}
|
||||
|
||||
if num_attn > 0:
|
||||
for j in range(config['num_res_blocks'] + 1):
|
||||
for j in range(config['layers_per_block'] + 1):
|
||||
replace_indices = {'old': f'up_blocks.{i}', 'new': f'up_blocks.{block_id}'}
|
||||
paths = renew_attention_paths(attns[j])
|
||||
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[replace_indices])
|
||||
@@ -195,6 +195,117 @@ def convert_ddpm_checkpoint(checkpoint, config):
|
||||
return new_checkpoint
|
||||
|
||||
|
||||
def convert_vq_autoenc_checkpoint(checkpoint, config):
|
||||
"""
|
||||
Takes a state dict and a config, and returns a converted checkpoint.
|
||||
"""
|
||||
new_checkpoint = {}
|
||||
|
||||
new_checkpoint['encoder.conv_norm_out.weight'] = checkpoint['encoder.norm_out.weight']
|
||||
new_checkpoint['encoder.conv_norm_out.bias'] = checkpoint['encoder.norm_out.bias']
|
||||
|
||||
new_checkpoint['encoder.conv_in.weight'] = checkpoint['encoder.conv_in.weight']
|
||||
new_checkpoint['encoder.conv_in.bias'] = checkpoint['encoder.conv_in.bias']
|
||||
new_checkpoint['encoder.conv_out.weight'] = checkpoint['encoder.conv_out.weight']
|
||||
new_checkpoint['encoder.conv_out.bias'] = checkpoint['encoder.conv_out.bias']
|
||||
|
||||
new_checkpoint['decoder.conv_norm_out.weight'] = checkpoint['decoder.norm_out.weight']
|
||||
new_checkpoint['decoder.conv_norm_out.bias'] = checkpoint['decoder.norm_out.bias']
|
||||
|
||||
new_checkpoint['decoder.conv_in.weight'] = checkpoint['decoder.conv_in.weight']
|
||||
new_checkpoint['decoder.conv_in.bias'] = checkpoint['decoder.conv_in.bias']
|
||||
new_checkpoint['decoder.conv_out.weight'] = checkpoint['decoder.conv_out.weight']
|
||||
new_checkpoint['decoder.conv_out.bias'] = checkpoint['decoder.conv_out.bias']
|
||||
|
||||
num_down_blocks = len({'.'.join(layer.split('.')[:3]) for layer in checkpoint if 'down' in layer})
|
||||
down_blocks = {layer_id: [key for key in checkpoint if f'down.{layer_id}' in key] for layer_id in range(num_down_blocks)}
|
||||
|
||||
num_up_blocks = len({'.'.join(layer.split('.')[:3]) for layer in checkpoint if 'up' in layer})
|
||||
up_blocks = {layer_id: [key for key in checkpoint if f'up.{layer_id}' in key] for layer_id in range(num_up_blocks)}
|
||||
|
||||
for i in range(num_down_blocks):
|
||||
block_id = (i - 1) // (config['layers_per_block'] + 1)
|
||||
|
||||
if any('downsample' in layer for layer in down_blocks[i]):
|
||||
new_checkpoint[f'encoder.down_blocks.{i}.downsamplers.0.conv.weight'] = checkpoint[f'encoder.down.{i}.downsample.conv.weight']
|
||||
new_checkpoint[f'encoder.down_blocks.{i}.downsamplers.0.conv.bias'] = checkpoint[f'encoder.down.{i}.downsample.conv.bias']
|
||||
|
||||
if any('block' in layer for layer in down_blocks[i]):
|
||||
num_blocks = len({'.'.join(shave_segments(layer, 3).split('.')[:3]) for layer in down_blocks[i] if 'block' in layer})
|
||||
blocks = {layer_id: [key for key in down_blocks[i] if f'block.{layer_id}' in key] for layer_id in range(num_blocks)}
|
||||
|
||||
if num_blocks > 0:
|
||||
for j in range(config['layers_per_block']):
|
||||
paths = renew_resnet_paths(blocks[j])
|
||||
assign_to_checkpoint(paths, new_checkpoint, checkpoint)
|
||||
|
||||
if any('attn' in layer for layer in down_blocks[i]):
|
||||
num_attn = len({'.'.join(shave_segments(layer, 3).split('.')[:3]) for layer in down_blocks[i] if 'attn' in layer})
|
||||
attns = {layer_id: [key for key in down_blocks[i] if f'attn.{layer_id}' in key] for layer_id in range(num_blocks)}
|
||||
|
||||
if num_attn > 0:
|
||||
for j in range(config['layers_per_block']):
|
||||
paths = renew_attention_paths(attns[j])
|
||||
assign_to_checkpoint(paths, new_checkpoint, checkpoint, config=config)
|
||||
|
||||
mid_block_1_layers = [key for key in checkpoint if "mid.block_1" in key]
|
||||
mid_block_2_layers = [key for key in checkpoint if "mid.block_2" in key]
|
||||
mid_attn_1_layers = [key for key in checkpoint if "mid.attn_1" in key]
|
||||
|
||||
# Mid new 2
|
||||
paths = renew_resnet_paths(mid_block_1_layers)
|
||||
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[
|
||||
{'old': 'mid.', 'new': 'mid_new_2.'}, {'old': 'block_1', 'new': 'resnets.0'}
|
||||
])
|
||||
|
||||
paths = renew_resnet_paths(mid_block_2_layers)
|
||||
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[
|
||||
{'old': 'mid.', 'new': 'mid_new_2.'}, {'old': 'block_2', 'new': 'resnets.1'}
|
||||
])
|
||||
|
||||
paths = renew_attention_paths(mid_attn_1_layers, in_mid=True)
|
||||
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[
|
||||
{'old': 'mid.', 'new': 'mid_new_2.'}, {'old': 'attn_1', 'new': 'attentions.0'}
|
||||
])
|
||||
|
||||
for i in range(num_up_blocks):
|
||||
block_id = num_up_blocks - 1 - i
|
||||
|
||||
if any('upsample' in layer for layer in up_blocks[i]):
|
||||
new_checkpoint[f'decoder.up_blocks.{block_id}.upsamplers.0.conv.weight'] = checkpoint[f'decoder.up.{i}.upsample.conv.weight']
|
||||
new_checkpoint[f'decoder.up_blocks.{block_id}.upsamplers.0.conv.bias'] = checkpoint[f'decoder.up.{i}.upsample.conv.bias']
|
||||
|
||||
if any('block' in layer for layer in up_blocks[i]):
|
||||
num_blocks = len({'.'.join(shave_segments(layer, 3).split('.')[:3]) for layer in up_blocks[i] if 'block' in layer})
|
||||
blocks = {layer_id: [key for key in up_blocks[i] if f'block.{layer_id}' in key] for layer_id in range(num_blocks)}
|
||||
|
||||
if num_blocks > 0:
|
||||
for j in range(config['layers_per_block'] + 1):
|
||||
replace_indices = {'old': f'up_blocks.{i}', 'new': f'up_blocks.{block_id}'}
|
||||
paths = renew_resnet_paths(blocks[j])
|
||||
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[replace_indices])
|
||||
|
||||
if any('attn' in layer for layer in up_blocks[i]):
|
||||
num_attn = len({'.'.join(shave_segments(layer, 3).split('.')[:3]) for layer in up_blocks[i] if 'attn' in layer})
|
||||
attns = {layer_id: [key for key in up_blocks[i] if f'attn.{layer_id}' in key] for layer_id in range(num_blocks)}
|
||||
|
||||
if num_attn > 0:
|
||||
for j in range(config['layers_per_block'] + 1):
|
||||
replace_indices = {'old': f'up_blocks.{i}', 'new': f'up_blocks.{block_id}'}
|
||||
paths = renew_attention_paths(attns[j])
|
||||
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[replace_indices])
|
||||
|
||||
new_checkpoint = {k.replace('mid_new_2', 'mid_block'): v for k, v in new_checkpoint.items()}
|
||||
new_checkpoint["quant_conv.weight"] = checkpoint["quant_conv.weight"]
|
||||
new_checkpoint["quant_conv.bias"] = checkpoint["quant_conv.bias"]
|
||||
if "quantize.embedding.weight" in checkpoint:
|
||||
new_checkpoint["quantize.embedding.weight"] = checkpoint["quantize.embedding.weight"]
|
||||
new_checkpoint["post_quant_conv.weight"] = checkpoint["post_quant_conv.weight"]
|
||||
new_checkpoint["post_quant_conv.bias"] = checkpoint["post_quant_conv.bias"]
|
||||
|
||||
return new_checkpoint
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
@@ -220,15 +331,29 @@ if __name__ == "__main__":
|
||||
with open(args.config_file) as f:
|
||||
config = json.loads(f.read())
|
||||
|
||||
converted_checkpoint = convert_ddpm_checkpoint(checkpoint, config)
|
||||
# unet case
|
||||
key_prefix_set = set(key.split(".")[0] for key in checkpoint.keys())
|
||||
if "encoder" in key_prefix_set and "decoder" in key_prefix_set:
|
||||
converted_checkpoint = convert_vq_autoenc_checkpoint(checkpoint, config)
|
||||
else:
|
||||
converted_checkpoint = convert_ddpm_checkpoint(checkpoint, config)
|
||||
|
||||
if "ddpm" in config:
|
||||
del config["ddpm"]
|
||||
|
||||
model = UNet2DModel(**config)
|
||||
model.load_state_dict(converted_checkpoint)
|
||||
if config["_class_name"] == "VQModel":
|
||||
model = VQModel(**config)
|
||||
model.load_state_dict(converted_checkpoint)
|
||||
model.save_pretrained(args.dump_path)
|
||||
elif config["_class_name"] == "AutoencoderKL":
|
||||
model = AutoencoderKL(**config)
|
||||
model.load_state_dict(converted_checkpoint)
|
||||
model.save_pretrained(args.dump_path)
|
||||
else:
|
||||
model = UNet2DModel(**config)
|
||||
model.load_state_dict(converted_checkpoint)
|
||||
|
||||
scheduler = DDPMScheduler.from_config("/".join(args.checkpoint_path.split("/")[:-1]))
|
||||
scheduler = DDPMScheduler.from_config("/".join(args.checkpoint_path.split("/")[:-1]))
|
||||
|
||||
pipe = DDPMPipeline(unet=model, scheduler=scheduler)
|
||||
pipe.save_pretrained(args.dump_path)
|
||||
pipe = DDPMPipeline(unet=model, scheduler=scheduler)
|
||||
pipe.save_pretrained(args.dump_path)
|
||||
|
||||
23
setup.py
23
setup.py
@@ -77,18 +77,22 @@ from setuptools import find_packages, setup
|
||||
# 2. once modified, run: `make deps_table_update` to update src/diffusers/dependency_versions_table.py
|
||||
_deps = [
|
||||
"Pillow",
|
||||
"accelerate>=0.11.0",
|
||||
"black~=22.0,>=22.3",
|
||||
"datasets",
|
||||
"filelock",
|
||||
"flake8>=3.8.3",
|
||||
"huggingface-hub",
|
||||
"hf-doc-builder>=0.3.0",
|
||||
"huggingface-hub>=0.8.1,<1.0",
|
||||
"importlib_metadata",
|
||||
"isort>=5.5.4",
|
||||
"modelcards==0.1.4",
|
||||
"numpy",
|
||||
"pytest",
|
||||
"regex!=2019.12.17",
|
||||
"requests",
|
||||
"torch>=1.4",
|
||||
"tensorboard",
|
||||
"modelcards==0.1.4"
|
||||
"torch>=1.4",
|
||||
]
|
||||
|
||||
# this is a lookup table with items like:
|
||||
@@ -160,14 +164,13 @@ extras = {}
|
||||
|
||||
extras = {}
|
||||
extras["quality"] = ["black ~= 22.0", "isort >= 5.5.4", "flake8 >= 3.8.3"]
|
||||
extras["docs"] = []
|
||||
extras["training"] = ["tensorboard", "modelcards"]
|
||||
extras["test"] = [
|
||||
"pytest",
|
||||
]
|
||||
extras["dev"] = extras["quality"] + extras["test"] + extras["training"]
|
||||
extras["docs"] = ["hf-doc-builder"]
|
||||
extras["training"] = ["accelerate", "datasets", "tensorboard", "modelcards"]
|
||||
extras["test"] = ["pytest"]
|
||||
extras["dev"] = extras["quality"] + extras["test"] + extras["training"] + extras["docs"]
|
||||
|
||||
install_requires = [
|
||||
deps["importlib_metadata"],
|
||||
deps["filelock"],
|
||||
deps["huggingface-hub"],
|
||||
deps["numpy"],
|
||||
@@ -179,7 +182,7 @@ install_requires = [
|
||||
|
||||
setup(
|
||||
name="diffusers",
|
||||
version="0.1.0",
|
||||
version="0.2.2",
|
||||
description="Diffusers",
|
||||
long_description=open("README.md", "r", encoding="utf-8").read(),
|
||||
long_description_content_type="text/markdown",
|
||||
|
||||
@@ -1,19 +1,43 @@
|
||||
# flake8: noqa
|
||||
# There's no way to ignore "F401 '...' imported but unused" warnings in this
|
||||
# module, but to preserve other warnings. So, don't check this module at all.
|
||||
from .utils import is_inflect_available, is_transformers_available, is_unidecode_available
|
||||
from .utils import is_inflect_available, is_scipy_available, is_transformers_available, is_unidecode_available
|
||||
|
||||
|
||||
__version__ = "0.1.0"
|
||||
__version__ = "0.2.2"
|
||||
|
||||
from .modeling_utils import ModelMixin
|
||||
from .models import AutoencoderKL, UNet2DConditionModel, UNet2DModel, VQModel
|
||||
from .optimization import (
|
||||
get_constant_schedule,
|
||||
get_constant_schedule_with_warmup,
|
||||
get_cosine_schedule_with_warmup,
|
||||
get_cosine_with_hard_restarts_schedule_with_warmup,
|
||||
get_linear_schedule_with_warmup,
|
||||
get_polynomial_decay_schedule_with_warmup,
|
||||
get_scheduler,
|
||||
)
|
||||
from .pipeline_utils import DiffusionPipeline
|
||||
from .pipelines import DDIMPipeline, DDPMPipeline, LDMPipeline, PNDMPipeline, ScoreSdeVePipeline
|
||||
from .schedulers import DDIMScheduler, DDPMScheduler, PNDMScheduler, SchedulerMixin, ScoreSdeVeScheduler
|
||||
from .pipelines import DDIMPipeline, DDPMPipeline, KarrasVePipeline, LDMPipeline, PNDMPipeline, ScoreSdeVePipeline
|
||||
from .schedulers import (
|
||||
DDIMScheduler,
|
||||
DDPMScheduler,
|
||||
KarrasVeScheduler,
|
||||
PNDMScheduler,
|
||||
SchedulerMixin,
|
||||
ScoreSdeVeScheduler,
|
||||
)
|
||||
|
||||
|
||||
if is_scipy_available():
|
||||
from .schedulers import LMSDiscreteScheduler
|
||||
else:
|
||||
from .utils.dummy_scipy_objects import *
|
||||
|
||||
from .training_utils import EMAModel
|
||||
|
||||
|
||||
if is_transformers_available():
|
||||
from .pipelines import LDMTextToImagePipeline
|
||||
from .pipelines import LDMTextToImagePipeline, StableDiffusionPipeline
|
||||
else:
|
||||
from .utils.dummy_transformers_objects import *
|
||||
|
||||
@@ -23,17 +23,11 @@ from collections import OrderedDict
|
||||
from typing import Any, Dict, Tuple, Union
|
||||
|
||||
from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
|
||||
from requests import HTTPError
|
||||
|
||||
from . import __version__
|
||||
from .utils import (
|
||||
DIFFUSERS_CACHE,
|
||||
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
|
||||
EntryNotFoundError,
|
||||
RepositoryNotFoundError,
|
||||
RevisionNotFoundError,
|
||||
logging,
|
||||
)
|
||||
from .utils import DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
@@ -3,16 +3,20 @@
|
||||
# 2. run `make deps_table_update``
|
||||
deps = {
|
||||
"Pillow": "Pillow",
|
||||
"accelerate": "accelerate>=0.11.0",
|
||||
"black": "black~=22.0,>=22.3",
|
||||
"datasets": "datasets",
|
||||
"filelock": "filelock",
|
||||
"flake8": "flake8>=3.8.3",
|
||||
"huggingface-hub": "huggingface-hub",
|
||||
"hf-doc-builder": "hf-doc-builder>=0.3.0",
|
||||
"huggingface-hub": "huggingface-hub>=0.8.1,<1.0",
|
||||
"importlib_metadata": "importlib_metadata",
|
||||
"isort": "isort>=5.5.4",
|
||||
"modelcards": "modelcards==0.1.4",
|
||||
"numpy": "numpy",
|
||||
"pytest": "pytest",
|
||||
"regex": "regex!=2019.12.17",
|
||||
"requests": "requests",
|
||||
"torch": "torch>=1.4",
|
||||
"tensorboard": "tensorboard",
|
||||
"modelcards": "modelcards==0.1.4",
|
||||
"torch": "torch>=1.4",
|
||||
}
|
||||
|
||||
@@ -21,14 +21,13 @@ from typing import Optional
|
||||
|
||||
from diffusers import DiffusionPipeline
|
||||
from huggingface_hub import HfFolder, Repository, whoami
|
||||
from utils import is_modelcards_available
|
||||
|
||||
from .utils import is_modelcards_available, logging
|
||||
|
||||
|
||||
if is_modelcards_available():
|
||||
from modelcards import CardData, ModelCard
|
||||
|
||||
from .utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
@@ -169,13 +168,13 @@ def create_model_card(args, model_name):
|
||||
license="apache-2.0",
|
||||
library_name="diffusers",
|
||||
tags=[],
|
||||
datasets=args.dataset,
|
||||
datasets=args.dataset_name,
|
||||
metrics=[],
|
||||
),
|
||||
template_path=MODEL_CARD_TEMPLATE_PATH,
|
||||
model_name=model_name,
|
||||
repo_name=repo_name,
|
||||
dataset_name=args.dataset if hasattr(args, "dataset") else None,
|
||||
dataset_name=args.dataset_name if hasattr(args, "dataset_name") else None,
|
||||
learning_rate=args.learning_rate,
|
||||
train_batch_size=args.train_batch_size,
|
||||
eval_batch_size=args.eval_batch_size,
|
||||
@@ -185,7 +184,7 @@ def create_model_card(args, model_name):
|
||||
adam_beta1=args.adam_beta1 if hasattr(args, "adam_beta1") else None,
|
||||
adam_beta2=args.adam_beta2 if hasattr(args, "adam_beta2") else None,
|
||||
adam_weight_decay=args.adam_weight_decay if hasattr(args, "adam_weight_decay") else None,
|
||||
adam_epsilon=args.adam_epsilon if hasattr(args, "adam_weight_decay") else None,
|
||||
adam_epsilon=args.adam_epsilon if hasattr(args, "adam_epsilon") else None,
|
||||
lr_scheduler=args.lr_scheduler if hasattr(args, "lr_scheduler") else None,
|
||||
lr_warmup_steps=args.lr_warmup_steps if hasattr(args, "lr_warmup_steps") else None,
|
||||
ema_inv_gamma=args.ema_inv_gamma if hasattr(args, "ema_inv_gamma") else None,
|
||||
|
||||
@@ -21,17 +21,10 @@ import torch
|
||||
from torch import Tensor, device
|
||||
|
||||
from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
|
||||
from requests import HTTPError
|
||||
|
||||
from .utils import (
|
||||
CONFIG_NAME,
|
||||
DIFFUSERS_CACHE,
|
||||
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
|
||||
EntryNotFoundError,
|
||||
RepositoryNotFoundError,
|
||||
RevisionNotFoundError,
|
||||
logging,
|
||||
)
|
||||
from .utils import CONFIG_NAME, DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, logging
|
||||
|
||||
|
||||
WEIGHTS_NAME = "diffusion_pytorch_model.bin"
|
||||
|
||||
@@ -32,10 +32,10 @@ def get_timestep_embedding(
|
||||
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
|
||||
|
||||
half_dim = embedding_dim // 2
|
||||
exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32)
|
||||
exponent = exponent / (half_dim - downscale_freq_shift)
|
||||
|
||||
emb_coeff = -math.log(max_period) / (half_dim - downscale_freq_shift)
|
||||
emb = torch.arange(half_dim, dtype=torch.float32, device=timesteps.device)
|
||||
emb = torch.exp(emb * emb_coeff)
|
||||
emb = torch.exp(exponent).to(device=timesteps.device)
|
||||
emb = timesteps[:, None].float() * emb[None, :]
|
||||
|
||||
# scale embeddings
|
||||
|
||||
@@ -288,7 +288,10 @@ class ResnetBlock(nn.Module):
|
||||
|
||||
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels)
|
||||
if temb_channels is not None:
|
||||
self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels)
|
||||
else:
|
||||
self.time_emb_proj = None
|
||||
|
||||
self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
|
||||
self.dropout = torch.nn.Dropout(dropout)
|
||||
@@ -328,7 +331,9 @@ class ResnetBlock(nn.Module):
|
||||
def forward(self, x, temb, hey=False):
|
||||
h = x
|
||||
|
||||
h = self.norm1(h)
|
||||
# make sure hidden states is in float32
|
||||
# when running in half-precision
|
||||
h = self.norm1(h.float()).type(h.dtype)
|
||||
h = self.nonlinearity(h)
|
||||
|
||||
if self.upsample is not None:
|
||||
@@ -344,7 +349,9 @@ class ResnetBlock(nn.Module):
|
||||
temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None]
|
||||
h = h + temb
|
||||
|
||||
h = self.norm2(h)
|
||||
# make sure hidden states is in float32
|
||||
# when running in half-precision
|
||||
h = self.norm2(h.float()).type(h.dtype)
|
||||
h = self.nonlinearity(h)
|
||||
|
||||
h = self.dropout(h)
|
||||
@@ -364,8 +371,9 @@ class ResnetBlock(nn.Module):
|
||||
self.conv1.weight.data = resnet.conv1.weight.data
|
||||
self.conv1.bias.data = resnet.conv1.bias.data
|
||||
|
||||
self.time_emb_proj.weight.data = resnet.temb_proj.weight.data
|
||||
self.time_emb_proj.bias.data = resnet.temb_proj.bias.data
|
||||
if self.time_emb_proj is not None:
|
||||
self.time_emb_proj.weight.data = resnet.temb_proj.weight.data
|
||||
self.time_emb_proj.bias.data = resnet.temb_proj.bias.data
|
||||
|
||||
self.norm2.weight.data = resnet.norm2.weight.data
|
||||
self.norm2.bias.data = resnet.norm2.bias.data
|
||||
|
||||
@@ -132,6 +132,9 @@ class UNet2DModel(ModelMixin, ConfigMixin):
|
||||
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
|
||||
timesteps = timesteps[None].to(sample.device)
|
||||
|
||||
# broadcast to batch dimension
|
||||
timesteps = timesteps.broadcast_to(sample.shape[0])
|
||||
|
||||
t_emb = self.time_proj(timesteps)
|
||||
emb = self.time_embedding(t_emb)
|
||||
|
||||
@@ -166,7 +169,9 @@ class UNet2DModel(ModelMixin, ConfigMixin):
|
||||
sample = upsample_block(sample, res_samples, emb)
|
||||
|
||||
# 6. post-process
|
||||
sample = self.conv_norm_out(sample)
|
||||
# make sure hidden states is in float32
|
||||
# when running in half-precision
|
||||
sample = self.conv_norm_out(sample.float()).type(sample.dtype)
|
||||
sample = self.conv_act(sample)
|
||||
sample = self.conv_out(sample)
|
||||
|
||||
|
||||
@@ -28,6 +28,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
|
||||
act_fn="silu",
|
||||
norm_num_groups=32,
|
||||
norm_eps=1e-5,
|
||||
cross_attention_dim=1280,
|
||||
attention_head_dim=8,
|
||||
):
|
||||
super().__init__()
|
||||
@@ -64,6 +65,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
|
||||
add_downsample=not is_final_block,
|
||||
resnet_eps=norm_eps,
|
||||
resnet_act_fn=act_fn,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
attn_num_head_channels=attention_head_dim,
|
||||
downsample_padding=downsample_padding,
|
||||
)
|
||||
@@ -77,6 +79,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
|
||||
resnet_act_fn=act_fn,
|
||||
output_scale_factor=mid_block_scale_factor,
|
||||
resnet_time_scale_shift="default",
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
attn_num_head_channels=attention_head_dim,
|
||||
resnet_groups=norm_num_groups,
|
||||
)
|
||||
@@ -101,6 +104,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
|
||||
add_upsample=not is_final_block,
|
||||
resnet_eps=norm_eps,
|
||||
resnet_act_fn=act_fn,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
attn_num_head_channels=attention_head_dim,
|
||||
)
|
||||
self.up_blocks.append(up_block)
|
||||
@@ -129,6 +133,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
|
||||
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
|
||||
timesteps = timesteps[None].to(sample.device)
|
||||
|
||||
# broadcast to batch dimension
|
||||
timesteps = timesteps.broadcast_to(sample.shape[0])
|
||||
|
||||
t_emb = self.time_proj(timesteps)
|
||||
emb = self.time_embedding(t_emb)
|
||||
|
||||
@@ -168,8 +175,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
|
||||
sample = upsample_block(hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples)
|
||||
|
||||
# 6. post-process
|
||||
|
||||
sample = self.conv_norm_out(sample)
|
||||
# make sure hidden states is in float32
|
||||
# when running in half-precision
|
||||
sample = self.conv_norm_out(sample.float()).type(sample.dtype)
|
||||
sample = self.conv_act(sample)
|
||||
sample = self.conv_out(sample)
|
||||
|
||||
|
||||
@@ -31,6 +31,7 @@ def get_down_block(
|
||||
resnet_eps,
|
||||
resnet_act_fn,
|
||||
attn_num_head_channels,
|
||||
cross_attention_dim=None,
|
||||
downsample_padding=None,
|
||||
):
|
||||
down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
|
||||
@@ -58,6 +59,8 @@ def get_down_block(
|
||||
attn_num_head_channels=attn_num_head_channels,
|
||||
)
|
||||
elif down_block_type == "CrossAttnDownBlock2D":
|
||||
if cross_attention_dim is None:
|
||||
raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock2D")
|
||||
return CrossAttnDownBlock2D(
|
||||
num_layers=num_layers,
|
||||
in_channels=in_channels,
|
||||
@@ -67,6 +70,7 @@ def get_down_block(
|
||||
resnet_eps=resnet_eps,
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
downsample_padding=downsample_padding,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
attn_num_head_channels=attn_num_head_channels,
|
||||
)
|
||||
elif down_block_type == "SkipDownBlock2D":
|
||||
@@ -92,6 +96,16 @@ def get_down_block(
|
||||
downsample_padding=downsample_padding,
|
||||
attn_num_head_channels=attn_num_head_channels,
|
||||
)
|
||||
elif down_block_type == "DownEncoderBlock2D":
|
||||
return DownEncoderBlock2D(
|
||||
num_layers=num_layers,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
add_downsample=add_downsample,
|
||||
resnet_eps=resnet_eps,
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
downsample_padding=downsample_padding,
|
||||
)
|
||||
|
||||
|
||||
def get_up_block(
|
||||
@@ -105,6 +119,7 @@ def get_up_block(
|
||||
resnet_eps,
|
||||
resnet_act_fn,
|
||||
attn_num_head_channels,
|
||||
cross_attention_dim=None,
|
||||
):
|
||||
up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
|
||||
if up_block_type == "UpBlock2D":
|
||||
@@ -119,6 +134,8 @@ def get_up_block(
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
)
|
||||
elif up_block_type == "CrossAttnUpBlock2D":
|
||||
if cross_attention_dim is None:
|
||||
raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock2D")
|
||||
return CrossAttnUpBlock2D(
|
||||
num_layers=num_layers,
|
||||
in_channels=in_channels,
|
||||
@@ -128,6 +145,7 @@ def get_up_block(
|
||||
add_upsample=add_upsample,
|
||||
resnet_eps=resnet_eps,
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
attn_num_head_channels=attn_num_head_channels,
|
||||
)
|
||||
elif up_block_type == "AttnUpBlock2D":
|
||||
@@ -165,6 +183,15 @@ def get_up_block(
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
attn_num_head_channels=attn_num_head_channels,
|
||||
)
|
||||
elif up_block_type == "UpDecoderBlock2D":
|
||||
return UpDecoderBlock2D(
|
||||
num_layers=num_layers,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
add_upsample=add_upsample,
|
||||
resnet_eps=resnet_eps,
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
)
|
||||
raise ValueError(f"{up_block_type} does not exist.")
|
||||
|
||||
|
||||
@@ -553,6 +580,139 @@ class DownBlock2D(nn.Module):
|
||||
return hidden_states, output_states
|
||||
|
||||
|
||||
class DownEncoderBlock2D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
dropout: float = 0.0,
|
||||
num_layers: int = 1,
|
||||
resnet_eps: float = 1e-6,
|
||||
resnet_time_scale_shift: str = "default",
|
||||
resnet_act_fn: str = "swish",
|
||||
resnet_groups: int = 32,
|
||||
resnet_pre_norm: bool = True,
|
||||
output_scale_factor=1.0,
|
||||
add_downsample=True,
|
||||
downsample_padding=1,
|
||||
):
|
||||
super().__init__()
|
||||
resnets = []
|
||||
|
||||
for i in range(num_layers):
|
||||
in_channels = in_channels if i == 0 else out_channels
|
||||
resnets.append(
|
||||
ResnetBlock(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
temb_channels=None,
|
||||
eps=resnet_eps,
|
||||
groups=resnet_groups,
|
||||
dropout=dropout,
|
||||
time_embedding_norm=resnet_time_scale_shift,
|
||||
non_linearity=resnet_act_fn,
|
||||
output_scale_factor=output_scale_factor,
|
||||
pre_norm=resnet_pre_norm,
|
||||
)
|
||||
)
|
||||
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
if add_downsample:
|
||||
self.downsamplers = nn.ModuleList(
|
||||
[
|
||||
Downsample2D(
|
||||
in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
|
||||
)
|
||||
]
|
||||
)
|
||||
else:
|
||||
self.downsamplers = None
|
||||
|
||||
def forward(self, hidden_states):
|
||||
for resnet in self.resnets:
|
||||
hidden_states = resnet(hidden_states, temb=None)
|
||||
|
||||
if self.downsamplers is not None:
|
||||
for downsampler in self.downsamplers:
|
||||
hidden_states = downsampler(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class AttnDownEncoderBlock2D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
dropout: float = 0.0,
|
||||
num_layers: int = 1,
|
||||
resnet_eps: float = 1e-6,
|
||||
resnet_time_scale_shift: str = "default",
|
||||
resnet_act_fn: str = "swish",
|
||||
resnet_groups: int = 32,
|
||||
resnet_pre_norm: bool = True,
|
||||
attn_num_head_channels=1,
|
||||
output_scale_factor=1.0,
|
||||
add_downsample=True,
|
||||
downsample_padding=1,
|
||||
):
|
||||
super().__init__()
|
||||
resnets = []
|
||||
attentions = []
|
||||
|
||||
for i in range(num_layers):
|
||||
in_channels = in_channels if i == 0 else out_channels
|
||||
resnets.append(
|
||||
ResnetBlock(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
temb_channels=None,
|
||||
eps=resnet_eps,
|
||||
groups=resnet_groups,
|
||||
dropout=dropout,
|
||||
time_embedding_norm=resnet_time_scale_shift,
|
||||
non_linearity=resnet_act_fn,
|
||||
output_scale_factor=output_scale_factor,
|
||||
pre_norm=resnet_pre_norm,
|
||||
)
|
||||
)
|
||||
attentions.append(
|
||||
AttentionBlockNew(
|
||||
out_channels,
|
||||
num_head_channels=attn_num_head_channels,
|
||||
rescale_output_factor=output_scale_factor,
|
||||
eps=resnet_eps,
|
||||
num_groups=resnet_groups,
|
||||
)
|
||||
)
|
||||
|
||||
self.attentions = nn.ModuleList(attentions)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
if add_downsample:
|
||||
self.downsamplers = nn.ModuleList(
|
||||
[
|
||||
Downsample2D(
|
||||
in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
|
||||
)
|
||||
]
|
||||
)
|
||||
else:
|
||||
self.downsamplers = None
|
||||
|
||||
def forward(self, hidden_states):
|
||||
for resnet, attn in zip(self.resnets, self.attentions):
|
||||
hidden_states = resnet(hidden_states, temb=None)
|
||||
hidden_states = attn(hidden_states)
|
||||
|
||||
if self.downsamplers is not None:
|
||||
for downsampler in self.downsamplers:
|
||||
hidden_states = downsampler(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class AttnSkipDownBlock2D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -946,6 +1106,127 @@ class UpBlock2D(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class UpDecoderBlock2D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
dropout: float = 0.0,
|
||||
num_layers: int = 1,
|
||||
resnet_eps: float = 1e-6,
|
||||
resnet_time_scale_shift: str = "default",
|
||||
resnet_act_fn: str = "swish",
|
||||
resnet_groups: int = 32,
|
||||
resnet_pre_norm: bool = True,
|
||||
output_scale_factor=1.0,
|
||||
add_upsample=True,
|
||||
):
|
||||
super().__init__()
|
||||
resnets = []
|
||||
|
||||
for i in range(num_layers):
|
||||
input_channels = in_channels if i == 0 else out_channels
|
||||
|
||||
resnets.append(
|
||||
ResnetBlock(
|
||||
in_channels=input_channels,
|
||||
out_channels=out_channels,
|
||||
temb_channels=None,
|
||||
eps=resnet_eps,
|
||||
groups=resnet_groups,
|
||||
dropout=dropout,
|
||||
time_embedding_norm=resnet_time_scale_shift,
|
||||
non_linearity=resnet_act_fn,
|
||||
output_scale_factor=output_scale_factor,
|
||||
pre_norm=resnet_pre_norm,
|
||||
)
|
||||
)
|
||||
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
if add_upsample:
|
||||
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
|
||||
else:
|
||||
self.upsamplers = None
|
||||
|
||||
def forward(self, hidden_states):
|
||||
for resnet in self.resnets:
|
||||
hidden_states = resnet(hidden_states, temb=None)
|
||||
|
||||
if self.upsamplers is not None:
|
||||
for upsampler in self.upsamplers:
|
||||
hidden_states = upsampler(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class AttnUpDecoderBlock2D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
dropout: float = 0.0,
|
||||
num_layers: int = 1,
|
||||
resnet_eps: float = 1e-6,
|
||||
resnet_time_scale_shift: str = "default",
|
||||
resnet_act_fn: str = "swish",
|
||||
resnet_groups: int = 32,
|
||||
resnet_pre_norm: bool = True,
|
||||
attn_num_head_channels=1,
|
||||
output_scale_factor=1.0,
|
||||
add_upsample=True,
|
||||
):
|
||||
super().__init__()
|
||||
resnets = []
|
||||
attentions = []
|
||||
|
||||
for i in range(num_layers):
|
||||
input_channels = in_channels if i == 0 else out_channels
|
||||
|
||||
resnets.append(
|
||||
ResnetBlock(
|
||||
in_channels=input_channels,
|
||||
out_channels=out_channels,
|
||||
temb_channels=None,
|
||||
eps=resnet_eps,
|
||||
groups=resnet_groups,
|
||||
dropout=dropout,
|
||||
time_embedding_norm=resnet_time_scale_shift,
|
||||
non_linearity=resnet_act_fn,
|
||||
output_scale_factor=output_scale_factor,
|
||||
pre_norm=resnet_pre_norm,
|
||||
)
|
||||
)
|
||||
attentions.append(
|
||||
AttentionBlockNew(
|
||||
out_channels,
|
||||
num_head_channels=attn_num_head_channels,
|
||||
rescale_output_factor=output_scale_factor,
|
||||
eps=resnet_eps,
|
||||
num_groups=resnet_groups,
|
||||
)
|
||||
)
|
||||
|
||||
self.attentions = nn.ModuleList(attentions)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
if add_upsample:
|
||||
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
|
||||
else:
|
||||
self.upsamplers = None
|
||||
|
||||
def forward(self, hidden_states):
|
||||
for resnet, attn in zip(self.resnets, self.attentions):
|
||||
hidden_states = resnet(hidden_states, temb=None)
|
||||
hidden_states = attn(hidden_states)
|
||||
|
||||
if self.upsamplers is not None:
|
||||
for upsampler in self.upsamplers:
|
||||
hidden_states = upsampler(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class AttnSkipUpBlock2D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -4,221 +4,165 @@ import torch.nn as nn
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..modeling_utils import ModelMixin
|
||||
from .attention import AttentionBlock
|
||||
from .resnet import Downsample2D, ResnetBlock2D, Upsample2D
|
||||
|
||||
|
||||
def nonlinearity(x):
|
||||
# swish
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
def Normalize(in_channels):
|
||||
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
from .unet_blocks import UNetMidBlock2D, get_down_block, get_up_block
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
ch,
|
||||
ch_mult=(1, 2, 4, 8),
|
||||
num_res_blocks,
|
||||
attn_resolutions,
|
||||
dropout=0.0,
|
||||
resamp_with_conv=True,
|
||||
in_channels,
|
||||
resolution,
|
||||
z_channels,
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
down_block_types=("DownEncoderBlock2D",),
|
||||
block_out_channels=(64,),
|
||||
layers_per_block=2,
|
||||
act_fn="silu",
|
||||
double_z=True,
|
||||
**ignore_kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.ch = ch
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
self.layers_per_block = layers_per_block
|
||||
|
||||
# downsampling
|
||||
self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
|
||||
self.conv_in = torch.nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1)
|
||||
|
||||
curr_res = resolution
|
||||
in_ch_mult = (1,) + tuple(ch_mult)
|
||||
self.down = nn.ModuleList()
|
||||
for i_level in range(self.num_resolutions):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_in = ch * in_ch_mult[i_level]
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks):
|
||||
block.append(
|
||||
ResnetBlock2D(
|
||||
in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout
|
||||
)
|
||||
)
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
attn.append(AttentionBlock(block_in, overwrite_qkv=True))
|
||||
down = nn.Module()
|
||||
down.block = block
|
||||
down.attn = attn
|
||||
if i_level != self.num_resolutions - 1:
|
||||
down.downsample = Downsample2D(block_in, use_conv=resamp_with_conv, padding=0)
|
||||
curr_res = curr_res // 2
|
||||
self.down.append(down)
|
||||
self.mid_block = None
|
||||
self.down_blocks = nn.ModuleList([])
|
||||
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock2D(
|
||||
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
|
||||
)
|
||||
self.mid.attn_1 = AttentionBlock(block_in, overwrite_qkv=True)
|
||||
self.mid.block_2 = ResnetBlock2D(
|
||||
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
|
||||
# down
|
||||
output_channel = block_out_channels[0]
|
||||
for i, down_block_type in enumerate(down_block_types):
|
||||
input_channel = output_channel
|
||||
output_channel = block_out_channels[i]
|
||||
is_final_block = i == len(block_out_channels) - 1
|
||||
|
||||
down_block = get_down_block(
|
||||
down_block_type,
|
||||
num_layers=self.layers_per_block,
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
add_downsample=not is_final_block,
|
||||
resnet_eps=1e-6,
|
||||
downsample_padding=0,
|
||||
resnet_act_fn=act_fn,
|
||||
attn_num_head_channels=None,
|
||||
temb_channels=None,
|
||||
)
|
||||
self.down_blocks.append(down_block)
|
||||
|
||||
# mid
|
||||
self.mid_block = UNetMidBlock2D(
|
||||
in_channels=block_out_channels[-1],
|
||||
resnet_eps=1e-6,
|
||||
resnet_act_fn=act_fn,
|
||||
output_scale_factor=1,
|
||||
resnet_time_scale_shift="default",
|
||||
attn_num_head_channels=None,
|
||||
resnet_groups=32,
|
||||
temb_channels=None,
|
||||
)
|
||||
|
||||
# end
|
||||
self.norm_out = Normalize(block_in)
|
||||
self.conv_out = torch.nn.Conv2d(
|
||||
block_in, 2 * z_channels if double_z else z_channels, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
# out
|
||||
num_groups_out = 32
|
||||
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=num_groups_out, eps=1e-6)
|
||||
self.conv_act = nn.SiLU()
|
||||
|
||||
conv_out_channels = 2 * out_channels if double_z else out_channels
|
||||
self.conv_out = nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
# assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution)
|
||||
sample = x
|
||||
sample = self.conv_in(sample)
|
||||
|
||||
# timestep embedding
|
||||
temb = None
|
||||
|
||||
# downsampling
|
||||
hs = [self.conv_in(x)]
|
||||
for i_level in range(self.num_resolutions):
|
||||
for i_block in range(self.num_res_blocks):
|
||||
h = self.down[i_level].block[i_block](hs[-1], temb)
|
||||
if len(self.down[i_level].attn) > 0:
|
||||
h = self.down[i_level].attn[i_block](h)
|
||||
hs.append(h)
|
||||
if i_level != self.num_resolutions - 1:
|
||||
hs.append(self.down[i_level].downsample(hs[-1]))
|
||||
# down
|
||||
for down_block in self.down_blocks:
|
||||
sample = down_block(sample)
|
||||
|
||||
# middle
|
||||
h = hs[-1]
|
||||
h = self.mid.block_1(h, temb)
|
||||
h = self.mid.attn_1(h)
|
||||
h = self.mid.block_2(h, temb)
|
||||
sample = self.mid_block(sample)
|
||||
|
||||
# end
|
||||
h = self.norm_out(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h)
|
||||
return h
|
||||
# post-process
|
||||
sample = self.conv_norm_out(sample)
|
||||
sample = self.conv_act(sample)
|
||||
sample = self.conv_out(sample)
|
||||
|
||||
return sample
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
ch,
|
||||
out_ch,
|
||||
ch_mult=(1, 2, 4, 8),
|
||||
num_res_blocks,
|
||||
attn_resolutions,
|
||||
dropout=0.0,
|
||||
resamp_with_conv=True,
|
||||
in_channels,
|
||||
resolution,
|
||||
z_channels,
|
||||
give_pre_end=False,
|
||||
**ignorekwargs,
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
up_block_types=("UpDecoderBlock2D",),
|
||||
block_out_channels=(64,),
|
||||
layers_per_block=2,
|
||||
act_fn="silu",
|
||||
):
|
||||
super().__init__()
|
||||
self.ch = ch
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
self.give_pre_end = give_pre_end
|
||||
self.layers_per_block = layers_per_block
|
||||
|
||||
# compute in_ch_mult, block_in and curr_res at lowest res
|
||||
block_in = ch * ch_mult[self.num_resolutions - 1]
|
||||
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
||||
self.z_shape = (1, z_channels, curr_res, curr_res)
|
||||
# print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape)))
|
||||
self.conv_in = nn.Conv2d(in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1)
|
||||
|
||||
# z to block_in
|
||||
self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
|
||||
self.mid_block = None
|
||||
self.up_blocks = nn.ModuleList([])
|
||||
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock2D(
|
||||
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
|
||||
)
|
||||
self.mid.attn_1 = AttentionBlock(block_in, overwrite_qkv=True)
|
||||
self.mid.block_2 = ResnetBlock2D(
|
||||
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
|
||||
# mid
|
||||
self.mid_block = UNetMidBlock2D(
|
||||
in_channels=block_out_channels[-1],
|
||||
resnet_eps=1e-6,
|
||||
resnet_act_fn=act_fn,
|
||||
output_scale_factor=1,
|
||||
resnet_time_scale_shift="default",
|
||||
attn_num_head_channels=None,
|
||||
resnet_groups=32,
|
||||
temb_channels=None,
|
||||
)
|
||||
|
||||
# upsampling
|
||||
self.up = nn.ModuleList()
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
block.append(
|
||||
ResnetBlock2D(
|
||||
in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout
|
||||
)
|
||||
)
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
attn.append(AttentionBlock(block_in, overwrite_qkv=True))
|
||||
up = nn.Module()
|
||||
up.block = block
|
||||
up.attn = attn
|
||||
if i_level != 0:
|
||||
up.upsample = Upsample2D(block_in, use_conv=resamp_with_conv)
|
||||
curr_res = curr_res * 2
|
||||
self.up.insert(0, up) # prepend to get consistent order
|
||||
# up
|
||||
reversed_block_out_channels = list(reversed(block_out_channels))
|
||||
output_channel = reversed_block_out_channels[0]
|
||||
for i, up_block_type in enumerate(up_block_types):
|
||||
prev_output_channel = output_channel
|
||||
output_channel = reversed_block_out_channels[i]
|
||||
|
||||
# end
|
||||
self.norm_out = Normalize(block_in)
|
||||
self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
|
||||
is_final_block = i == len(block_out_channels) - 1
|
||||
|
||||
up_block = get_up_block(
|
||||
up_block_type,
|
||||
num_layers=self.layers_per_block + 1,
|
||||
in_channels=prev_output_channel,
|
||||
out_channels=output_channel,
|
||||
prev_output_channel=None,
|
||||
add_upsample=not is_final_block,
|
||||
resnet_eps=1e-6,
|
||||
resnet_act_fn=act_fn,
|
||||
attn_num_head_channels=None,
|
||||
temb_channels=None,
|
||||
)
|
||||
self.up_blocks.append(up_block)
|
||||
prev_output_channel = output_channel
|
||||
|
||||
# out
|
||||
num_groups_out = 32
|
||||
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=num_groups_out, eps=1e-6)
|
||||
self.conv_act = nn.SiLU()
|
||||
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
|
||||
|
||||
def forward(self, z):
|
||||
# assert z.shape[1:] == self.z_shape[1:]
|
||||
self.last_z_shape = z.shape
|
||||
|
||||
# timestep embedding
|
||||
temb = None
|
||||
|
||||
# z to block_in
|
||||
h = self.conv_in(z)
|
||||
sample = z
|
||||
sample = self.conv_in(sample)
|
||||
|
||||
# middle
|
||||
h = self.mid.block_1(h, temb)
|
||||
h = self.mid.attn_1(h)
|
||||
h = self.mid.block_2(h, temb)
|
||||
sample = self.mid_block(sample)
|
||||
|
||||
# upsampling
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
h = self.up[i_level].block[i_block](h, temb)
|
||||
if len(self.up[i_level].attn) > 0:
|
||||
h = self.up[i_level].attn[i_block](h)
|
||||
if i_level != 0:
|
||||
h = self.up[i_level].upsample(h)
|
||||
# up
|
||||
for up_block in self.up_blocks:
|
||||
sample = up_block(sample)
|
||||
|
||||
# end
|
||||
if self.give_pre_end:
|
||||
return h
|
||||
# post-process
|
||||
sample = self.conv_norm_out(sample)
|
||||
sample = self.conv_act(sample)
|
||||
sample = self.conv_out(sample)
|
||||
|
||||
h = self.norm_out(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h)
|
||||
return h
|
||||
return sample
|
||||
|
||||
|
||||
class VectorQuantizer(nn.Module):
|
||||
@@ -383,57 +327,44 @@ class VQModel(ModelMixin, ConfigMixin):
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
ch,
|
||||
out_ch,
|
||||
num_res_blocks,
|
||||
attn_resolutions,
|
||||
in_channels,
|
||||
resolution,
|
||||
z_channels,
|
||||
n_embed,
|
||||
embed_dim,
|
||||
remap=None,
|
||||
sane_index_shape=False, # tell vector quantizer to return indices as bhw
|
||||
ch_mult=(1, 2, 4, 8),
|
||||
dropout=0.0,
|
||||
double_z=True,
|
||||
resamp_with_conv=True,
|
||||
give_pre_end=False,
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
down_block_types=("DownEncoderBlock2D",),
|
||||
up_block_types=("UpDecoderBlock2D",),
|
||||
block_out_channels=(64,),
|
||||
layers_per_block=1,
|
||||
act_fn="silu",
|
||||
latent_channels=3,
|
||||
sample_size=32,
|
||||
num_vq_embeddings=256,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# pass init params to Encoder
|
||||
self.encoder = Encoder(
|
||||
ch=ch,
|
||||
num_res_blocks=num_res_blocks,
|
||||
attn_resolutions=attn_resolutions,
|
||||
in_channels=in_channels,
|
||||
resolution=resolution,
|
||||
z_channels=z_channels,
|
||||
ch_mult=ch_mult,
|
||||
dropout=dropout,
|
||||
resamp_with_conv=resamp_with_conv,
|
||||
double_z=double_z,
|
||||
give_pre_end=give_pre_end,
|
||||
out_channels=latent_channels,
|
||||
down_block_types=down_block_types,
|
||||
block_out_channels=block_out_channels,
|
||||
layers_per_block=layers_per_block,
|
||||
act_fn=act_fn,
|
||||
double_z=False,
|
||||
)
|
||||
|
||||
self.quant_conv = torch.nn.Conv2d(z_channels, embed_dim, 1)
|
||||
self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25, remap=remap, sane_index_shape=sane_index_shape)
|
||||
self.post_quant_conv = torch.nn.Conv2d(embed_dim, z_channels, 1)
|
||||
self.quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1)
|
||||
self.quantize = VectorQuantizer(
|
||||
num_vq_embeddings, latent_channels, beta=0.25, remap=None, sane_index_shape=False
|
||||
)
|
||||
self.post_quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1)
|
||||
|
||||
# pass init params to Decoder
|
||||
self.decoder = Decoder(
|
||||
ch=ch,
|
||||
out_ch=out_ch,
|
||||
num_res_blocks=num_res_blocks,
|
||||
attn_resolutions=attn_resolutions,
|
||||
in_channels=in_channels,
|
||||
resolution=resolution,
|
||||
z_channels=z_channels,
|
||||
ch_mult=ch_mult,
|
||||
dropout=dropout,
|
||||
resamp_with_conv=resamp_with_conv,
|
||||
give_pre_end=give_pre_end,
|
||||
in_channels=latent_channels,
|
||||
out_channels=out_channels,
|
||||
up_block_types=up_block_types,
|
||||
block_out_channels=block_out_channels,
|
||||
layers_per_block=layers_per_block,
|
||||
act_fn=act_fn,
|
||||
)
|
||||
|
||||
def encode(self, x):
|
||||
@@ -462,57 +393,41 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
ch,
|
||||
out_ch,
|
||||
num_res_blocks,
|
||||
attn_resolutions,
|
||||
in_channels,
|
||||
resolution,
|
||||
z_channels,
|
||||
embed_dim,
|
||||
remap=None,
|
||||
sane_index_shape=False, # tell vector quantizer to return indices as bhw
|
||||
ch_mult=(1, 2, 4, 8),
|
||||
dropout=0.0,
|
||||
double_z=True,
|
||||
resamp_with_conv=True,
|
||||
give_pre_end=False,
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
down_block_types=("DownEncoderBlock2D",),
|
||||
up_block_types=("UpDecoderBlock2D",),
|
||||
block_out_channels=(64,),
|
||||
layers_per_block=1,
|
||||
act_fn="silu",
|
||||
latent_channels=4,
|
||||
sample_size=32,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# pass init params to Encoder
|
||||
self.encoder = Encoder(
|
||||
ch=ch,
|
||||
out_ch=out_ch,
|
||||
num_res_blocks=num_res_blocks,
|
||||
attn_resolutions=attn_resolutions,
|
||||
in_channels=in_channels,
|
||||
resolution=resolution,
|
||||
z_channels=z_channels,
|
||||
ch_mult=ch_mult,
|
||||
dropout=dropout,
|
||||
resamp_with_conv=resamp_with_conv,
|
||||
double_z=double_z,
|
||||
give_pre_end=give_pre_end,
|
||||
out_channels=latent_channels,
|
||||
down_block_types=down_block_types,
|
||||
block_out_channels=block_out_channels,
|
||||
layers_per_block=layers_per_block,
|
||||
act_fn=act_fn,
|
||||
double_z=True,
|
||||
)
|
||||
|
||||
# pass init params to Decoder
|
||||
self.decoder = Decoder(
|
||||
ch=ch,
|
||||
out_ch=out_ch,
|
||||
num_res_blocks=num_res_blocks,
|
||||
attn_resolutions=attn_resolutions,
|
||||
in_channels=in_channels,
|
||||
resolution=resolution,
|
||||
z_channels=z_channels,
|
||||
ch_mult=ch_mult,
|
||||
dropout=dropout,
|
||||
resamp_with_conv=resamp_with_conv,
|
||||
give_pre_end=give_pre_end,
|
||||
in_channels=latent_channels,
|
||||
out_channels=out_channels,
|
||||
up_block_types=up_block_types,
|
||||
block_out_channels=block_out_channels,
|
||||
layers_per_block=layers_per_block,
|
||||
act_fn=act_fn,
|
||||
)
|
||||
|
||||
self.quant_conv = torch.nn.Conv2d(2 * z_channels, 2 * embed_dim, 1)
|
||||
self.post_quant_conv = torch.nn.Conv2d(embed_dim, z_channels, 1)
|
||||
self.quant_conv = torch.nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
|
||||
self.post_quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1)
|
||||
|
||||
def encode(self, x):
|
||||
h = self.encoder(x)
|
||||
|
||||
@@ -18,7 +18,6 @@ import math
|
||||
from enum import Enum
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import LambdaLR
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import importlib
|
||||
import inspect
|
||||
import os
|
||||
from typing import Optional, Union
|
||||
|
||||
@@ -148,6 +149,12 @@ class DiffusionPipeline(ConfigMixin):
|
||||
diffusers_module = importlib.import_module(cls.__module__.split(".")[0])
|
||||
pipeline_class = getattr(diffusers_module, config_dict["_class_name"])
|
||||
|
||||
# some modules can be passed directly to the init
|
||||
# in this case they are already instantiated in `kwargs`
|
||||
# extract them here
|
||||
expected_modules = set(inspect.signature(pipeline_class.__init__).parameters.keys())
|
||||
passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
|
||||
|
||||
init_dict, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
|
||||
|
||||
init_kwargs = {}
|
||||
@@ -158,8 +165,36 @@ class DiffusionPipeline(ConfigMixin):
|
||||
# 3. Load each module in the pipeline
|
||||
for name, (library_name, class_name) in init_dict.items():
|
||||
is_pipeline_module = hasattr(pipelines, library_name)
|
||||
loaded_sub_model = None
|
||||
|
||||
# if the model is in a pipeline module, then we load it from the pipeline
|
||||
if is_pipeline_module:
|
||||
if name in passed_class_obj:
|
||||
# 1. check that passed_class_obj has correct parent class
|
||||
if not is_pipeline_module:
|
||||
library = importlib.import_module(library_name)
|
||||
class_obj = getattr(library, class_name)
|
||||
importable_classes = LOADABLE_CLASSES[library_name]
|
||||
class_candidates = {c: getattr(library, c) for c in importable_classes.keys()}
|
||||
|
||||
expected_class_obj = None
|
||||
for class_name, class_candidate in class_candidates.items():
|
||||
if issubclass(class_obj, class_candidate):
|
||||
expected_class_obj = class_candidate
|
||||
|
||||
if not issubclass(passed_class_obj[name].__class__, expected_class_obj):
|
||||
raise ValueError(
|
||||
f"{passed_class_obj[name]} is of type: {type(passed_class_obj[name])}, but should be"
|
||||
f" {expected_class_obj}"
|
||||
)
|
||||
else:
|
||||
logger.warn(
|
||||
f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it"
|
||||
" has the correct type"
|
||||
)
|
||||
|
||||
# set passed class object
|
||||
loaded_sub_model = passed_class_obj[name]
|
||||
elif is_pipeline_module:
|
||||
pipeline_module = getattr(pipelines, library_name)
|
||||
class_obj = getattr(pipeline_module, class_name)
|
||||
importable_classes = ALL_IMPORTABLE_CLASSES
|
||||
@@ -171,23 +206,24 @@ class DiffusionPipeline(ConfigMixin):
|
||||
importable_classes = LOADABLE_CLASSES[library_name]
|
||||
class_candidates = {c: getattr(library, c) for c in importable_classes.keys()}
|
||||
|
||||
load_method_name = None
|
||||
for class_name, class_candidate in class_candidates.items():
|
||||
if issubclass(class_obj, class_candidate):
|
||||
load_method_name = importable_classes[class_name][1]
|
||||
if loaded_sub_model is None:
|
||||
load_method_name = None
|
||||
for class_name, class_candidate in class_candidates.items():
|
||||
if issubclass(class_obj, class_candidate):
|
||||
load_method_name = importable_classes[class_name][1]
|
||||
|
||||
load_method = getattr(class_obj, load_method_name)
|
||||
load_method = getattr(class_obj, load_method_name)
|
||||
|
||||
# check if the module is in a subdirectory
|
||||
if os.path.isdir(os.path.join(cached_folder, name)):
|
||||
loaded_sub_model = load_method(os.path.join(cached_folder, name))
|
||||
else:
|
||||
# else load from the root directory
|
||||
loaded_sub_model = load_method(cached_folder)
|
||||
# check if the module is in a subdirectory
|
||||
if os.path.isdir(os.path.join(cached_folder, name)):
|
||||
loaded_sub_model = load_method(os.path.join(cached_folder, name))
|
||||
else:
|
||||
# else load from the root directory
|
||||
loaded_sub_model = load_method(cached_folder)
|
||||
|
||||
init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...)
|
||||
|
||||
# 5. Instantiate the pipeline
|
||||
# 4. Instantiate the pipeline
|
||||
model = pipeline_class(**init_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@@ -10,10 +10,10 @@ TODO(Patrick, Anton, Suraj)
|
||||
|
||||
## Examples
|
||||
|
||||
- DDPM for unconditional image generation in [pipeline_ddpm](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_ddpm.py).
|
||||
- DDIM for unconditional image generation in [pipeline_ddim](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_ddim.py).
|
||||
- PNDM for unconditional image generation in [pipeline_pndm](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_pndm.py).
|
||||
- Latent diffusion for text to image generation / conditional image generation in [pipeline_latent_diffusion](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_latent_diffusion.py).
|
||||
- Glide for text to image generation / conditional image generation in [pipeline_glide](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_glide.py).
|
||||
- BDDMPipeline for spectrogram-to-sound vocoding in [pipeline_bddm](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_bddm.py).
|
||||
- Grad-TTS for text to audio generation / conditional audio generation in [pipeline_grad_tts](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_grad_tts.py).
|
||||
- DDPM for unconditional image generation in [pipeline_ddpm](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/ddpm/pipeline_ddpm.py).
|
||||
- DDIM for unconditional image generation in [pipeline_ddim](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/ddim/pipeline_ddim.py).
|
||||
- PNDM for unconditional image generation in [pipeline_pndm](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pndm/pipeline_pndm.py).
|
||||
- Latent diffusion for text to image generation / conditional image generation in [pipeline_latent_diffusion](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py).
|
||||
- Glide for text to image generation / conditional image generation in [pipeline_glide](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/glide/pipeline_glide.py).
|
||||
- BDDMPipeline for spectrogram-to-sound vocoding in [pipeline_bddm](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/bddm/pipeline_bddm.py).
|
||||
- Grad-TTS for text to audio generation / conditional audio generation in [pipeline_grad_tts](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/grad_tts/pipeline_grad_tts.py).
|
||||
|
||||
@@ -1,10 +1,16 @@
|
||||
from ..utils import is_inflect_available, is_transformers_available, is_unidecode_available
|
||||
# flake8: noqa
|
||||
# There's no way to ignore "F401 '...' imported but unused" warnings in this
|
||||
# module, but to preserve other warnings. So, don't check this module at all.
|
||||
|
||||
from ..utils import is_transformers_available
|
||||
from .ddim import DDIMPipeline
|
||||
from .ddpm import DDPMPipeline
|
||||
from .latent_diffusion_uncond import LDMPipeline
|
||||
from .pndm import PNDMPipeline
|
||||
from .score_sde_ve import ScoreSdeVePipeline
|
||||
from .stochatic_karras_ve import KarrasVePipeline
|
||||
|
||||
|
||||
if is_transformers_available():
|
||||
from .latent_diffusion import LDMTextToImagePipeline
|
||||
from .stable_diffusion import StableDiffusionPipeline
|
||||
|
||||
@@ -1 +1,2 @@
|
||||
# flake8: noqa
|
||||
from .pipeline_ddim import DDIMPipeline
|
||||
|
||||
@@ -1 +1,2 @@
|
||||
# flake8: noqa
|
||||
from .pipeline_ddpm import DDPMPipeline
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
# flake8: noqa
|
||||
from ...utils import is_transformers_available
|
||||
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import inspect
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
@@ -45,11 +46,11 @@ class LDMTextToImagePipeline(DiffusionPipeline):
|
||||
# get unconditional embeddings for classifier free guidance
|
||||
if guidance_scale != 1.0:
|
||||
uncond_input = self.tokenizer([""] * batch_size, padding="max_length", max_length=77, return_tensors="pt")
|
||||
uncond_embeddings = self.bert(uncond_input.input_ids.to(torch_device))
|
||||
uncond_embeddings = self.bert(uncond_input.input_ids.to(torch_device))[0]
|
||||
|
||||
# get prompt text embeddings
|
||||
text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt")
|
||||
text_embeddings = self.bert(text_input.input_ids.to(torch_device))
|
||||
text_embeddings = self.bert(text_input.input_ids.to(torch_device))[0]
|
||||
|
||||
latents = torch.randn(
|
||||
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
|
||||
@@ -59,6 +60,13 @@ class LDMTextToImagePipeline(DiffusionPipeline):
|
||||
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
|
||||
extra_kwargs = {}
|
||||
if accepts_eta:
|
||||
extra_kwargs["eta"] = eta
|
||||
|
||||
for t in tqdm(self.scheduler.timesteps):
|
||||
if guidance_scale == 1.0:
|
||||
# guidance_scale of 1 means no guidance
|
||||
@@ -79,7 +87,7 @@ class LDMTextToImagePipeline(DiffusionPipeline):
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, eta)["prev_sample"]
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_kwargs)["prev_sample"]
|
||||
|
||||
# scale and decode the image latents with vae
|
||||
latents = 1 / 0.18215 * latents
|
||||
@@ -618,5 +626,4 @@ class LDMBertModel(LDMBertPreTrainedModel):
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
sequence_output = outputs[0]
|
||||
return sequence_output
|
||||
return outputs
|
||||
|
||||
@@ -1 +1,2 @@
|
||||
# flake8: noqa
|
||||
from .pipeline_latent_diffusion_uncond import LDMPipeline
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import inspect
|
||||
|
||||
import torch
|
||||
|
||||
from tqdm.auto import tqdm
|
||||
@@ -31,11 +33,18 @@ class LDMPipeline(DiffusionPipeline):
|
||||
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
|
||||
extra_kwargs = {}
|
||||
if accepts_eta:
|
||||
extra_kwargs["eta"] = eta
|
||||
|
||||
for t in tqdm(self.scheduler.timesteps):
|
||||
# predict the noise residual
|
||||
noise_prediction = self.unet(latents, t)["sample"]
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_prediction, t, latents, eta)["prev_sample"]
|
||||
latents = self.scheduler.step(noise_prediction, t, latents, **extra_kwargs)["prev_sample"]
|
||||
|
||||
# decode the image latents with the VAE
|
||||
image = self.vqvae.decode(latents)
|
||||
|
||||
@@ -1 +1,2 @@
|
||||
# flake8: noqa
|
||||
from .pipeline_pndm import PNDMPipeline
|
||||
|
||||
@@ -1 +1,2 @@
|
||||
# flake8: noqa
|
||||
from .pipeline_score_sde_ve import ScoreSdeVePipeline
|
||||
|
||||
@@ -6,31 +6,33 @@ from tqdm.auto import tqdm
|
||||
|
||||
|
||||
class ScoreSdeVePipeline(DiffusionPipeline):
|
||||
def __init__(self, model, scheduler):
|
||||
def __init__(self, unet, scheduler):
|
||||
super().__init__()
|
||||
self.register_modules(model=model, scheduler=scheduler)
|
||||
self.register_modules(unet=unet, scheduler=scheduler)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, num_inference_steps=2000, generator=None, output_type="pil"):
|
||||
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||
def __call__(self, batch_size=1, num_inference_steps=2000, generator=None, torch_device=None, output_type="pil"):
|
||||
|
||||
img_size = self.model.config.sample_size
|
||||
shape = (1, 3, img_size, img_size)
|
||||
if torch_device is None:
|
||||
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
model = self.model.to(device)
|
||||
img_size = self.unet.config.sample_size
|
||||
shape = (batch_size, 3, img_size, img_size)
|
||||
|
||||
model = self.unet.to(torch_device)
|
||||
|
||||
sample = torch.randn(*shape) * self.scheduler.config.sigma_max
|
||||
sample = sample.to(device)
|
||||
sample = sample.to(torch_device)
|
||||
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
self.scheduler.set_sigmas(num_inference_steps)
|
||||
|
||||
for i, t in tqdm(enumerate(self.scheduler.timesteps)):
|
||||
sigma_t = self.scheduler.sigmas[i] * torch.ones(shape[0], device=device)
|
||||
sigma_t = self.scheduler.sigmas[i] * torch.ones(shape[0], device=torch_device)
|
||||
|
||||
# correction step
|
||||
for _ in range(self.scheduler.correct_steps):
|
||||
model_output = self.model(sample, sigma_t)["sample"]
|
||||
model_output = self.unet(sample, sigma_t)["sample"]
|
||||
sample = self.scheduler.step_correct(model_output, sample)["prev_sample"]
|
||||
|
||||
# prediction step
|
||||
@@ -39,7 +41,7 @@ class ScoreSdeVePipeline(DiffusionPipeline):
|
||||
|
||||
sample, sample_mean = output["prev_sample"], output["prev_sample_mean"]
|
||||
|
||||
sample = sample.clamp(0, 1)
|
||||
sample = sample_mean.clamp(0, 1)
|
||||
sample = sample.cpu().permute(0, 2, 3, 1).numpy()
|
||||
if output_type == "pil":
|
||||
sample = self.numpy_to_pil(sample)
|
||||
|
||||
6
src/diffusers/pipelines/stable_diffusion/__init__.py
Normal file
6
src/diffusers/pipelines/stable_diffusion/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
# flake8: noqa
|
||||
from ...utils import is_transformers_available
|
||||
|
||||
|
||||
if is_transformers_available():
|
||||
from .pipeline_stable_diffusion import StableDiffusionPipeline
|
||||
@@ -0,0 +1,142 @@
|
||||
import inspect
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...pipeline_utils import DiffusionPipeline
|
||||
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
||||
|
||||
|
||||
class StableDiffusionPipeline(DiffusionPipeline):
|
||||
def __init__(
|
||||
self,
|
||||
vae: AutoencoderKL,
|
||||
text_encoder: CLIPTextModel,
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
|
||||
):
|
||||
super().__init__()
|
||||
scheduler = scheduler.set_format("pt")
|
||||
self.register_modules(vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
height: Optional[int] = 512,
|
||||
width: Optional[int] = 512,
|
||||
num_inference_steps: Optional[int] = 50,
|
||||
guidance_scale: Optional[float] = 7.5,
|
||||
eta: Optional[float] = 0.0,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
torch_device: Optional[Union[str, torch.device]] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
):
|
||||
if torch_device is None:
|
||||
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
if isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if height % 8 != 0 or width % 8 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
||||
|
||||
self.unet.to(torch_device)
|
||||
self.vae.to(torch_device)
|
||||
self.text_encoder.to(torch_device)
|
||||
|
||||
# get prompt text embeddings
|
||||
text_input = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_embeddings = self.text_encoder(text_input.input_ids.to(torch_device))[0]
|
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
# get unconditional embeddings for classifier free guidance
|
||||
if do_classifier_free_guidance:
|
||||
max_length = text_input.input_ids.shape[-1]
|
||||
uncond_input = self.tokenizer(
|
||||
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
|
||||
)
|
||||
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(torch_device))[0]
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
||||
|
||||
# get the intial random noise
|
||||
latents = torch.randn(
|
||||
(batch_size, self.unet.in_channels, height // 8, width // 8),
|
||||
generator=generator,
|
||||
device=torch_device,
|
||||
)
|
||||
|
||||
# set timesteps
|
||||
accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())
|
||||
extra_set_kwargs = {}
|
||||
if accepts_offset:
|
||||
extra_set_kwargs["offset"] = 1
|
||||
|
||||
self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
|
||||
|
||||
# if we use LMSDiscreteScheduler, let's make sure latents are mulitplied by sigmas
|
||||
if isinstance(self.scheduler, LMSDiscreteScheduler):
|
||||
latents = latents * self.scheduler.sigmas[0]
|
||||
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
||||
# and should be between [0, 1]
|
||||
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
extra_step_kwargs = {}
|
||||
if accepts_eta:
|
||||
extra_step_kwargs["eta"] = eta
|
||||
|
||||
for i, t in tqdm(enumerate(self.scheduler.timesteps)):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
if isinstance(self.scheduler, LMSDiscreteScheduler):
|
||||
sigma = self.scheduler.sigmas[i]
|
||||
latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"]
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
if isinstance(self.scheduler, LMSDiscreteScheduler):
|
||||
latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs)["prev_sample"]
|
||||
else:
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs)["prev_sample"]
|
||||
|
||||
# scale and decode the image latents with vae
|
||||
latents = 1 / 0.18215 * latents
|
||||
image = self.vae.decode(latents)
|
||||
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
||||
if output_type == "pil":
|
||||
image = self.numpy_to_pil(image)
|
||||
|
||||
return {"sample": image}
|
||||
2
src/diffusers/pipelines/stochatic_karras_ve/__init__.py
Normal file
2
src/diffusers/pipelines/stochatic_karras_ve/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# flake8: noqa
|
||||
from .pipeline_stochastic_karras_ve import KarrasVePipeline
|
||||
@@ -0,0 +1,81 @@
|
||||
#!/usr/bin/env python3
|
||||
import torch
|
||||
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from ...models import UNet2DModel
|
||||
from ...pipeline_utils import DiffusionPipeline
|
||||
from ...schedulers import KarrasVeScheduler
|
||||
|
||||
|
||||
class KarrasVePipeline(DiffusionPipeline):
|
||||
"""
|
||||
Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Use Algorithm 2 and
|
||||
the VE column of Table 1 from [1] for reference.
|
||||
|
||||
[1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models."
|
||||
https://arxiv.org/abs/2206.00364 [2] Song, Yang, et al. "Score-based generative modeling through stochastic
|
||||
differential equations." https://arxiv.org/abs/2011.13456
|
||||
"""
|
||||
|
||||
unet: UNet2DModel
|
||||
scheduler: KarrasVeScheduler
|
||||
|
||||
def __init__(self, unet, scheduler):
|
||||
super().__init__()
|
||||
scheduler = scheduler.set_format("pt")
|
||||
self.register_modules(unet=unet, scheduler=scheduler)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, batch_size=1, num_inference_steps=50, generator=None, torch_device=None, output_type="pil"):
|
||||
if torch_device is None:
|
||||
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
img_size = self.unet.config.sample_size
|
||||
shape = (batch_size, 3, img_size, img_size)
|
||||
|
||||
model = self.unet.to(torch_device)
|
||||
|
||||
# sample x_0 ~ N(0, sigma_0^2 * I)
|
||||
sample = torch.randn(*shape) * self.scheduler.config.sigma_max
|
||||
sample = sample.to(torch_device)
|
||||
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
for t in tqdm(self.scheduler.timesteps):
|
||||
# here sigma_t == t_i from the paper
|
||||
sigma = self.scheduler.schedule[t]
|
||||
sigma_prev = self.scheduler.schedule[t - 1] if t > 0 else 0
|
||||
|
||||
# 1. Select temporarily increased noise level sigma_hat
|
||||
# 2. Add new noise to move from sample_i to sample_hat
|
||||
sample_hat, sigma_hat = self.scheduler.add_noise_to_input(sample, sigma, generator=generator)
|
||||
|
||||
# 3. Predict the noise residual given the noise magnitude `sigma_hat`
|
||||
# The model inputs and output are adjusted by following eq. (213) in [1].
|
||||
model_output = (sigma_hat / 2) * model((sample_hat + 1) / 2, sigma_hat / 2)["sample"]
|
||||
|
||||
# 4. Evaluate dx/dt at sigma_hat
|
||||
# 5. Take Euler step from sigma to sigma_prev
|
||||
step_output = self.scheduler.step(model_output, sigma_hat, sigma_prev, sample_hat)
|
||||
|
||||
if sigma_prev != 0:
|
||||
# 6. Apply 2nd order correction
|
||||
# The model inputs and output are adjusted by following eq. (213) in [1].
|
||||
model_output = (sigma_prev / 2) * model((step_output["prev_sample"] + 1) / 2, sigma_prev / 2)["sample"]
|
||||
step_output = self.scheduler.step_correct(
|
||||
model_output,
|
||||
sigma_hat,
|
||||
sigma_prev,
|
||||
sample_hat,
|
||||
step_output["prev_sample"],
|
||||
step_output["derivative"],
|
||||
)
|
||||
sample = step_output["prev_sample"]
|
||||
|
||||
sample = (sample / 2 + 0.5).clamp(0, 1)
|
||||
sample = sample.cpu().permute(0, 2, 3, 1).numpy()
|
||||
if output_type == "pil":
|
||||
sample = self.numpy_to_pil(sample)
|
||||
|
||||
return {"sample": sample}
|
||||
@@ -1,18 +1,18 @@
|
||||
# Schedulers
|
||||
|
||||
- Schedulers are the algorithms to use diffusion models in inference as well as for training. They include the noise schedules and define algorithm-specific diffusion steps.
|
||||
- Schedulers can be used interchangable between diffusion models in inference to find the preferred tradef-off between speed and generation quality.
|
||||
- Schedulers can be used interchangable between diffusion models in inference to find the preferred trade-off between speed and generation quality.
|
||||
- Schedulers are available in numpy, but can easily be transformed into PyTorch.
|
||||
|
||||
## API
|
||||
|
||||
- Schedulers should provide one or more `def step(...)` functions that should be called iteratively to unroll the diffusion loop during
|
||||
the forward pass.
|
||||
- Schedulers should be framework-agonstic, but provide a simple functionality to convert the scheduler into a specific framework, such as PyTorch
|
||||
- Schedulers should be framework-agnostic, but provide a simple functionality to convert the scheduler into a specific framework, such as PyTorch
|
||||
with a `set_format(...)` method.
|
||||
|
||||
## Examples
|
||||
|
||||
- The DDPM scheduler was proposed in [Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239) and can be found in [scheduling_ddpm.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_ddpm.py). An example of how to use this scheduler can be found in [pipeline_ddpm.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_ddpm.py).
|
||||
- The DDIM scheduler was proposed in [Denoising Diffusion Implicit Models](https://arxiv.org/abs/2010.02502) and can be found in [scheduling_ddim.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_ddim.py). An example of how to use this scheduler can be found in [pipeline_ddim.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_ddim.py).
|
||||
- The PNMD scheduler was proposed in [Pseudo Numerical Methods for Diffusion Models on Manifolds](https://arxiv.org/abs/2202.09778) and can be found in [scheduling_pndm.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_pndm.py). An example of how to use this scheduler can be found in [pipeline_pndm.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_pndm.py).
|
||||
- The PNDM scheduler was proposed in [Pseudo Numerical Methods for Diffusion Models on Manifolds](https://arxiv.org/abs/2202.09778) and can be found in [scheduling_pndm.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_pndm.py). An example of how to use this scheduler can be found in [pipeline_pndm.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_pndm.py).
|
||||
|
||||
@@ -16,9 +16,17 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ..utils import is_scipy_available
|
||||
from .scheduling_ddim import DDIMScheduler
|
||||
from .scheduling_ddpm import DDPMScheduler
|
||||
from .scheduling_karras_ve import KarrasVeScheduler
|
||||
from .scheduling_pndm import PNDMScheduler
|
||||
from .scheduling_sde_ve import ScoreSdeVeScheduler
|
||||
from .scheduling_sde_vp import ScoreSdeVpScheduler
|
||||
from .scheduling_utils import SchedulerMixin
|
||||
|
||||
|
||||
if is_scipy_available():
|
||||
from .scheduling_lms_discrete import LMSDiscreteScheduler
|
||||
else:
|
||||
from ..utils.dummy_scipy_objects import *
|
||||
|
||||
@@ -59,6 +59,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
trained_betas=None,
|
||||
timestep_values=None,
|
||||
clip_sample=True,
|
||||
set_alpha_to_one=True,
|
||||
tensor_format="pt",
|
||||
):
|
||||
|
||||
@@ -75,7 +76,12 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
self.alphas = 1.0 - self.betas
|
||||
self.alphas_cumprod = np.cumprod(self.alphas, axis=0)
|
||||
self.one = np.array(1.0)
|
||||
|
||||
# At every step in ddim, we are looking into the previous alphas_cumprod
|
||||
# For the final step, there is no previous alphas_cumprod because we are already at 0
|
||||
# `set_alpha_to_one` decides whether we set this paratemer simply to one or
|
||||
# whether we use the final alpha of the "non-previous" one.
|
||||
self.final_alpha_cumprod = np.array(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
|
||||
|
||||
# setable values
|
||||
self.num_inference_steps = None
|
||||
@@ -86,7 +92,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
def _get_variance(self, timestep, prev_timestep):
|
||||
alpha_prod_t = self.alphas_cumprod[timestep]
|
||||
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.one
|
||||
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
|
||||
beta_prod_t = 1 - alpha_prod_t
|
||||
beta_prod_t_prev = 1 - alpha_prod_t_prev
|
||||
|
||||
@@ -94,11 +100,12 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
return variance
|
||||
|
||||
def set_timesteps(self, num_inference_steps):
|
||||
def set_timesteps(self, num_inference_steps, offset=0):
|
||||
self.num_inference_steps = num_inference_steps
|
||||
self.timesteps = np.arange(
|
||||
0, self.config.num_train_timesteps, self.config.num_train_timesteps // self.num_inference_steps
|
||||
)[::-1].copy()
|
||||
self.timesteps += offset
|
||||
self.set_format(tensor_format=self.tensor_format)
|
||||
|
||||
def step(
|
||||
@@ -126,7 +133,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# 2. compute alphas, betas
|
||||
alpha_prod_t = self.alphas_cumprod[timestep]
|
||||
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.one
|
||||
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
|
||||
beta_prod_t = 1 - alpha_prod_t
|
||||
|
||||
# 3. compute predicted original sample from predicted noise also called
|
||||
|
||||
@@ -65,6 +65,9 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.betas = np.asarray(trained_betas)
|
||||
elif beta_schedule == "linear":
|
||||
self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32)
|
||||
elif beta_schedule == "scaled_linear":
|
||||
# this schedule is very specific to the latent diffusion model.
|
||||
self.betas = np.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=np.float32) ** 2
|
||||
elif beta_schedule == "squaredcos_cap_v2":
|
||||
# Glide cosine schedule
|
||||
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
||||
@@ -82,6 +85,8 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.tensor_format = tensor_format
|
||||
self.set_format(tensor_format=tensor_format)
|
||||
|
||||
self.variance_type = variance_type
|
||||
|
||||
def set_timesteps(self, num_inference_steps):
|
||||
num_inference_steps = min(self.config.num_train_timesteps, num_inference_steps)
|
||||
self.num_inference_steps = num_inference_steps
|
||||
@@ -90,7 +95,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
)[::-1].copy()
|
||||
self.set_format(tensor_format=self.tensor_format)
|
||||
|
||||
def _get_variance(self, t, variance_type=None):
|
||||
def _get_variance(self, t, predicted_variance=None, variance_type=None):
|
||||
alpha_prod_t = self.alphas_cumprod[t]
|
||||
alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one
|
||||
|
||||
@@ -113,6 +118,13 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
elif variance_type == "fixed_large_log":
|
||||
# Glide max_log
|
||||
variance = self.log(self.betas[t])
|
||||
elif variance_type == "learned":
|
||||
return predicted_variance
|
||||
elif variance_type == "learned_range":
|
||||
min_log = variance
|
||||
max_log = self.betas[t]
|
||||
frac = (predicted_variance + 1) / 2
|
||||
variance = frac * max_log + (1 - frac) * min_log
|
||||
|
||||
return variance
|
||||
|
||||
@@ -125,6 +137,12 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
generator=None,
|
||||
):
|
||||
t = timestep
|
||||
|
||||
if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]:
|
||||
model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1)
|
||||
else:
|
||||
predicted_variance = None
|
||||
|
||||
# 1. compute alphas, betas
|
||||
alpha_prod_t = self.alphas_cumprod[t]
|
||||
alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one
|
||||
@@ -155,7 +173,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
variance = 0
|
||||
if t > 0:
|
||||
noise = self.randn_like(model_output, generator=generator)
|
||||
variance = (self._get_variance(t) ** 0.5) * noise
|
||||
variance = (self._get_variance(t, predicted_variance=predicted_variance) ** 0.5) * noise
|
||||
|
||||
pred_prev_sample = pred_prev_sample + variance
|
||||
|
||||
|
||||
127
src/diffusers/schedulers/scheduling_karras_ve.py
Normal file
127
src/diffusers/schedulers/scheduling_karras_ve.py
Normal file
@@ -0,0 +1,127 @@
|
||||
# Copyright 2022 NVIDIA and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from .scheduling_utils import SchedulerMixin
|
||||
|
||||
|
||||
class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Use Algorithm 2 and
|
||||
the VE column of Table 1 from [1] for reference.
|
||||
|
||||
[1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models."
|
||||
https://arxiv.org/abs/2206.00364 [2] Song, Yang, et al. "Score-based generative modeling through stochastic
|
||||
differential equations." https://arxiv.org/abs/2011.13456
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
sigma_min=0.02,
|
||||
sigma_max=100,
|
||||
s_noise=1.007,
|
||||
s_churn=80,
|
||||
s_min=0.05,
|
||||
s_max=50,
|
||||
tensor_format="pt",
|
||||
):
|
||||
"""
|
||||
For more details on the parameters, see the original paper's Appendix E.: "Elucidating the Design Space of
|
||||
Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364. The grid search values used to find the
|
||||
optimal {s_noise, s_churn, s_min, s_max} for a specific model are described in Table 5 of the paper.
|
||||
|
||||
Args:
|
||||
sigma_min (`float`): minimum noise magnitude
|
||||
sigma_max (`float`): maximum noise magnitude
|
||||
s_noise (`float`): the amount of additional noise to counteract loss of detail during sampling.
|
||||
A reasonable range is [1.000, 1.011].
|
||||
s_churn (`float`): the parameter controlling the overall amount of stochasticity.
|
||||
A reasonable range is [0, 100].
|
||||
s_min (`float`): the start value of the sigma range where we add noise (enable stochasticity).
|
||||
A reasonable range is [0, 10].
|
||||
s_max (`float`): the end value of the sigma range where we add noise.
|
||||
A reasonable range is [0.2, 80].
|
||||
"""
|
||||
# setable values
|
||||
self.num_inference_steps = None
|
||||
self.timesteps = None
|
||||
self.schedule = None # sigma(t_i)
|
||||
|
||||
self.tensor_format = tensor_format
|
||||
self.set_format(tensor_format=tensor_format)
|
||||
|
||||
def set_timesteps(self, num_inference_steps):
|
||||
self.num_inference_steps = num_inference_steps
|
||||
self.timesteps = np.arange(0, self.num_inference_steps)[::-1].copy()
|
||||
self.schedule = [
|
||||
(self.sigma_max * (self.sigma_min**2 / self.sigma_max**2) ** (i / (num_inference_steps - 1)))
|
||||
for i in self.timesteps
|
||||
]
|
||||
self.schedule = np.array(self.schedule, dtype=np.float32)
|
||||
|
||||
self.set_format(tensor_format=self.tensor_format)
|
||||
|
||||
def add_noise_to_input(self, sample, sigma, generator=None):
|
||||
"""
|
||||
Explicit Langevin-like "churn" step of adding noise to the sample according to a factor gamma_i ≥ 0 to reach a
|
||||
higher noise level sigma_hat = sigma_i + gamma_i*sigma_i.
|
||||
"""
|
||||
if self.s_min <= sigma <= self.s_max:
|
||||
gamma = min(self.s_churn / self.num_inference_steps, 2**0.5 - 1)
|
||||
else:
|
||||
gamma = 0
|
||||
|
||||
# sample eps ~ N(0, S_noise^2 * I)
|
||||
eps = self.s_noise * torch.randn(sample.shape, generator=generator).to(sample.device)
|
||||
sigma_hat = sigma + gamma * sigma
|
||||
sample_hat = sample + ((sigma_hat**2 - sigma**2) ** 0.5 * eps)
|
||||
|
||||
return sample_hat, sigma_hat
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: Union[torch.FloatTensor, np.ndarray],
|
||||
sigma_hat: float,
|
||||
sigma_prev: float,
|
||||
sample_hat: Union[torch.FloatTensor, np.ndarray],
|
||||
):
|
||||
pred_original_sample = sample_hat + sigma_hat * model_output
|
||||
derivative = (sample_hat - pred_original_sample) / sigma_hat
|
||||
sample_prev = sample_hat + (sigma_prev - sigma_hat) * derivative
|
||||
|
||||
return {"prev_sample": sample_prev, "derivative": derivative}
|
||||
|
||||
def step_correct(
|
||||
self,
|
||||
model_output: Union[torch.FloatTensor, np.ndarray],
|
||||
sigma_hat: float,
|
||||
sigma_prev: float,
|
||||
sample_hat: Union[torch.FloatTensor, np.ndarray],
|
||||
sample_prev: Union[torch.FloatTensor, np.ndarray],
|
||||
derivative: Union[torch.FloatTensor, np.ndarray],
|
||||
):
|
||||
pred_original_sample = sample_prev + sigma_prev * model_output
|
||||
derivative_corr = (sample_prev - pred_original_sample) / sigma_prev
|
||||
sample_prev = sample_hat + (sigma_prev - sigma_hat) * (0.5 * derivative + 0.5 * derivative_corr)
|
||||
return {"prev_sample": sample_prev, "derivative": derivative_corr}
|
||||
|
||||
def add_noise(self, original_samples, noise, timesteps):
|
||||
raise NotImplementedError()
|
||||
134
src/diffusers/schedulers/scheduling_lms_discrete.py
Normal file
134
src/diffusers/schedulers/scheduling_lms_discrete.py
Normal file
@@ -0,0 +1,134 @@
|
||||
# Copyright 2022 Katherine Crowson and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from scipy import integrate
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from .scheduling_utils import SchedulerMixin
|
||||
|
||||
|
||||
class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_train_timesteps=1000,
|
||||
beta_start=0.0001,
|
||||
beta_end=0.02,
|
||||
beta_schedule="linear",
|
||||
trained_betas=None,
|
||||
timestep_values=None,
|
||||
tensor_format="pt",
|
||||
):
|
||||
"""
|
||||
Linear Multistep Scheduler for discrete beta schedules. Based on the original k-diffusion implementation by
|
||||
Katherine Crowson:
|
||||
https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L181
|
||||
"""
|
||||
|
||||
if beta_schedule == "linear":
|
||||
self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32)
|
||||
elif beta_schedule == "scaled_linear":
|
||||
# this schedule is very specific to the latent diffusion model.
|
||||
self.betas = np.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=np.float32) ** 2
|
||||
else:
|
||||
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
|
||||
|
||||
self.alphas = 1.0 - self.betas
|
||||
self.alphas_cumprod = np.cumprod(self.alphas, axis=0)
|
||||
|
||||
self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5
|
||||
|
||||
# setable values
|
||||
self.num_inference_steps = None
|
||||
self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy()
|
||||
self.derivatives = []
|
||||
|
||||
self.tensor_format = tensor_format
|
||||
self.set_format(tensor_format=tensor_format)
|
||||
|
||||
def get_lms_coefficient(self, order, t, current_order):
|
||||
"""
|
||||
Compute a linear multistep coefficient
|
||||
"""
|
||||
|
||||
def lms_derivative(tau):
|
||||
prod = 1.0
|
||||
for k in range(order):
|
||||
if current_order == k:
|
||||
continue
|
||||
prod *= (tau - self.sigmas[t - k]) / (self.sigmas[t - current_order] - self.sigmas[t - k])
|
||||
return prod
|
||||
|
||||
integrated_coeff = integrate.quad(lms_derivative, self.sigmas[t], self.sigmas[t + 1], epsrel=1e-4)[0]
|
||||
|
||||
return integrated_coeff
|
||||
|
||||
def set_timesteps(self, num_inference_steps):
|
||||
self.num_inference_steps = num_inference_steps
|
||||
self.timesteps = np.linspace(self.num_train_timesteps - 1, 0, num_inference_steps, dtype=float)
|
||||
|
||||
low_idx = np.floor(self.timesteps).astype(int)
|
||||
high_idx = np.ceil(self.timesteps).astype(int)
|
||||
frac = np.mod(self.timesteps, 1.0)
|
||||
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
||||
sigmas = (1 - frac) * sigmas[low_idx] + frac * sigmas[high_idx]
|
||||
self.sigmas = np.concatenate([sigmas, [0.0]])
|
||||
|
||||
self.derivatives = []
|
||||
|
||||
self.set_format(tensor_format=self.tensor_format)
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: Union[torch.FloatTensor, np.ndarray],
|
||||
timestep: int,
|
||||
sample: Union[torch.FloatTensor, np.ndarray],
|
||||
order: int = 4,
|
||||
):
|
||||
sigma = self.sigmas[timestep]
|
||||
|
||||
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
|
||||
pred_original_sample = sample - sigma * model_output
|
||||
|
||||
# 2. Convert to an ODE derivative
|
||||
derivative = (sample - pred_original_sample) / sigma
|
||||
self.derivatives.append(derivative)
|
||||
if len(self.derivatives) > order:
|
||||
self.derivatives.pop(0)
|
||||
|
||||
# 3. Compute linear multistep coefficients
|
||||
order = min(timestep + 1, order)
|
||||
lms_coeffs = [self.get_lms_coefficient(order, timestep, curr_order) for curr_order in range(order)]
|
||||
|
||||
# 4. Compute previous sample based on the derivatives path
|
||||
prev_sample = sample + sum(
|
||||
coeff * derivative for coeff, derivative in zip(lms_coeffs, reversed(self.derivatives))
|
||||
)
|
||||
|
||||
return {"prev_sample": prev_sample}
|
||||
|
||||
def add_noise(self, original_samples, noise, timesteps):
|
||||
alpha_prod = self.alphas_cumprod[timesteps]
|
||||
alpha_prod = self.match_shape(alpha_prod, original_samples)
|
||||
|
||||
noisy_samples = (alpha_prod**0.5) * original_samples + ((1 - alpha_prod) ** 0.5) * noise
|
||||
return noisy_samples
|
||||
|
||||
def __len__(self):
|
||||
return self.config.num_train_timesteps
|
||||
@@ -56,10 +56,14 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
beta_end=0.02,
|
||||
beta_schedule="linear",
|
||||
tensor_format="pt",
|
||||
skip_prk_steps=False,
|
||||
):
|
||||
|
||||
if beta_schedule == "linear":
|
||||
self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32)
|
||||
elif beta_schedule == "scaled_linear":
|
||||
# this schedule is very specific to the latent diffusion model.
|
||||
self.betas = np.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=np.float32) ** 2
|
||||
elif beta_schedule == "squaredcos_cap_v2":
|
||||
# Glide cosine schedule
|
||||
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
||||
@@ -85,6 +89,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
# setable values
|
||||
self.num_inference_steps = None
|
||||
self._timesteps = np.arange(0, num_train_timesteps)[::-1].copy()
|
||||
self._offset = 0
|
||||
self.prk_timesteps = None
|
||||
self.plms_timesteps = None
|
||||
self.timesteps = None
|
||||
@@ -92,19 +97,30 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.tensor_format = tensor_format
|
||||
self.set_format(tensor_format=tensor_format)
|
||||
|
||||
def set_timesteps(self, num_inference_steps):
|
||||
def set_timesteps(self, num_inference_steps, offset=0):
|
||||
self.num_inference_steps = num_inference_steps
|
||||
self._timesteps = list(
|
||||
range(0, self.config.num_train_timesteps, self.config.num_train_timesteps // num_inference_steps)
|
||||
)
|
||||
self._offset = offset
|
||||
self._timesteps = [t + self._offset for t in self._timesteps]
|
||||
|
||||
if self.config.skip_prk_steps:
|
||||
# for some models like stable diffusion the prk steps can/should be skipped to
|
||||
# produce better results. When using PNDM with `self.config.skip_prk_steps` the implementation
|
||||
# is based on crowsonkb's PLMS sampler implementation: https://github.com/CompVis/latent-diffusion/pull/51
|
||||
self.prk_timesteps = []
|
||||
self.plms_timesteps = list(reversed(self._timesteps[:-1] + self._timesteps[-2:-1] + self._timesteps[-1:]))
|
||||
else:
|
||||
prk_timesteps = np.array(self._timesteps[-self.pndm_order :]).repeat(2) + np.tile(
|
||||
np.array([0, self.config.num_train_timesteps // num_inference_steps // 2]), self.pndm_order
|
||||
)
|
||||
self.prk_timesteps = list(reversed(prk_timesteps[:-1].repeat(2)[1:-1]))
|
||||
self.plms_timesteps = list(reversed(self._timesteps[:-3]))
|
||||
|
||||
prk_timesteps = np.array(self._timesteps[-self.pndm_order :]).repeat(2) + np.tile(
|
||||
np.array([0, self.config.num_train_timesteps // num_inference_steps // 2]), self.pndm_order
|
||||
)
|
||||
self.prk_timesteps = list(reversed(prk_timesteps[:-1].repeat(2)[1:-1]))
|
||||
self.plms_timesteps = list(reversed(self._timesteps[:-3]))
|
||||
self.timesteps = self.prk_timesteps + self.plms_timesteps
|
||||
|
||||
self.ets = []
|
||||
self.counter = 0
|
||||
self.set_format(tensor_format=self.tensor_format)
|
||||
|
||||
@@ -114,7 +130,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
timestep: int,
|
||||
sample: Union[torch.FloatTensor, np.ndarray],
|
||||
):
|
||||
if self.counter < len(self.prk_timesteps):
|
||||
if self.counter < len(self.prk_timesteps) and not self.config.skip_prk_steps:
|
||||
return self.step_prk(model_output=model_output, timestep=timestep, sample=sample)
|
||||
else:
|
||||
return self.step_plms(model_output=model_output, timestep=timestep, sample=sample)
|
||||
@@ -163,7 +179,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple
|
||||
times to approximate the solution.
|
||||
"""
|
||||
if len(self.ets) < 3:
|
||||
if not self.config.skip_prk_steps and len(self.ets) < 3:
|
||||
raise ValueError(
|
||||
f"{self.__class__} can only be run AFTER scheduler has been run "
|
||||
"in 'prk' mode for at least 12 iterations "
|
||||
@@ -172,9 +188,26 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
)
|
||||
|
||||
prev_timestep = max(timestep - self.config.num_train_timesteps // self.num_inference_steps, 0)
|
||||
self.ets.append(model_output)
|
||||
|
||||
model_output = (1 / 24) * (55 * self.ets[-1] - 59 * self.ets[-2] + 37 * self.ets[-3] - 9 * self.ets[-4])
|
||||
if self.counter != 1:
|
||||
self.ets.append(model_output)
|
||||
else:
|
||||
prev_timestep = timestep
|
||||
timestep = timestep + self.config.num_train_timesteps // self.num_inference_steps
|
||||
|
||||
if len(self.ets) == 1 and self.counter == 0:
|
||||
model_output = model_output
|
||||
self.cur_sample = sample
|
||||
elif len(self.ets) == 1 and self.counter == 1:
|
||||
model_output = (model_output + self.ets[-1]) / 2
|
||||
sample = self.cur_sample
|
||||
self.cur_sample = None
|
||||
elif len(self.ets) == 2:
|
||||
model_output = (3 * self.ets[-1] - self.ets[-2]) / 2
|
||||
elif len(self.ets) == 3:
|
||||
model_output = (23 * self.ets[-1] - 16 * self.ets[-2] + 5 * self.ets[-3]) / 12
|
||||
else:
|
||||
model_output = (1 / 24) * (55 * self.ets[-1] - 59 * self.ets[-2] + 37 * self.ets[-3] - 9 * self.ets[-4])
|
||||
|
||||
prev_sample = self._get_prev_sample(sample, timestep, prev_timestep, model_output)
|
||||
self.counter += 1
|
||||
@@ -194,8 +227,8 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
# sample -> x_t
|
||||
# model_output -> e_θ(x_t, t)
|
||||
# prev_sample -> x_(t−δ)
|
||||
alpha_prod_t = self.alphas_cumprod[timestep + 1]
|
||||
alpha_prod_t_prev = self.alphas_cumprod[timestep_prev + 1]
|
||||
alpha_prod_t = self.alphas_cumprod[timestep + 1 - self._offset]
|
||||
alpha_prod_t_prev = self.alphas_cumprod[timestep_prev + 1 - self._offset]
|
||||
beta_prod_t = 1 - alpha_prod_t
|
||||
beta_prod_t_prev = 1 - alpha_prod_t_prev
|
||||
|
||||
|
||||
@@ -1,8 +1,44 @@
|
||||
import copy
|
||||
import os
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def enable_full_determinism(seed: int):
|
||||
"""
|
||||
Helper function for reproducible behavior during distributed training. See
|
||||
- https://pytorch.org/docs/stable/notes/randomness.html for pytorch
|
||||
"""
|
||||
# set seed first
|
||||
set_seed(seed)
|
||||
|
||||
# Enable PyTorch deterministic mode. This potentially requires either the environment
|
||||
# variable 'CUDA_LAUNCH_BLOCKING' or 'CUBLAS_WORKSPACE_CONFIG' to be set,
|
||||
# depending on the CUDA version, so we set them both here
|
||||
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
|
||||
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
|
||||
torch.use_deterministic_algorithms(True)
|
||||
|
||||
# Enable CUDNN deterministic mode
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
|
||||
|
||||
def set_seed(seed: int):
|
||||
"""
|
||||
Args:
|
||||
Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch`.
|
||||
seed (`int`): The seed to set.
|
||||
"""
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
# ^^ safe to call this function even if cuda is not available
|
||||
|
||||
|
||||
class EMAModel:
|
||||
"""
|
||||
Exponential Moving Average of models weights
|
||||
|
||||
@@ -1,4 +1,8 @@
|
||||
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
||||
# flake8: noqa
|
||||
# There's no way to ignore "F401 '...' imported but unused" warnings in this
|
||||
# module, but to preserve other warnings. So, don't check this module at all.
|
||||
|
||||
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -11,13 +15,26 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import importlib
|
||||
|
||||
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
|
||||
import importlib_metadata
|
||||
from requests.exceptions import HTTPError
|
||||
|
||||
from .import_utils import (
|
||||
ENV_VARS_TRUE_AND_AUTO_VALUES,
|
||||
ENV_VARS_TRUE_VALUES,
|
||||
USE_JAX,
|
||||
USE_TF,
|
||||
USE_TORCH,
|
||||
DummyObject,
|
||||
is_flax_available,
|
||||
is_inflect_available,
|
||||
is_scipy_available,
|
||||
is_tf_available,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
is_unidecode_available,
|
||||
requires_backends,
|
||||
)
|
||||
from .logging import get_logger
|
||||
|
||||
|
||||
@@ -35,116 +52,3 @@ HUGGINGFACE_CO_RESOLVE_ENDPOINT = "https://huggingface.co"
|
||||
DIFFUSERS_CACHE = default_cache_path
|
||||
DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules"
|
||||
HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules"))
|
||||
|
||||
|
||||
_transformers_available = importlib.util.find_spec("transformers") is not None
|
||||
try:
|
||||
_transformers_version = importlib_metadata.version("transformers")
|
||||
logger.debug(f"Successfully imported transformers version {_transformers_version}")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_transformers_available = False
|
||||
|
||||
|
||||
_inflect_available = importlib.util.find_spec("inflect") is not None
|
||||
try:
|
||||
_inflect_version = importlib_metadata.version("inflect")
|
||||
logger.debug(f"Successfully imported inflect version {_inflect_version}")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_inflect_available = False
|
||||
|
||||
|
||||
_unidecode_available = importlib.util.find_spec("unidecode") is not None
|
||||
try:
|
||||
_unidecode_version = importlib_metadata.version("unidecode")
|
||||
logger.debug(f"Successfully imported unidecode version {_unidecode_version}")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_unidecode_available = False
|
||||
|
||||
|
||||
_modelcards_available = importlib.util.find_spec("modelcards") is not None
|
||||
try:
|
||||
_modelcards_version = importlib_metadata.version("modelcards")
|
||||
logger.debug(f"Successfully imported modelcards version {_modelcards_version}")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_modelcards_available = False
|
||||
|
||||
|
||||
def is_transformers_available():
|
||||
return _transformers_available
|
||||
|
||||
|
||||
def is_inflect_available():
|
||||
return _inflect_available
|
||||
|
||||
|
||||
def is_unidecode_available():
|
||||
return _unidecode_available
|
||||
|
||||
|
||||
def is_modelcards_available():
|
||||
return _modelcards_available
|
||||
|
||||
|
||||
class RepositoryNotFoundError(HTTPError):
|
||||
"""
|
||||
Raised when trying to access a hf.co URL with an invalid repository name, or with a private repo name the user does
|
||||
not have access to.
|
||||
"""
|
||||
|
||||
|
||||
class EntryNotFoundError(HTTPError):
|
||||
"""Raised when trying to access a hf.co URL with a valid repository and revision but an invalid filename."""
|
||||
|
||||
|
||||
class RevisionNotFoundError(HTTPError):
|
||||
"""Raised when trying to access a hf.co URL with a valid repository but an invalid revision."""
|
||||
|
||||
|
||||
TRANSFORMERS_IMPORT_ERROR = """
|
||||
{0} requires the transformers library but it was not found in your environment. You can install it with pip: `pip
|
||||
install transformers`
|
||||
"""
|
||||
|
||||
|
||||
UNIDECODE_IMPORT_ERROR = """
|
||||
{0} requires the unidecode library but it was not found in your environment. You can install it with pip: `pip install
|
||||
Unidecode`
|
||||
"""
|
||||
|
||||
|
||||
INFLECT_IMPORT_ERROR = """
|
||||
{0} requires the inflect library but it was not found in your environment. You can install it with pip: `pip install
|
||||
inflect`
|
||||
"""
|
||||
|
||||
|
||||
BACKENDS_MAPPING = OrderedDict(
|
||||
[
|
||||
("transformers", (is_transformers_available, TRANSFORMERS_IMPORT_ERROR)),
|
||||
("unidecode", (is_unidecode_available, UNIDECODE_IMPORT_ERROR)),
|
||||
("inflect", (is_inflect_available, INFLECT_IMPORT_ERROR)),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def requires_backends(obj, backends):
|
||||
if not isinstance(backends, (list, tuple)):
|
||||
backends = [backends]
|
||||
|
||||
name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__
|
||||
checks = (BACKENDS_MAPPING[backend] for backend in backends)
|
||||
failed = [msg.format(name) for available, msg in checks if not available()]
|
||||
if failed:
|
||||
raise ImportError("".join(failed))
|
||||
|
||||
|
||||
class DummyObject(type):
|
||||
"""
|
||||
Metaclass for the dummy objects. Any class inheriting from it will return the ImportError generated by
|
||||
`requires_backend` each time a user tries to access any method of that class.
|
||||
"""
|
||||
|
||||
def __getattr__(cls, key):
|
||||
if key.startswith("_"):
|
||||
return super().__getattr__(cls, key)
|
||||
requires_backends(cls, cls._backends)
|
||||
|
||||
10
src/diffusers/utils/dummy_scipy_objects.py
Normal file
10
src/diffusers/utils/dummy_scipy_objects.py
Normal file
@@ -0,0 +1,10 @@
|
||||
# This file is autogenerated by the command `make fix-copies`, do not edit.
|
||||
# flake8: noqa
|
||||
from ..utils import DummyObject, requires_backends
|
||||
|
||||
|
||||
class LMSDiscreteScheduler(metaclass=DummyObject):
|
||||
_backends = ["scipy"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["scipy"])
|
||||
@@ -8,3 +8,10 @@ class LDMTextToImagePipeline(metaclass=DummyObject):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["transformers"])
|
||||
|
||||
|
||||
class StableDiffusionPipeline(metaclass=DummyObject):
|
||||
_backends = ["transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["transformers"])
|
||||
|
||||
255
src/diffusers/utils/import_utils.py
Normal file
255
src/diffusers/utils/import_utils.py
Normal file
@@ -0,0 +1,255 @@
|
||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Import utilities: Utilities related to imports and our lazy inits.
|
||||
"""
|
||||
import importlib.util
|
||||
import os
|
||||
import sys
|
||||
from collections import OrderedDict
|
||||
|
||||
from packaging import version
|
||||
|
||||
from . import logging
|
||||
|
||||
|
||||
# The package importlib_metadata is in a different place, depending on the python version.
|
||||
if sys.version_info < (3, 8):
|
||||
import importlib_metadata
|
||||
else:
|
||||
import importlib.metadata as importlib_metadata
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"}
|
||||
ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"})
|
||||
|
||||
USE_TF = os.environ.get("USE_TF", "AUTO").upper()
|
||||
USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
|
||||
USE_JAX = os.environ.get("USE_FLAX", "AUTO").upper()
|
||||
|
||||
_torch_version = "N/A"
|
||||
if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES:
|
||||
_torch_available = importlib.util.find_spec("torch") is not None
|
||||
if _torch_available:
|
||||
try:
|
||||
_torch_version = importlib_metadata.version("torch")
|
||||
logger.info(f"PyTorch version {_torch_version} available.")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_torch_available = False
|
||||
else:
|
||||
logger.info("Disabling PyTorch because USE_TF is set")
|
||||
_torch_available = False
|
||||
|
||||
|
||||
_tf_version = "N/A"
|
||||
if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VALUES:
|
||||
_tf_available = importlib.util.find_spec("tensorflow") is not None
|
||||
if _tf_available:
|
||||
candidates = (
|
||||
"tensorflow",
|
||||
"tensorflow-cpu",
|
||||
"tensorflow-gpu",
|
||||
"tf-nightly",
|
||||
"tf-nightly-cpu",
|
||||
"tf-nightly-gpu",
|
||||
"intel-tensorflow",
|
||||
"intel-tensorflow-avx512",
|
||||
"tensorflow-rocm",
|
||||
"tensorflow-macos",
|
||||
"tensorflow-aarch64",
|
||||
)
|
||||
_tf_version = None
|
||||
# For the metadata, we have to look for both tensorflow and tensorflow-cpu
|
||||
for pkg in candidates:
|
||||
try:
|
||||
_tf_version = importlib_metadata.version(pkg)
|
||||
break
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
pass
|
||||
_tf_available = _tf_version is not None
|
||||
if _tf_available:
|
||||
if version.parse(_tf_version) < version.parse("2"):
|
||||
logger.info(f"TensorFlow found but with version {_tf_version}. Diffusers requires version 2 minimum.")
|
||||
_tf_available = False
|
||||
else:
|
||||
logger.info(f"TensorFlow version {_tf_version} available.")
|
||||
else:
|
||||
logger.info("Disabling Tensorflow because USE_TORCH is set")
|
||||
_tf_available = False
|
||||
|
||||
|
||||
if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES:
|
||||
_flax_available = importlib.util.find_spec("jax") is not None and importlib.util.find_spec("flax") is not None
|
||||
if _flax_available:
|
||||
try:
|
||||
_jax_version = importlib_metadata.version("jax")
|
||||
_flax_version = importlib_metadata.version("flax")
|
||||
logger.info(f"JAX version {_jax_version}, Flax version {_flax_version} available.")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_flax_available = False
|
||||
else:
|
||||
_flax_available = False
|
||||
|
||||
|
||||
_transformers_available = importlib.util.find_spec("transformers") is not None
|
||||
try:
|
||||
_transformers_version = importlib_metadata.version("transformers")
|
||||
logger.debug(f"Successfully imported transformers version {_transformers_version}")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_transformers_available = False
|
||||
|
||||
|
||||
_inflect_available = importlib.util.find_spec("inflect") is not None
|
||||
try:
|
||||
_inflect_version = importlib_metadata.version("inflect")
|
||||
logger.debug(f"Successfully imported inflect version {_inflect_version}")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_inflect_available = False
|
||||
|
||||
|
||||
_unidecode_available = importlib.util.find_spec("unidecode") is not None
|
||||
try:
|
||||
_unidecode_version = importlib_metadata.version("unidecode")
|
||||
logger.debug(f"Successfully imported unidecode version {_unidecode_version}")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_unidecode_available = False
|
||||
|
||||
|
||||
_modelcards_available = importlib.util.find_spec("modelcards") is not None
|
||||
try:
|
||||
_modelcards_version = importlib_metadata.version("modelcards")
|
||||
logger.debug(f"Successfully imported modelcards version {_modelcards_version}")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_modelcards_available = False
|
||||
|
||||
|
||||
_scipy_available = importlib.util.find_spec("scipy") is not None
|
||||
try:
|
||||
_scipy_version = importlib_metadata.version("scipy")
|
||||
logger.debug(f"Successfully imported transformers version {_scipy_version}")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_scipy_available = False
|
||||
|
||||
|
||||
def is_torch_available():
|
||||
return _torch_available
|
||||
|
||||
|
||||
def is_tf_available():
|
||||
return _tf_available
|
||||
|
||||
|
||||
def is_flax_available():
|
||||
return _flax_available
|
||||
|
||||
|
||||
def is_transformers_available():
|
||||
return _transformers_available
|
||||
|
||||
|
||||
def is_inflect_available():
|
||||
return _inflect_available
|
||||
|
||||
|
||||
def is_unidecode_available():
|
||||
return _unidecode_available
|
||||
|
||||
|
||||
def is_modelcards_available():
|
||||
return _modelcards_available
|
||||
|
||||
|
||||
def is_scipy_available():
|
||||
return _scipy_available
|
||||
|
||||
|
||||
# docstyle-ignore
|
||||
FLAX_IMPORT_ERROR = """
|
||||
{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the
|
||||
installation page: https://github.com/google/flax and follow the ones that match your environment.
|
||||
"""
|
||||
|
||||
# docstyle-ignore
|
||||
INFLECT_IMPORT_ERROR = """
|
||||
{0} requires the inflect library but it was not found in your environment. You can install it with pip: `pip install
|
||||
inflect`
|
||||
"""
|
||||
|
||||
# docstyle-ignore
|
||||
PYTORCH_IMPORT_ERROR = """
|
||||
{0} requires the PyTorch library but it was not found in your environment. Checkout the instructions on the
|
||||
installation page: https://pytorch.org/get-started/locally/ and follow the ones that match your environment.
|
||||
"""
|
||||
|
||||
# docstyle-ignore
|
||||
SCIPY_IMPORT_ERROR = """
|
||||
{0} requires the scipy library but it was not found in your environment. You can install it with pip: `pip install
|
||||
scipy`
|
||||
"""
|
||||
|
||||
# docstyle-ignore
|
||||
TENSORFLOW_IMPORT_ERROR = """
|
||||
{0} requires the TensorFlow library but it was not found in your environment. Checkout the instructions on the
|
||||
installation page: https://www.tensorflow.org/install and follow the ones that match your environment.
|
||||
"""
|
||||
|
||||
# docstyle-ignore
|
||||
TRANSFORMERS_IMPORT_ERROR = """
|
||||
{0} requires the transformers library but it was not found in your environment. You can install it with pip: `pip
|
||||
install transformers`
|
||||
"""
|
||||
|
||||
# docstyle-ignore
|
||||
UNIDECODE_IMPORT_ERROR = """
|
||||
{0} requires the unidecode library but it was not found in your environment. You can install it with pip: `pip install
|
||||
Unidecode`
|
||||
"""
|
||||
|
||||
|
||||
BACKENDS_MAPPING = OrderedDict(
|
||||
[
|
||||
("flax", (is_flax_available, FLAX_IMPORT_ERROR)),
|
||||
("inflect", (is_inflect_available, INFLECT_IMPORT_ERROR)),
|
||||
("scipy", (is_scipy_available, SCIPY_IMPORT_ERROR)),
|
||||
("tf", (is_tf_available, TENSORFLOW_IMPORT_ERROR)),
|
||||
("torch", (is_torch_available, PYTORCH_IMPORT_ERROR)),
|
||||
("transformers", (is_transformers_available, TRANSFORMERS_IMPORT_ERROR)),
|
||||
("unidecode", (is_unidecode_available, UNIDECODE_IMPORT_ERROR)),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def requires_backends(obj, backends):
|
||||
if not isinstance(backends, (list, tuple)):
|
||||
backends = [backends]
|
||||
|
||||
name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__
|
||||
checks = (BACKENDS_MAPPING[backend] for backend in backends)
|
||||
failed = [msg.format(name) for available, msg in checks if not available()]
|
||||
if failed:
|
||||
raise ImportError("".join(failed))
|
||||
|
||||
|
||||
class DummyObject(type):
|
||||
"""
|
||||
Metaclass for the dummy objects. Any class inheriting from it will return the ImportError generated by
|
||||
`requires_backend` each time a user tries to access any method of that class.
|
||||
"""
|
||||
|
||||
def __getattr__(cls, key):
|
||||
if key.startswith("_"):
|
||||
return super().__getattr__(cls, key)
|
||||
requires_backends(cls, cls._backends)
|
||||
0
tests/models/test_embeddings.py
Normal file
0
tests/models/test_embeddings.py
Normal file
@@ -14,16 +14,13 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import inspect
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from diffusers.models.embeddings import get_timestep_embedding
|
||||
from diffusers.models.resnet import Downsample1D, Downsample2D, Upsample1D, Upsample2D
|
||||
from diffusers.testing_utils import floats_tensor, slow, torch_device
|
||||
from diffusers.models.resnet import Downsample2D, Upsample2D
|
||||
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
@@ -219,108 +216,3 @@ class Downsample2DBlockTests(unittest.TestCase):
|
||||
output_slice = downsampled[0, -1, -3:, -3:]
|
||||
expected_slice = torch.tensor([-0.6586, 0.5985, 0.0721, 0.1256, -0.1492, 0.4436, -0.2544, 0.5021, 1.1522])
|
||||
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
|
||||
|
||||
|
||||
class Upsample1DBlockTests(unittest.TestCase):
|
||||
def test_upsample_default(self):
|
||||
torch.manual_seed(0)
|
||||
sample = torch.randn(1, 32, 32)
|
||||
upsample = Upsample1D(channels=32, use_conv=False)
|
||||
with torch.no_grad():
|
||||
upsampled = upsample(sample)
|
||||
|
||||
assert upsampled.shape == (1, 32, 64)
|
||||
output_slice = upsampled[0, -1, -8:]
|
||||
expected_slice = torch.tensor([-1.6340, -1.6340, 0.5374, 0.5374, 1.0826, 1.0826, -1.7105, -1.7105])
|
||||
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
|
||||
|
||||
def test_upsample_with_conv(self):
|
||||
torch.manual_seed(0)
|
||||
sample = torch.randn(1, 32, 32)
|
||||
upsample = Upsample1D(channels=32, use_conv=True)
|
||||
with torch.no_grad():
|
||||
upsampled = upsample(sample)
|
||||
|
||||
assert upsampled.shape == (1, 32, 64)
|
||||
output_slice = upsampled[0, -1, -8:]
|
||||
expected_slice = torch.tensor([-0.4546, -0.5010, -0.2996, 0.2844, 0.4040, -0.7772, -0.6862, 0.3612])
|
||||
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
|
||||
|
||||
def test_upsample_with_conv_out_dim(self):
|
||||
torch.manual_seed(0)
|
||||
sample = torch.randn(1, 32, 32)
|
||||
upsample = Upsample1D(channels=32, use_conv=True, out_channels=64)
|
||||
with torch.no_grad():
|
||||
upsampled = upsample(sample)
|
||||
|
||||
assert upsampled.shape == (1, 64, 64)
|
||||
output_slice = upsampled[0, -1, -8:]
|
||||
expected_slice = torch.tensor([-0.0516, -0.0972, 0.9740, 1.1883, 0.4539, -0.5285, -0.5851, 0.1152])
|
||||
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
|
||||
|
||||
def test_upsample_with_transpose(self):
|
||||
torch.manual_seed(0)
|
||||
sample = torch.randn(1, 32, 32)
|
||||
upsample = Upsample1D(channels=32, use_conv=False, use_conv_transpose=True)
|
||||
with torch.no_grad():
|
||||
upsampled = upsample(sample)
|
||||
|
||||
assert upsampled.shape == (1, 32, 64)
|
||||
output_slice = upsampled[0, -1, -8:]
|
||||
expected_slice = torch.tensor([-0.2238, -0.5842, -0.7165, 0.6699, 0.1033, -0.4269, -0.8974, -0.3716])
|
||||
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
|
||||
|
||||
|
||||
class Downsample1DBlockTests(unittest.TestCase):
|
||||
def test_downsample_default(self):
|
||||
torch.manual_seed(0)
|
||||
sample = torch.randn(1, 32, 64)
|
||||
downsample = Downsample1D(channels=32, use_conv=False)
|
||||
with torch.no_grad():
|
||||
downsampled = downsample(sample)
|
||||
|
||||
assert downsampled.shape == (1, 32, 32)
|
||||
output_slice = downsampled[0, -1, -8:]
|
||||
expected_slice = torch.tensor([-0.8796, 1.0945, -0.3434, 0.2910, 0.3391, -0.4488, -0.9568, -0.2909])
|
||||
max_diff = (output_slice.flatten() - expected_slice).abs().sum().item()
|
||||
assert max_diff <= 1e-3
|
||||
# assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-1)
|
||||
|
||||
def test_downsample_with_conv(self):
|
||||
torch.manual_seed(0)
|
||||
sample = torch.randn(1, 32, 64)
|
||||
downsample = Downsample1D(channels=32, use_conv=True)
|
||||
with torch.no_grad():
|
||||
downsampled = downsample(sample)
|
||||
|
||||
assert downsampled.shape == (1, 32, 32)
|
||||
output_slice = downsampled[0, -1, -8:]
|
||||
|
||||
expected_slice = torch.tensor(
|
||||
[0.1723, 0.0811, -0.6205, -0.3045, 0.0666, -0.2381, -0.0238, 0.2834],
|
||||
)
|
||||
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
|
||||
|
||||
def test_downsample_with_conv_pad1(self):
|
||||
torch.manual_seed(0)
|
||||
sample = torch.randn(1, 32, 64)
|
||||
downsample = Downsample1D(channels=32, use_conv=True, padding=1)
|
||||
with torch.no_grad():
|
||||
downsampled = downsample(sample)
|
||||
|
||||
assert downsampled.shape == (1, 32, 32)
|
||||
output_slice = downsampled[0, -1, -8:]
|
||||
expected_slice = torch.tensor([0.1723, 0.0811, -0.6205, -0.3045, 0.0666, -0.2381, -0.0238, 0.2834])
|
||||
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
|
||||
|
||||
def test_downsample_with_conv_out_dim(self):
|
||||
torch.manual_seed(0)
|
||||
sample = torch.randn(1, 32, 64)
|
||||
downsample = Downsample1D(channels=32, use_conv=True, out_channels=16)
|
||||
with torch.no_grad():
|
||||
downsampled = downsample(sample)
|
||||
|
||||
assert downsampled.shape == (1, 16, 32)
|
||||
output_slice = downsampled[0, -1, -8:]
|
||||
expected_slice = torch.tensor([1.1067, -0.5255, -0.4451, 0.0487, -0.3664, -0.7945, -0.4495, -0.3129])
|
||||
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
|
||||
|
||||
@@ -29,12 +29,16 @@ from diffusers import (
|
||||
DDIMScheduler,
|
||||
DDPMPipeline,
|
||||
DDPMScheduler,
|
||||
KarrasVePipeline,
|
||||
KarrasVeScheduler,
|
||||
LDMPipeline,
|
||||
LDMTextToImagePipeline,
|
||||
LMSDiscreteScheduler,
|
||||
PNDMPipeline,
|
||||
PNDMScheduler,
|
||||
ScoreSdeVePipeline,
|
||||
ScoreSdeVeScheduler,
|
||||
StableDiffusionPipeline,
|
||||
UNet2DModel,
|
||||
VQModel,
|
||||
)
|
||||
@@ -555,18 +559,12 @@ class VQModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"ch": 64,
|
||||
"out_ch": 3,
|
||||
"num_res_blocks": 1,
|
||||
"block_out_channels": [32, 64],
|
||||
"in_channels": 3,
|
||||
"attn_resolutions": [],
|
||||
"resolution": 32,
|
||||
"z_channels": 3,
|
||||
"n_embed": 256,
|
||||
"embed_dim": 3,
|
||||
"sane_index_shape": False,
|
||||
"ch_mult": (1,),
|
||||
"double_z": False,
|
||||
"out_channels": 3,
|
||||
"down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"],
|
||||
"up_block_types": ["UpDecoderBlock2D", "UpDecoderBlock2D"],
|
||||
"latent_channels": 3,
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
@@ -595,13 +593,13 @@ class VQModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(0)
|
||||
|
||||
image = torch.randn(1, model.config.in_channels, model.config.resolution, model.config.resolution)
|
||||
image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
|
||||
with torch.no_grad():
|
||||
output = model(image)
|
||||
|
||||
output_slice = output[0, -1, -3:, -3:].flatten()
|
||||
# fmt: off
|
||||
expected_output_slice = torch.tensor([-1.1321, 0.1056, 0.3505, -0.6461, -0.2014, 0.0419, -0.5763, -0.8462, -0.4218])
|
||||
expected_output_slice = torch.tensor([-0.0153, -0.4044, -0.1880, -0.5161, -0.2418, -0.4072, -0.1612, -0.0633, -0.0143])
|
||||
# fmt: on
|
||||
self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2))
|
||||
|
||||
@@ -629,15 +627,12 @@ class AutoencoderKLTests(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"ch": 64,
|
||||
"ch_mult": (1,),
|
||||
"embed_dim": 4,
|
||||
"block_out_channels": [32, 64],
|
||||
"in_channels": 3,
|
||||
"attn_resolutions": [],
|
||||
"num_res_blocks": 1,
|
||||
"out_ch": 3,
|
||||
"resolution": 32,
|
||||
"z_channels": 4,
|
||||
"out_channels": 3,
|
||||
"down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"],
|
||||
"up_block_types": ["UpDecoderBlock2D", "UpDecoderBlock2D"],
|
||||
"latent_channels": 4,
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
@@ -666,13 +661,13 @@ class AutoencoderKLTests(ModelTesterMixin, unittest.TestCase):
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(0)
|
||||
|
||||
image = torch.randn(1, model.config.in_channels, model.config.resolution, model.config.resolution)
|
||||
image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
|
||||
with torch.no_grad():
|
||||
output = model(image, sample_posterior=True)
|
||||
|
||||
output_slice = output[0, -1, -3:, -3:].flatten()
|
||||
# fmt: off
|
||||
expected_output_slice = torch.tensor([-0.0814, -0.0229, -0.1320, -0.4123, -0.0366, -0.3473, 0.0438, -0.1662, 0.1750])
|
||||
expected_output_slice = torch.tensor([-4.0078e-01, -3.8304e-04, -1.2681e-01, -1.1462e-01, 2.0095e-01, 1.0893e-01, -8.8248e-02, -3.0361e-01, -9.8646e-03])
|
||||
# fmt: on
|
||||
self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2))
|
||||
|
||||
@@ -723,6 +718,27 @@ class PipelineTesterMixin(unittest.TestCase):
|
||||
|
||||
assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass"
|
||||
|
||||
@slow
|
||||
def test_from_pretrained_hub_pass_model(self):
|
||||
model_path = "google/ddpm-cifar10-32"
|
||||
|
||||
# pass unet into DiffusionPipeline
|
||||
unet = UNet2DModel.from_pretrained(model_path)
|
||||
ddpm_from_hub_custom_model = DiffusionPipeline.from_pretrained(model_path, unet=unet)
|
||||
|
||||
ddpm_from_hub = DiffusionPipeline.from_pretrained(model_path)
|
||||
|
||||
ddpm_from_hub_custom_model.scheduler.num_timesteps = 10
|
||||
ddpm_from_hub.scheduler.num_timesteps = 10
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
|
||||
image = ddpm_from_hub_custom_model(generator=generator, output_type="numpy")["sample"]
|
||||
generator = generator.manual_seed(0)
|
||||
new_image = ddpm_from_hub(generator=generator, output_type="numpy")["sample"]
|
||||
|
||||
assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass"
|
||||
|
||||
@slow
|
||||
def test_output_format(self):
|
||||
model_path = "google/ddpm-cifar10-32"
|
||||
@@ -846,17 +862,62 @@ class PipelineTesterMixin(unittest.TestCase):
|
||||
expected_slice = np.array([0.3163, 0.8670, 0.6465, 0.1865, 0.6291, 0.5139, 0.2824, 0.3723, 0.4344])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
@slow
|
||||
@unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
|
||||
def test_stable_diffusion(self):
|
||||
# make sure here that pndm scheduler skips prk
|
||||
sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-1-diffusers")
|
||||
|
||||
prompt = "A painting of a squirrel eating a burger"
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
with torch.autocast("cuda"):
|
||||
output = sd_pipe(
|
||||
[prompt], generator=generator, guidance_scale=6.0, num_inference_steps=20, output_type="np"
|
||||
)
|
||||
|
||||
image = output["sample"]
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 512, 512, 3)
|
||||
expected_slice = np.array([0.8887, 0.915, 0.91, 0.894, 0.909, 0.912, 0.919, 0.925, 0.883])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
@slow
|
||||
@unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
|
||||
def test_stable_diffusion_fast_ddim(self):
|
||||
sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-1-diffusers")
|
||||
|
||||
scheduler = DDIMScheduler(
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
beta_schedule="scaled_linear",
|
||||
clip_sample=False,
|
||||
set_alpha_to_one=False,
|
||||
)
|
||||
sd_pipe.scheduler = scheduler
|
||||
|
||||
prompt = "A painting of a squirrel eating a burger"
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
|
||||
with torch.autocast("cuda"):
|
||||
output = sd_pipe([prompt], generator=generator, num_inference_steps=2, output_type="numpy")
|
||||
image = output["sample"]
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 512, 512, 3)
|
||||
expected_slice = np.array([0.8354, 0.83, 0.866, 0.838, 0.8315, 0.867, 0.836, 0.8584, 0.869])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
||||
|
||||
@slow
|
||||
def test_score_sde_ve_pipeline(self):
|
||||
model = UNet2DModel.from_pretrained("google/ncsnpp-church-256")
|
||||
model_id = "google/ncsnpp-church-256"
|
||||
model = UNet2DModel.from_pretrained(model_id)
|
||||
|
||||
torch.manual_seed(0)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(0)
|
||||
scheduler = ScoreSdeVeScheduler.from_config(model_id)
|
||||
|
||||
scheduler = ScoreSdeVeScheduler.from_config("google/ncsnpp-church-256")
|
||||
|
||||
sde_ve = ScoreSdeVePipeline(model=model, scheduler=scheduler)
|
||||
sde_ve = ScoreSdeVePipeline(unet=model, scheduler=scheduler)
|
||||
|
||||
torch.manual_seed(0)
|
||||
image = sde_ve(num_inference_steps=300, output_type="numpy")["sample"]
|
||||
@@ -864,6 +925,7 @@ class PipelineTesterMixin(unittest.TestCase):
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 256, 256, 3)
|
||||
|
||||
expected_slice = np.array([0.64363, 0.5868, 0.3031, 0.2284, 0.7409, 0.3216, 0.25643, 0.6557, 0.2633])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
@@ -879,3 +941,80 @@ class PipelineTesterMixin(unittest.TestCase):
|
||||
assert image.shape == (1, 256, 256, 3)
|
||||
expected_slice = np.array([0.4399, 0.44975, 0.46825, 0.474, 0.4359, 0.4581, 0.45095, 0.4341, 0.4447])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
@slow
|
||||
def test_ddpm_ddim_equality(self):
|
||||
model_id = "google/ddpm-cifar10-32"
|
||||
|
||||
unet = UNet2DModel.from_pretrained(model_id)
|
||||
ddpm_scheduler = DDPMScheduler(tensor_format="pt")
|
||||
ddim_scheduler = DDIMScheduler(tensor_format="pt")
|
||||
|
||||
ddpm = DDPMPipeline(unet=unet, scheduler=ddpm_scheduler)
|
||||
ddim = DDIMPipeline(unet=unet, scheduler=ddim_scheduler)
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
ddpm_image = ddpm(generator=generator, output_type="numpy")["sample"]
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
ddim_image = ddim(generator=generator, num_inference_steps=1000, eta=1.0, output_type="numpy")["sample"]
|
||||
|
||||
# the values aren't exactly equal, but the images look the same visually
|
||||
assert np.abs(ddpm_image - ddim_image).max() < 1e-1
|
||||
|
||||
@unittest.skip("(Anton) The test is failing for large batch sizes, needs investigation")
|
||||
def test_ddpm_ddim_equality_batched(self):
|
||||
model_id = "google/ddpm-cifar10-32"
|
||||
|
||||
unet = UNet2DModel.from_pretrained(model_id)
|
||||
ddpm_scheduler = DDPMScheduler(tensor_format="pt")
|
||||
ddim_scheduler = DDIMScheduler(tensor_format="pt")
|
||||
|
||||
ddpm = DDPMPipeline(unet=unet, scheduler=ddpm_scheduler)
|
||||
ddim = DDIMPipeline(unet=unet, scheduler=ddim_scheduler)
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
ddpm_images = ddpm(batch_size=4, generator=generator, output_type="numpy")["sample"]
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
ddim_images = ddim(batch_size=4, generator=generator, num_inference_steps=1000, eta=1.0, output_type="numpy")[
|
||||
"sample"
|
||||
]
|
||||
|
||||
# the values aren't exactly equal, but the images look the same visually
|
||||
assert np.abs(ddpm_images - ddim_images).max() < 1e-1
|
||||
|
||||
@slow
|
||||
def test_karras_ve_pipeline(self):
|
||||
model_id = "google/ncsnpp-celebahq-256"
|
||||
model = UNet2DModel.from_pretrained(model_id)
|
||||
scheduler = KarrasVeScheduler(tensor_format="pt")
|
||||
|
||||
pipe = KarrasVePipeline(unet=model, scheduler=scheduler)
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
image = pipe(num_inference_steps=20, generator=generator, output_type="numpy")["sample"]
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
assert image.shape == (1, 256, 256, 3)
|
||||
expected_slice = np.array([0.26815, 0.1581, 0.2658, 0.23248, 0.1550, 0.2539, 0.1131, 0.1024, 0.0837])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
@slow
|
||||
@unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
|
||||
def test_lms_stable_diffusion_pipeline(self):
|
||||
model_id = "CompVis/stable-diffusion-v1-1-diffusers"
|
||||
pipe = StableDiffusionPipeline.from_pretrained(model_id, use_auth_token=True)
|
||||
scheduler = LMSDiscreteScheduler.from_config(model_id, subfolder="scheduler", use_auth_token=True)
|
||||
pipe.scheduler = scheduler
|
||||
|
||||
prompt = "a photograph of an astronaut riding a horse"
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
image = pipe([prompt], generator=generator, guidance_scale=7.5, num_inference_steps=10, output_type="numpy")[
|
||||
"sample"
|
||||
]
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
assert image.shape == (1, 512, 512, 3)
|
||||
expected_slice = np.array([0.9077, 0.9254, 0.9181, 0.9227, 0.9213, 0.9367, 0.9399, 0.9406, 0.9024])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
89
tests/test_training.py
Normal file
89
tests/test_training.py
Normal file
@@ -0,0 +1,89 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import DDIMScheduler, DDPMScheduler, UNet2DModel
|
||||
from diffusers.testing_utils import slow, torch_device
|
||||
from diffusers.training_utils import enable_full_determinism, set_seed
|
||||
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
|
||||
|
||||
class TrainingTests(unittest.TestCase):
|
||||
def get_model_optimizer(self, resolution=32):
|
||||
set_seed(0)
|
||||
model = UNet2DModel(sample_size=resolution, in_channels=3, out_channels=3)
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=0.0001)
|
||||
return model, optimizer
|
||||
|
||||
@slow
|
||||
def test_training_step_equality(self):
|
||||
enable_full_determinism(0)
|
||||
|
||||
ddpm_scheduler = DDPMScheduler(
|
||||
num_train_timesteps=1000,
|
||||
beta_start=0.0001,
|
||||
beta_end=0.02,
|
||||
beta_schedule="linear",
|
||||
clip_sample=True,
|
||||
tensor_format="pt",
|
||||
)
|
||||
ddim_scheduler = DDIMScheduler(
|
||||
num_train_timesteps=1000,
|
||||
beta_start=0.0001,
|
||||
beta_end=0.02,
|
||||
beta_schedule="linear",
|
||||
clip_sample=True,
|
||||
tensor_format="pt",
|
||||
)
|
||||
|
||||
assert ddpm_scheduler.num_train_timesteps == ddim_scheduler.num_train_timesteps
|
||||
|
||||
# shared batches for DDPM and DDIM
|
||||
set_seed(0)
|
||||
clean_images = [torch.randn((4, 3, 32, 32)).clip(-1, 1).to(torch_device) for _ in range(4)]
|
||||
noise = [torch.randn((4, 3, 32, 32)).to(torch_device) for _ in range(4)]
|
||||
timesteps = [torch.randint(0, 1000, (4,)).long().to(torch_device) for _ in range(4)]
|
||||
|
||||
# train with a DDPM scheduler
|
||||
model, optimizer = self.get_model_optimizer(resolution=32)
|
||||
model.train().to(torch_device)
|
||||
for i in range(4):
|
||||
optimizer.zero_grad()
|
||||
ddpm_noisy_images = ddpm_scheduler.add_noise(clean_images[i], noise[i], timesteps[i])
|
||||
ddpm_noise_pred = model(ddpm_noisy_images, timesteps[i])["sample"]
|
||||
loss = torch.nn.functional.mse_loss(ddpm_noise_pred, noise[i])
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
del model, optimizer
|
||||
|
||||
# recreate the model and optimizer, and retry with DDIM
|
||||
model, optimizer = self.get_model_optimizer(resolution=32)
|
||||
model.train().to(torch_device)
|
||||
for i in range(4):
|
||||
optimizer.zero_grad()
|
||||
ddim_noisy_images = ddim_scheduler.add_noise(clean_images[i], noise[i], timesteps[i])
|
||||
ddim_noise_pred = model(ddim_noisy_images, timesteps[i])["sample"]
|
||||
loss = torch.nn.functional.mse_loss(ddim_noise_pred, noise[i])
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
del model, optimizer
|
||||
|
||||
self.assertTrue(torch.allclose(ddpm_noisy_images, ddim_noisy_images, atol=1e-5))
|
||||
self.assertTrue(torch.allclose(ddpm_noise_pred, ddim_noise_pred, atol=1e-5))
|
||||
@@ -22,358 +22,138 @@ from collections import OrderedDict
|
||||
from difflib import get_close_matches
|
||||
from pathlib import Path
|
||||
|
||||
from transformers import is_flax_available, is_tf_available, is_torch_available
|
||||
from transformers.models.auto import get_values
|
||||
from transformers.utils import ENV_VARS_TRUE_VALUES
|
||||
from diffusers.utils import ENV_VARS_TRUE_VALUES, is_flax_available, is_tf_available, is_torch_available
|
||||
|
||||
|
||||
# All paths are set with the intent you should run this script from the root of the repo with the command
|
||||
# python utils/check_repo.py
|
||||
PATH_TO_TRANSFORMERS = "src/transformers"
|
||||
PATH_TO_DIFFUSERS = "src/diffusers"
|
||||
PATH_TO_TESTS = "tests"
|
||||
PATH_TO_DOC = "docs/source/en"
|
||||
|
||||
# Update this list with models that are supposed to be private.
|
||||
PRIVATE_MODELS = [
|
||||
"DPRSpanPredictor",
|
||||
"RealmBertModel",
|
||||
"T5Stack",
|
||||
"TFDPRSpanPredictor",
|
||||
]
|
||||
PRIVATE_MODELS = []
|
||||
|
||||
# Update this list for models that are not tested with a comment explaining the reason it should not be.
|
||||
# Being in this list is an exception and should **not** be the rule.
|
||||
IGNORE_NON_TESTED = PRIVATE_MODELS.copy() + [
|
||||
# models to ignore for not tested
|
||||
"OPTDecoder", # Building part of bigger (tested) model.
|
||||
"DecisionTransformerGPT2Model", # Building part of bigger (tested) model.
|
||||
"SegformerDecodeHead", # Building part of bigger (tested) model.
|
||||
"PLBartEncoder", # Building part of bigger (tested) model.
|
||||
"PLBartDecoder", # Building part of bigger (tested) model.
|
||||
"PLBartDecoderWrapper", # Building part of bigger (tested) model.
|
||||
"BigBirdPegasusEncoder", # Building part of bigger (tested) model.
|
||||
"BigBirdPegasusDecoder", # Building part of bigger (tested) model.
|
||||
"BigBirdPegasusDecoderWrapper", # Building part of bigger (tested) model.
|
||||
"DetrEncoder", # Building part of bigger (tested) model.
|
||||
"DetrDecoder", # Building part of bigger (tested) model.
|
||||
"DetrDecoderWrapper", # Building part of bigger (tested) model.
|
||||
"M2M100Encoder", # Building part of bigger (tested) model.
|
||||
"M2M100Decoder", # Building part of bigger (tested) model.
|
||||
"Speech2TextEncoder", # Building part of bigger (tested) model.
|
||||
"Speech2TextDecoder", # Building part of bigger (tested) model.
|
||||
"LEDEncoder", # Building part of bigger (tested) model.
|
||||
"LEDDecoder", # Building part of bigger (tested) model.
|
||||
"BartDecoderWrapper", # Building part of bigger (tested) model.
|
||||
"BartEncoder", # Building part of bigger (tested) model.
|
||||
"BertLMHeadModel", # Needs to be setup as decoder.
|
||||
"BlenderbotSmallEncoder", # Building part of bigger (tested) model.
|
||||
"BlenderbotSmallDecoderWrapper", # Building part of bigger (tested) model.
|
||||
"BlenderbotEncoder", # Building part of bigger (tested) model.
|
||||
"BlenderbotDecoderWrapper", # Building part of bigger (tested) model.
|
||||
"MBartEncoder", # Building part of bigger (tested) model.
|
||||
"MBartDecoderWrapper", # Building part of bigger (tested) model.
|
||||
"MegatronBertLMHeadModel", # Building part of bigger (tested) model.
|
||||
"MegatronBertEncoder", # Building part of bigger (tested) model.
|
||||
"MegatronBertDecoder", # Building part of bigger (tested) model.
|
||||
"MegatronBertDecoderWrapper", # Building part of bigger (tested) model.
|
||||
"PegasusEncoder", # Building part of bigger (tested) model.
|
||||
"PegasusDecoderWrapper", # Building part of bigger (tested) model.
|
||||
"DPREncoder", # Building part of bigger (tested) model.
|
||||
"ProphetNetDecoderWrapper", # Building part of bigger (tested) model.
|
||||
"RealmBertModel", # Building part of bigger (tested) model.
|
||||
"RealmReader", # Not regular model.
|
||||
"RealmScorer", # Not regular model.
|
||||
"RealmForOpenQA", # Not regular model.
|
||||
"ReformerForMaskedLM", # Needs to be setup as decoder.
|
||||
"Speech2Text2DecoderWrapper", # Building part of bigger (tested) model.
|
||||
"TFDPREncoder", # Building part of bigger (tested) model.
|
||||
"TFElectraMainLayer", # Building part of bigger (tested) model (should it be a TFModelMixin ?)
|
||||
"TFRobertaForMultipleChoice", # TODO: fix
|
||||
"TrOCRDecoderWrapper", # Building part of bigger (tested) model.
|
||||
"SeparableConv1D", # Building part of bigger (tested) model.
|
||||
"FlaxBartForCausalLM", # Building part of bigger (tested) model.
|
||||
"FlaxBertForCausalLM", # Building part of bigger (tested) model. Tested implicitly through FlaxRobertaForCausalLM.
|
||||
"OPTDecoderWrapper",
|
||||
]
|
||||
IGNORE_NON_TESTED = PRIVATE_MODELS.copy() + []
|
||||
|
||||
# Update this list with test files that don't have a tester with a `all_model_classes` variable and which don't
|
||||
# trigger the common tests.
|
||||
TEST_FILES_WITH_NO_COMMON_TESTS = [
|
||||
"models/decision_transformer/test_modeling_decision_transformer.py",
|
||||
"models/camembert/test_modeling_camembert.py",
|
||||
"models/mt5/test_modeling_flax_mt5.py",
|
||||
"models/mbart/test_modeling_mbart.py",
|
||||
"models/mt5/test_modeling_mt5.py",
|
||||
"models/pegasus/test_modeling_pegasus.py",
|
||||
"models/camembert/test_modeling_tf_camembert.py",
|
||||
"models/mt5/test_modeling_tf_mt5.py",
|
||||
"models/xlm_roberta/test_modeling_tf_xlm_roberta.py",
|
||||
"models/xlm_roberta/test_modeling_flax_xlm_roberta.py",
|
||||
"models/xlm_prophetnet/test_modeling_xlm_prophetnet.py",
|
||||
"models/xlm_roberta/test_modeling_xlm_roberta.py",
|
||||
"models/vision_text_dual_encoder/test_modeling_vision_text_dual_encoder.py",
|
||||
"models/vision_text_dual_encoder/test_modeling_flax_vision_text_dual_encoder.py",
|
||||
"models/decision_transformer/test_modeling_decision_transformer.py",
|
||||
]
|
||||
TEST_FILES_WITH_NO_COMMON_TESTS = []
|
||||
|
||||
# Update this list for models that are not in any of the auto MODEL_XXX_MAPPING. Being in this list is an exception and
|
||||
# should **not** be the rule.
|
||||
IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [
|
||||
# models to ignore for model xxx mapping
|
||||
"DPTForDepthEstimation",
|
||||
"DecisionTransformerGPT2Model",
|
||||
"GLPNForDepthEstimation",
|
||||
"ViltForQuestionAnswering",
|
||||
"ViltForImagesAndTextClassification",
|
||||
"ViltForImageAndTextRetrieval",
|
||||
"ViltForMaskedLM",
|
||||
"XGLMEncoder",
|
||||
"XGLMDecoder",
|
||||
"XGLMDecoderWrapper",
|
||||
"PerceiverForMultimodalAutoencoding",
|
||||
"PerceiverForOpticalFlow",
|
||||
"SegformerDecodeHead",
|
||||
"FlaxBeitForMaskedImageModeling",
|
||||
"PLBartEncoder",
|
||||
"PLBartDecoder",
|
||||
"PLBartDecoderWrapper",
|
||||
"BeitForMaskedImageModeling",
|
||||
"CLIPTextModel",
|
||||
"CLIPVisionModel",
|
||||
"TFCLIPTextModel",
|
||||
"TFCLIPVisionModel",
|
||||
"FlaxCLIPTextModel",
|
||||
"FlaxCLIPVisionModel",
|
||||
"FlaxWav2Vec2ForCTC",
|
||||
"DetrForSegmentation",
|
||||
"DPRReader",
|
||||
"FlaubertForQuestionAnswering",
|
||||
"FlavaImageCodebook",
|
||||
"FlavaTextModel",
|
||||
"FlavaImageModel",
|
||||
"FlavaMultimodalModel",
|
||||
"GPT2DoubleHeadsModel",
|
||||
"LukeForMaskedLM",
|
||||
"LukeForEntityClassification",
|
||||
"LukeForEntityPairClassification",
|
||||
"LukeForEntitySpanClassification",
|
||||
"OpenAIGPTDoubleHeadsModel",
|
||||
"RagModel",
|
||||
"RagSequenceForGeneration",
|
||||
"RagTokenForGeneration",
|
||||
"RealmEmbedder",
|
||||
"RealmForOpenQA",
|
||||
"RealmScorer",
|
||||
"RealmReader",
|
||||
"TFDPRReader",
|
||||
"TFGPT2DoubleHeadsModel",
|
||||
"TFOpenAIGPTDoubleHeadsModel",
|
||||
"TFRagModel",
|
||||
"TFRagSequenceForGeneration",
|
||||
"TFRagTokenForGeneration",
|
||||
"Wav2Vec2ForCTC",
|
||||
"HubertForCTC",
|
||||
"SEWForCTC",
|
||||
"SEWDForCTC",
|
||||
"XLMForQuestionAnswering",
|
||||
"XLNetForQuestionAnswering",
|
||||
"SeparableConv1D",
|
||||
"VisualBertForRegionToPhraseAlignment",
|
||||
"VisualBertForVisualReasoning",
|
||||
"VisualBertForQuestionAnswering",
|
||||
"VisualBertForMultipleChoice",
|
||||
"TFWav2Vec2ForCTC",
|
||||
"TFHubertForCTC",
|
||||
"MaskFormerForInstanceSegmentation",
|
||||
]
|
||||
|
||||
# Update this list for models that have multiple model types for the same
|
||||
# model doc
|
||||
MODEL_TYPE_TO_DOC_MAPPING = OrderedDict(
|
||||
[
|
||||
("data2vec-text", "data2vec"),
|
||||
("data2vec-audio", "data2vec"),
|
||||
("data2vec-vision", "data2vec"),
|
||||
]
|
||||
)
|
||||
IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + []
|
||||
|
||||
|
||||
# This is to make sure the transformers module imported is the one in the repo.
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
"transformers",
|
||||
os.path.join(PATH_TO_TRANSFORMERS, "__init__.py"),
|
||||
submodule_search_locations=[PATH_TO_TRANSFORMERS],
|
||||
"diffusers",
|
||||
os.path.join(PATH_TO_DIFFUSERS, "__init__.py"),
|
||||
submodule_search_locations=[PATH_TO_DIFFUSERS],
|
||||
)
|
||||
transformers = spec.loader.load_module()
|
||||
diffusers = spec.loader.load_module()
|
||||
|
||||
|
||||
def check_model_list():
|
||||
"""Check the model list inside the transformers library."""
|
||||
# Get the models from the directory structure of `src/transformers/models/`
|
||||
models_dir = os.path.join(PATH_TO_TRANSFORMERS, "models")
|
||||
_models = []
|
||||
for model in os.listdir(models_dir):
|
||||
model_dir = os.path.join(models_dir, model)
|
||||
if os.path.isdir(model_dir) and "__init__.py" in os.listdir(model_dir):
|
||||
_models.append(model)
|
||||
def check_modules_are_in_local_init():
|
||||
"""Check the model list inside the diffusers library."""
|
||||
# Get the modules from the directory structure of `src/diffusers/<models,schedulers,pipelines>/`
|
||||
modules_dirs = [os.path.join(PATH_TO_DIFFUSERS, subdir) for subdir in ["models", "pipelines", "schedulers"]]
|
||||
for modules_dir in modules_dirs:
|
||||
_modules = []
|
||||
for module in os.listdir(modules_dir):
|
||||
module_dir = os.path.join(modules_dir, module)
|
||||
if os.path.isdir(module_dir) and "__init__.py" in os.listdir(module_dir):
|
||||
_modules.append(module)
|
||||
elif os.path.isfile(module_dir) and not module.startswith("_") and module_dir.endswith(".py") :
|
||||
_modules.append(module.replace(".py", ""))
|
||||
|
||||
# Get the models from the directory structure of `src/transformers/models/`
|
||||
models = [model for model in dir(transformers.models) if not model.startswith("__")]
|
||||
# Get the modules from the directory structure of `src/diffusers/<models,schedulers,pipelines>/`
|
||||
module_dirs = dir(diffusers.models) + dir(diffusers.pipelines) + dir(diffusers.schedulers)
|
||||
modules = [module for module in module_dirs if not module.startswith("__")]
|
||||
|
||||
missing_models = sorted(list(set(_models).difference(models)))
|
||||
if missing_models:
|
||||
raise Exception(
|
||||
f"The following models should be included in {models_dir}/__init__.py: {','.join(missing_models)}."
|
||||
)
|
||||
missing_modules = sorted(list(set(_modules).difference(modules)))
|
||||
if missing_modules:
|
||||
raise Exception(
|
||||
f"The following modules should be included in {modules_dir}/__init__.py: {','.join(missing_modules)}."
|
||||
)
|
||||
|
||||
|
||||
# If some modeling modules should be ignored for all checks, they should be added in the nested list
|
||||
# _ignore_modules of this function.
|
||||
def get_model_modules():
|
||||
"""Get the model modules inside the transformers library."""
|
||||
_ignore_modules = [
|
||||
"modeling_auto",
|
||||
"modeling_encoder_decoder",
|
||||
"modeling_marian",
|
||||
"modeling_mmbt",
|
||||
"modeling_outputs",
|
||||
"modeling_retribert",
|
||||
"modeling_utils",
|
||||
"modeling_flax_auto",
|
||||
"modeling_flax_encoder_decoder",
|
||||
"modeling_flax_utils",
|
||||
"modeling_speech_encoder_decoder",
|
||||
"modeling_flax_speech_encoder_decoder",
|
||||
"modeling_flax_vision_encoder_decoder",
|
||||
"modeling_transfo_xl_utilities",
|
||||
"modeling_tf_auto",
|
||||
"modeling_tf_encoder_decoder",
|
||||
"modeling_tf_outputs",
|
||||
"modeling_tf_pytorch_utils",
|
||||
"modeling_tf_utils",
|
||||
"modeling_tf_transfo_xl_utilities",
|
||||
"modeling_tf_vision_encoder_decoder",
|
||||
"modeling_vision_encoder_decoder",
|
||||
]
|
||||
"""Get the model modules inside the diffusers library."""
|
||||
_ignore_modules = []
|
||||
modules = []
|
||||
for model in dir(transformers.models):
|
||||
for model in dir(diffusers.models):
|
||||
# There are some magic dunder attributes in the dir, we ignore them
|
||||
if not model.startswith("__"):
|
||||
model_module = getattr(transformers.models, model)
|
||||
for submodule in dir(model_module):
|
||||
if submodule.startswith("modeling") and submodule not in _ignore_modules:
|
||||
modeling_module = getattr(model_module, submodule)
|
||||
if inspect.ismodule(modeling_module):
|
||||
modules.append(modeling_module)
|
||||
model_module = getattr(diffusers.models, model)
|
||||
if inspect.ismodule(model_module):
|
||||
modules.append(model_module)
|
||||
return modules
|
||||
|
||||
|
||||
def get_models(module, include_pretrained=False):
|
||||
"""Get the objects in module that are models."""
|
||||
models = []
|
||||
model_classes = (transformers.ModelMixin, transformers.TFModelMixin, transformers.FlaxModelMixin)
|
||||
def get_modules(module):
|
||||
"""Get the objects in module that are models/schedulers/pipelines."""
|
||||
objects = []
|
||||
objects_classes = (diffusers.modeling_utils.ModelMixin, diffusers.SchedulerMixin, diffusers.DiffusionPipeline)
|
||||
for attr_name in dir(module):
|
||||
if not include_pretrained and ("Pretrained" in attr_name or "PreTrained" in attr_name):
|
||||
continue
|
||||
attr = getattr(module, attr_name)
|
||||
if isinstance(attr, type) and issubclass(attr, model_classes) and attr.__module__ == module.__name__:
|
||||
models.append((attr_name, attr))
|
||||
return models
|
||||
if isinstance(attr, type) and issubclass(attr, objects_classes) and attr.__module__ == module.__name__:
|
||||
objects.append((attr_name, attr))
|
||||
return objects
|
||||
|
||||
|
||||
def is_a_private_model(model):
|
||||
"""Returns True if the model should not be in the main init."""
|
||||
if model in PRIVATE_MODELS:
|
||||
return True
|
||||
|
||||
# Wrapper, Encoder and Decoder are all privates
|
||||
if model.endswith("Wrapper"):
|
||||
return True
|
||||
if model.endswith("Encoder"):
|
||||
return True
|
||||
if model.endswith("Decoder"):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def check_models_are_in_init():
|
||||
def check_modules_are_in_global_init():
|
||||
"""Checks all models defined in the library are in the main init."""
|
||||
models_not_in_init = []
|
||||
dir_transformers = dir(transformers)
|
||||
modules_not_in_init = []
|
||||
dir_diffusers = dir(diffusers)
|
||||
for module in get_model_modules():
|
||||
models_not_in_init += [
|
||||
model[0] for model in get_models(module, include_pretrained=True) if model[0] not in dir_transformers
|
||||
modules_not_in_init += [
|
||||
module[0] for module in get_modules(module) if module[0] not in dir_diffusers
|
||||
]
|
||||
|
||||
# Remove private models
|
||||
models_not_in_init = [model for model in models_not_in_init if not is_a_private_model(model)]
|
||||
if len(models_not_in_init) > 0:
|
||||
raise Exception(f"The following models should be in the main init: {','.join(models_not_in_init)}.")
|
||||
if len(modules_not_in_init) > 0:
|
||||
raise Exception(f"The following models should be in the main init: {','.join(modules_not_in_init)}.")
|
||||
|
||||
|
||||
# If some test_modeling files should be ignored when checking models are all tested, they should be added in the
|
||||
# If some test files should be ignored when checking models are all tested, they should be added in the
|
||||
# nested list _ignore_files of this function.
|
||||
def get_model_test_files():
|
||||
"""Get the model test files.
|
||||
def get_module_test_files():
|
||||
"""Get the model/scheduler/pipeline test files.
|
||||
|
||||
The returned files should NOT contain the `tests` (i.e. `PATH_TO_TESTS` defined in this script). They will be
|
||||
considered as paths relative to `tests`. A caller has to use `os.path.join(PATH_TO_TESTS, ...)` to access the files.
|
||||
"""
|
||||
|
||||
_ignore_files = [
|
||||
"test_modeling_common",
|
||||
"test_modeling_encoder_decoder",
|
||||
"test_modeling_flax_encoder_decoder",
|
||||
"test_modeling_flax_speech_encoder_decoder",
|
||||
"test_modeling_marian",
|
||||
"test_modeling_tf_common",
|
||||
"test_modeling_tf_encoder_decoder",
|
||||
]
|
||||
_ignore_files = []
|
||||
test_files = []
|
||||
# Check both `PATH_TO_TESTS` and `PATH_TO_TESTS/models`
|
||||
model_test_root = os.path.join(PATH_TO_TESTS, "models")
|
||||
|
||||
model_test_root = os.path.join(PATH_TO_TESTS)
|
||||
model_test_dirs = []
|
||||
for x in os.listdir(model_test_root):
|
||||
x = os.path.join(model_test_root, x)
|
||||
if os.path.isdir(x):
|
||||
model_test_dirs.append(x)
|
||||
|
||||
for target_dir in [PATH_TO_TESTS] + model_test_dirs:
|
||||
for target_dir in model_test_dirs:
|
||||
for file_or_dir in os.listdir(target_dir):
|
||||
path = os.path.join(target_dir, file_or_dir)
|
||||
if os.path.isfile(path):
|
||||
filename = os.path.split(path)[-1]
|
||||
if "test_modeling" in filename and not os.path.splitext(filename)[0] in _ignore_files:
|
||||
if "test_" in filename and not os.path.splitext(filename)[0] in _ignore_files:
|
||||
file = os.path.join(*path.split(os.sep)[1:])
|
||||
test_files.append(file)
|
||||
|
||||
return test_files
|
||||
|
||||
|
||||
# This is a bit hacky but I didn't find a way to import the test_file as a module and read inside the tester class
|
||||
# for the all_model_classes variable.
|
||||
def find_tested_models(test_file):
|
||||
"""Parse the content of test_file to detect what's in all_model_classes"""
|
||||
# This is a bit hacky but I didn't find a way to import the test_file as a module and read inside the class
|
||||
with open(os.path.join(PATH_TO_TESTS, test_file), "r", encoding="utf-8", newline="\n") as f:
|
||||
content = f.read()
|
||||
all_models = re.findall(r"all_model_classes\s+=\s+\(\s*\(([^\)]*)\)", content)
|
||||
# Check with one less parenthesis as well
|
||||
all_models += re.findall(r"all_model_classes\s+=\s+\(([^\)]*)\)", content)
|
||||
if len(all_models) > 0:
|
||||
model_tested = []
|
||||
for entry in all_models:
|
||||
for line in entry.split(","):
|
||||
name = line.strip()
|
||||
if len(name) > 0:
|
||||
model_tested.append(name)
|
||||
return model_tested
|
||||
|
||||
|
||||
def check_models_are_tested(module, test_file):
|
||||
"""Check models defined in module are tested in test_file."""
|
||||
# XxxModelMixin are not tested
|
||||
defined_models = get_models(module)
|
||||
defined_models = get_modules(module)
|
||||
tested_models = find_tested_models(test_file)
|
||||
if tested_models is None:
|
||||
if test_file.replace(os.path.sep, "/") in TEST_FILES_WITH_NO_COMMON_TESTS:
|
||||
@@ -395,10 +175,10 @@ def check_models_are_tested(module, test_file):
|
||||
return failures
|
||||
|
||||
|
||||
def check_all_models_are_tested():
|
||||
"""Check all models are properly tested."""
|
||||
def check_all_modules_are_tested():
|
||||
"""Check all models/schedulers/pipelines are properly tested."""
|
||||
modules = get_model_modules()
|
||||
test_files = get_model_test_files()
|
||||
test_files = get_module_test_files()
|
||||
failures = []
|
||||
for module in modules:
|
||||
test_file = [file for file in test_files if f"test_{module.__name__.split('.')[-1]}.py" in file]
|
||||
@@ -414,84 +194,6 @@ def check_all_models_are_tested():
|
||||
if len(failures) > 0:
|
||||
raise Exception(f"There were {len(failures)} failures:\n" + "\n".join(failures))
|
||||
|
||||
|
||||
def get_all_auto_configured_models():
|
||||
"""Return the list of all models in at least one auto class."""
|
||||
result = set() # To avoid duplicates we concatenate all model classes in a set.
|
||||
if is_torch_available():
|
||||
for attr_name in dir(transformers.models.auto.modeling_auto):
|
||||
if attr_name.startswith("MODEL_") and attr_name.endswith("MAPPING_NAMES"):
|
||||
result = result | set(get_values(getattr(transformers.models.auto.modeling_auto, attr_name)))
|
||||
if is_tf_available():
|
||||
for attr_name in dir(transformers.models.auto.modeling_tf_auto):
|
||||
if attr_name.startswith("TF_MODEL_") and attr_name.endswith("MAPPING_NAMES"):
|
||||
result = result | set(get_values(getattr(transformers.models.auto.modeling_tf_auto, attr_name)))
|
||||
if is_flax_available():
|
||||
for attr_name in dir(transformers.models.auto.modeling_flax_auto):
|
||||
if attr_name.startswith("FLAX_MODEL_") and attr_name.endswith("MAPPING_NAMES"):
|
||||
result = result | set(get_values(getattr(transformers.models.auto.modeling_flax_auto, attr_name)))
|
||||
return [cls for cls in result]
|
||||
|
||||
|
||||
def ignore_unautoclassed(model_name):
|
||||
"""Rules to determine if `name` should be in an auto class."""
|
||||
# Special white list
|
||||
if model_name in IGNORE_NON_AUTO_CONFIGURED:
|
||||
return True
|
||||
# Encoder and Decoder should be ignored
|
||||
if "Encoder" in model_name or "Decoder" in model_name:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def check_models_are_auto_configured(module, all_auto_models):
|
||||
"""Check models defined in module are each in an auto class."""
|
||||
defined_models = get_models(module)
|
||||
failures = []
|
||||
for model_name, _ in defined_models:
|
||||
if model_name not in all_auto_models and not ignore_unautoclassed(model_name):
|
||||
failures.append(
|
||||
f"{model_name} is defined in {module.__name__} but is not present in any of the auto mapping. "
|
||||
"If that is intended behavior, add its name to `IGNORE_NON_AUTO_CONFIGURED` in the file "
|
||||
"`utils/check_repo.py`."
|
||||
)
|
||||
return failures
|
||||
|
||||
|
||||
def check_all_models_are_auto_configured():
|
||||
"""Check all models are each in an auto class."""
|
||||
missing_backends = []
|
||||
if not is_torch_available():
|
||||
missing_backends.append("PyTorch")
|
||||
if not is_tf_available():
|
||||
missing_backends.append("TensorFlow")
|
||||
if not is_flax_available():
|
||||
missing_backends.append("Flax")
|
||||
if len(missing_backends) > 0:
|
||||
missing = ", ".join(missing_backends)
|
||||
if os.getenv("TRANSFORMERS_IS_CI", "").upper() in ENV_VARS_TRUE_VALUES:
|
||||
raise Exception(
|
||||
"Full quality checks require all backends to be installed (with `pip install -e .[dev]` in the "
|
||||
f"Transformers repo, the following are missing: {missing}."
|
||||
)
|
||||
else:
|
||||
warnings.warn(
|
||||
"Full quality checks require all backends to be installed (with `pip install -e .[dev]` in the "
|
||||
f"Transformers repo, the following are missing: {missing}. While it's probably fine as long as you "
|
||||
"didn't make any change in one of those backends modeling files, you should probably execute the "
|
||||
"command above to be on the safe side."
|
||||
)
|
||||
modules = get_model_modules()
|
||||
all_auto_models = get_all_auto_configured_models()
|
||||
failures = []
|
||||
for module in modules:
|
||||
new_failures = check_models_are_auto_configured(module, all_auto_models)
|
||||
if new_failures is not None:
|
||||
failures += new_failures
|
||||
if len(failures) > 0:
|
||||
raise Exception(f"There were {len(failures)} failures:\n" + "\n".join(failures))
|
||||
|
||||
|
||||
_re_decorator = re.compile(r"^\s*@(\S+)\s+$")
|
||||
|
||||
|
||||
@@ -546,77 +248,14 @@ def find_all_documented_objects():
|
||||
|
||||
|
||||
# One good reason for not being documented is to be deprecated. Put in this list deprecated objects.
|
||||
DEPRECATED_OBJECTS = [
|
||||
"AutoModelWithLMHead",
|
||||
"BartPretrainedModel",
|
||||
"DataCollator",
|
||||
"DataCollatorForSOP",
|
||||
"GlueDataset",
|
||||
"GlueDataTrainingArguments",
|
||||
"LineByLineTextDataset",
|
||||
"LineByLineWithRefDataset",
|
||||
"LineByLineWithSOPTextDataset",
|
||||
"PretrainedBartModel",
|
||||
"PretrainedFSMTModel",
|
||||
"SingleSentenceClassificationProcessor",
|
||||
"SquadDataTrainingArguments",
|
||||
"SquadDataset",
|
||||
"SquadExample",
|
||||
"SquadFeatures",
|
||||
"SquadV1Processor",
|
||||
"SquadV2Processor",
|
||||
"TFAutoModelWithLMHead",
|
||||
"TFBartPretrainedModel",
|
||||
"TextDataset",
|
||||
"TextDatasetForNextSentencePrediction",
|
||||
"Wav2Vec2ForMaskedLM",
|
||||
"Wav2Vec2Tokenizer",
|
||||
"glue_compute_metrics",
|
||||
"glue_convert_examples_to_features",
|
||||
"glue_output_modes",
|
||||
"glue_processors",
|
||||
"glue_tasks_num_labels",
|
||||
"squad_convert_examples_to_features",
|
||||
"xnli_compute_metrics",
|
||||
"xnli_output_modes",
|
||||
"xnli_processors",
|
||||
"xnli_tasks_num_labels",
|
||||
"TFTrainer",
|
||||
"TFTrainingArguments",
|
||||
]
|
||||
DEPRECATED_OBJECTS = []
|
||||
|
||||
# Exceptionally, some objects should not be documented after all rules passed.
|
||||
# ONLY PUT SOMETHING IN THIS LIST AS A LAST RESORT!
|
||||
UNDOCUMENTED_OBJECTS = [
|
||||
"AddedToken", # This is a tokenizers class.
|
||||
"BasicTokenizer", # Internal, should never have been in the main init.
|
||||
"CharacterTokenizer", # Internal, should never have been in the main init.
|
||||
"DPRPretrainedReader", # Like an Encoder.
|
||||
"DummyObject", # Just picked by mistake sometimes.
|
||||
"MecabTokenizer", # Internal, should never have been in the main init.
|
||||
"ModelCard", # Internal type.
|
||||
"SqueezeBertModule", # Internal building block (should have been called SqueezeBertLayer)
|
||||
"TFDPRPretrainedReader", # Like an Encoder.
|
||||
"TransfoXLCorpus", # Internal type.
|
||||
"WordpieceTokenizer", # Internal, should never have been in the main init.
|
||||
"absl", # External module
|
||||
"add_end_docstrings", # Internal, should never have been in the main init.
|
||||
"add_start_docstrings", # Internal, should never have been in the main init.
|
||||
"cached_path", # Internal used for downloading models.
|
||||
"convert_tf_weight_name_to_pt_weight_name", # Internal used to convert model weights
|
||||
"logger", # Internal logger
|
||||
"logging", # External module
|
||||
"requires_backends", # Internal function
|
||||
]
|
||||
UNDOCUMENTED_OBJECTS = []
|
||||
|
||||
# This list should be empty. Objects in it should get their own doc page.
|
||||
SHOULD_HAVE_THEIR_OWN_PAGE = [
|
||||
# Benchmarks
|
||||
"PyTorchBenchmark",
|
||||
"PyTorchBenchmarkArguments",
|
||||
"TensorFlowBenchmark",
|
||||
"TensorFlowBenchmarkArguments",
|
||||
]
|
||||
SHOULD_HAVE_THEIR_OWN_PAGE = []
|
||||
|
||||
|
||||
def ignore_undocumented(name):
|
||||
@@ -636,8 +275,8 @@ def ignore_undocumented(name):
|
||||
):
|
||||
return True
|
||||
# Submodules are not documented.
|
||||
if os.path.isdir(os.path.join(PATH_TO_TRANSFORMERS, name)) or os.path.isfile(
|
||||
os.path.join(PATH_TO_TRANSFORMERS, f"{name}.py")
|
||||
if os.path.isdir(os.path.join(PATH_TO_DIFFUSERS, name)) or os.path.isfile(
|
||||
os.path.join(PATH_TO_DIFFUSERS, f"{name}.py")
|
||||
):
|
||||
return True
|
||||
# All load functions are not documented.
|
||||
@@ -660,9 +299,7 @@ def ignore_undocumented(name):
|
||||
def check_all_objects_are_documented():
|
||||
"""Check all models are properly documented."""
|
||||
documented_objs = find_all_documented_objects()
|
||||
modules = transformers._modules
|
||||
objects = [c for c in dir(transformers) if c not in modules and not c.startswith("_")]
|
||||
undocumented_objs = [c for c in objects if c not in documented_objs and not ignore_undocumented(c)]
|
||||
undocumented_objs = [c for c in dir(diffusers) if c not in documented_objs and not ignore_undocumented(c) and not c.startswith("_")]
|
||||
if len(undocumented_objs) > 0:
|
||||
raise Exception(
|
||||
"The following objects are in the public init so should be documented:\n - "
|
||||
@@ -677,8 +314,7 @@ def check_model_type_doc_match():
|
||||
model_doc_folder = Path(PATH_TO_DOC) / "model_doc"
|
||||
model_docs = [m.stem for m in model_doc_folder.glob("*.mdx")]
|
||||
|
||||
model_types = list(transformers.models.auto.configuration_auto.MODEL_NAMES_MAPPING.keys())
|
||||
model_types = [MODEL_TYPE_TO_DOC_MAPPING[m] if m in MODEL_TYPE_TO_DOC_MAPPING else m for m in model_types]
|
||||
model_types = list(diffusers.models.auto.configuration_auto.MODEL_NAMES_MAPPING.keys())
|
||||
|
||||
errors = []
|
||||
for m in model_docs:
|
||||
@@ -723,7 +359,7 @@ def is_rst_docstring(docstring):
|
||||
def check_docstrings_are_in_md():
|
||||
"""Check all docstrings are in md"""
|
||||
files_with_rst = []
|
||||
for file in Path(PATH_TO_TRANSFORMERS).glob("**/*.py"):
|
||||
for file in Path(PATH_TO_DIFFUSERS).glob("**/*.py"):
|
||||
with open(file, "r") as f:
|
||||
code = f.read()
|
||||
docstrings = code.split('"""')
|
||||
@@ -745,17 +381,15 @@ def check_docstrings_are_in_md():
|
||||
|
||||
def check_repo_quality():
|
||||
"""Check all models are properly tested and documented."""
|
||||
print("Checking all models are included.")
|
||||
check_model_list()
|
||||
print("Checking all models are public.")
|
||||
check_models_are_in_init()
|
||||
print("Checking all models are properly tested.")
|
||||
print("Checking all models, schedulers and pipelines are included.")
|
||||
check_modules_are_in_local_init()
|
||||
print("Checking all models, schedulers and pipelines are public.")
|
||||
check_modules_are_in_global_init()
|
||||
print("Checking all models, schedulers and pipelines are properly tested.")
|
||||
check_all_decorator_order()
|
||||
check_all_models_are_tested()
|
||||
check_all_modules_are_tested()
|
||||
print("Checking all objects are properly documented.")
|
||||
check_all_objects_are_documented()
|
||||
print("Checking all models are in at least one auto class.")
|
||||
check_all_models_are_auto_configured()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user