Compare commits

..

43 Commits

Author SHA1 Message Date
Anton Lozhkov
0679d09083 Release: 5.0.0 (#830) 2022-10-13 18:48:50 +02:00
Patrick von Platen
1d51224403 [Flax] Complete tests (#828) 2022-10-13 18:18:32 +02:00
Patrick von Platen
7c2262640b Align PT and Flax API - allow loading checkpoint from PyTorch configs (#827)
* up

* finish

* add more tests

* up

* up

* finish
2022-10-13 17:43:06 +02:00
Pedro Cuenca
78db11dbf3 Flax safety checker (#825)
* Remove set_format in Flax pipeline.

* Remove DummyChecker.

* Run safety_checker in pipeline.

* Don't pmap on every call.

We could have decorated `generate` with `pmap`, but I wanted to keep it
in case someone wants to invoke it in non-parallel mode.

* Remove commented line

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Replicate outside __call__, prepare for optional jitting.

* Remove unnecessary clipping.

As suggested by @kashif.

* Do not jit unless requested.

* Send all args to generate.

* make style

* Remove unused imports.

* Fix docstring.

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
2022-10-13 17:01:47 +02:00
Patrick von Platen
e713346ad1 Give more customizable options for safety checker (#815)
* Give more customizable options for safety checker

* Apply suggestions from code review

* Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py

* Finish

* make style

* Apply suggestions from code review

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* up

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
2022-10-13 15:52:26 +02:00
Anton Lozhkov
26c7df5d82 Fix type mismatch error, add tests for negative prompts (#823) 2022-10-13 15:45:42 +02:00
Anton Lozhkov
e001fededf Fix dreambooth loss type with prior_preservation and fp16 (#826)
Fix dreambooth loss type with prior preservation
2022-10-13 15:41:19 +02:00
Suraj Patil
0a09af2f0a update flax scheduler API (#822)
* update flax scheduler API

* remoev set format

* fix call to scale_model_input

* update flax pndm

* use int32

* update docstr
2022-10-13 15:40:01 +02:00
Patrick von Platen
f1d4289be8 [Flax] Add test (#824) 2022-10-13 13:55:39 +02:00
Anton Lozhkov
323a9e1f6d Add diffusers version and pipeline class to the Hub UA (#814)
* Add diffusers version and pipeline class to the Hub UA

* Fallback to class name for pipelines

* Update src/diffusers/modeling_utils.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update src/diffusers/modeling_flax_utils.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Remove autoclass

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
2022-10-12 21:54:40 +02:00
pink-red
60c384bcd2 Fix fine-tuning compatibility with deepspeed (#816) 2022-10-12 21:43:37 +02:00
Suraj Patil
008b608f15 [train_text2image] Fix EMA and make it compatible with deepspeed. (#813)
* fix ema

* style

* add comment about copy

* style

* quality
2022-10-12 19:13:22 +02:00
Nathan Lambert
5afc2b60cd add or fix license formatting in models directory (#808)
* add or fix license formatting

* fix quality
2022-10-12 08:19:35 -07:00
anton-l
96598639c0 Revert an accidental commit
This reverts commit 679c77f8ea.
2022-10-12 17:20:44 +02:00
anton-l
80be0744a6 Merge remote-tracking branch 'origin/main' 2022-10-12 17:18:42 +02:00
anton-l
679c77f8ea Add diffusers version and pipeline class to the Hub UA 2022-10-12 17:18:32 +02:00
Patrick von Platen
db47b1e4d9 [Dummy imports] Better error message (#795)
* [Dummy imports] Better error message

* Test: load pipeline with LMS scheduler.

Fails with a cryptic message if scipy is not installed.

* Correct

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
2022-10-12 14:41:16 +02:00
Anton Lozhkov
966e2fc461 Minor package fixes (#809) 2022-10-12 13:22:51 +02:00
Patrick von Platen
6bc11782b7 [Img2Img] Fix batch size mismatch prompts vs. init images (#793)
* [Img2Img] Fix batch size mismatch prompts vs. init images

* Remove bogus folder

* fix

* Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
2022-10-12 13:00:36 +02:00
Patrick von Platen
c1b6ea3dce Update img2img.mdx 2022-10-12 00:52:30 +02:00
Pedro Cuenca
24b8b5cf5e mps: Alternative implementation for repeat_interleave (#766)
* mps: alt. implementation for repeat_interleave

* style

* Bump mps version of PyTorch in the documentation.

* Apply suggestions from code review

Co-authored-by: Suraj Patil <surajp815@gmail.com>

* Simplify: do not check for device.

* style

* Fix repeat dimensions:

- The unconditional embeddings are always created from a single prompt.
- I was shadowing the batch_size var.

* Split long lines as suggested by Suraj.

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Suraj Patil <surajp815@gmail.com>
2022-10-11 20:30:09 +02:00
Omar Sanseviero
757babfcad Fix indentation in the code example (#802)
Update custom_pipelines.mdx
2022-10-11 20:26:52 +02:00
spezialspezial
e895952816 Eventually preserve this typo? :) (#804) 2022-10-11 20:06:24 +02:00
Akash Pannu
a124204490 Flax: Trickle down norm_num_groups (#789)
* pass norm_num_groups param and add tests

* set resnet_groups for FlaxUNetMidBlock2D

* fixed docstrings

* fixed typo

* using is_flax_available util and created require_flax decorator
2022-10-11 20:05:10 +02:00
Suraj Patil
66a5279a94 stable diffusion fine-tuning (#356)
* begin text2image script

* loading the datasets, preprocessing & transforms

* handle input features correctly

* add gradient checkpointing support

* fix output names

* run unet in train mode not text encoder

* use no_grad instead of freezing params

* default max steps None

* pad to longest

* don't pad when tokenizing

* fix encode on multi gpu

* fix stupid bug

* add random flip

* add ema

* fix ema

* put ema on cpu

* improve EMA model

* contiguous_format

* don't warp vae and text encode in accelerate

* remove no_grad

* use randn_like

* fix resize

* improve few things

* log epoch loss

* set log level

* don't log each step

* remove max_length from collate

* style

* add report_to option

* make scale_lr false by default

* add grad clipping

* add an option to use 8bit adam

* fix logging in multi-gpu, log every step

* more comments

* remove eval for now

* adress review comments

* add requirements file

* begin readme

* begin readme

* fix typo

* fix push to hub

* populate readme

* update readme

* remove use_auth_token from the script

* address some review comments

* better mixed precision support

* remove redundant to

* create ema model early

* Apply suggestions from code review

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* better description for train_data_dir

* add diffusers in requirements

* update dataset_name_mapping

* update readme

* add inference example

Co-authored-by: anton-l <anton@huggingface.co>
Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
2022-10-11 19:03:39 +02:00
Suraj Patil
797b290ed0 support bf16 for stable diffusion (#792)
* support bf16 for stable diffusion

* fix typo

* address review comments
2022-10-11 12:02:12 +02:00
Henrik Forstén
81bdbb5e2a DreamBooth DeepSpeed support for under 8 GB VRAM training (#735)
* Support deepspeed

* Dreambooth DeepSpeed documentation

* Remove unnecessary casts, documentation

Due to recent commits some casts to half precision are not necessary
anymore.

Mention that DeepSpeed's version of Adam is about 2x faster.

* Review comments
2022-10-10 21:29:27 +02:00
Nathan Lambert
71ca10c6a4 fix typo docstring in unet2d (#798)
fix typo docstring
2022-10-10 11:25:20 -07:00
Patrick von Platen
22963ed826 Fix gradient checkpointing test (#797)
* Fix gradient checkpointing test

* more tsets
2022-10-10 19:40:33 +02:00
Patrick von Platen
fab17528da [Low CPU memory] + device map (#772)
* add accelerate to load models with smaller memory footprint

* remove low_cpu_mem_usage as it is reduntant

* move accelerate init weights context to modelling utils

* add test to ensure results are the same when loading with accelerate

* add tests to ensure ram usage gets lower when using accelerate

* move accelerate logic to single snippet under modelling utils and remove it from configuration utils

* format code using to pass quality check

* fix imports with isor

* add accelerate to test extra deps

* only import accelerate if device_map is set to auto

* move accelerate availability check to diffusers import utils

* format code

* add device map to pipeline abstraction

* lint it to pass PR quality check

* fix class check to use accelerate when using diffusers ModelMixin subclasses

* use low_cpu_mem_usage in transformers if device_map is not available

* NoModuleLayer

* comment out tests

* up

* uP

* finish

* Update src/diffusers/pipelines/stable_diffusion/safety_checker.py

* finish

* uP

* make style

Co-authored-by: Pi Esposito <piero.skywalker@gmail.com>
2022-10-10 18:05:49 +02:00
Nathan Lambert
feaa73243d add sigmoid betas (#777)
* add sigmoid betas

* convert to torch

* add comment on source
2022-10-10 08:28:10 -07:00
Nathan Lambert
a73f8b7251 Clean up resnet.py file (#780)
* clean up resnet.py

* make style and quality

* minor formatting
2022-10-10 08:27:50 -07:00
lowinli
5af6eed9ee debug an exception (#638)
* debug an exception

if dst_path is not a file, it will raise Exception in the function src_path.samefile:
FileNotFoundError: [Errno 2] No such file or directory: '/home/lilongwei/notebook/onnx_diffusion/vae_decoder/model.onnx'

* Update src/diffusers/onnx_utils.py

Co-authored-by: Anton Lozhkov <aglozhkov@gmail.com>
2022-10-10 13:02:18 +02:00
Patrick von Platen
f3983d16ee [Tests] Fix tests (#774)
* Fix tests

* remove bogus file
2022-10-07 19:38:40 +02:00
Suraj Patil
92d7086366 [img2img, inpainting] fix fp16 inference (#769)
* handle dtype in vae and image2image pipeline

* fix inpaint in fp16

* dtype should be handled in add_noise

* style

* address review comments

* add simple fast tests to check fp16

* fix test name

* put mask in fp16
2022-10-07 17:01:51 +02:00
Suraj Patil
ec831b6a72 [schedulers] hanlde dtype in add_noise (#767)
* handle dtype in vae and image2image pipeline

* handle dtype in add noise

* don't modify vae and pipeline

* remove the if
2022-10-07 16:44:19 +02:00
Kevin Turner
cb0bf0bd0b fix(DDIM scheduler): use correct dtype for noise (#742)
Otherwise, it crashes when eta > 0 with float16.
2022-10-07 16:02:32 +02:00
James R T
e0fece2b26 Add final latent slice checks to SD pipeline intermediate state tests (#731)
This is to ensure that the final latent slices stay somewhat consistent as more changes are introduced into the library.

Signed-off-by: James R T <jamestiotio@gmail.com>

Signed-off-by: James R T <jamestiotio@gmail.com>
2022-10-07 15:50:20 +02:00
Justin Chu
75bb6d2d46 Fix ONNX conversion script opset argument type (#739)
The opset argument should be an `int` but was set as a `str`.
2022-10-07 15:47:43 +02:00
YaYaB
906e4105d7 Fix push_to_hub for dreambooth and textual_inversion (#748)
* Fix push_to_hub for dreambooth and textual_inversion

* Use repo.push_to_hub instead of push_to_hub
2022-10-07 11:50:28 +02:00
Patrick von Platen
7258dc4943 remove bogus folder no.2 2022-10-07 11:21:24 +02:00
Patrick von Platen
c93a8cc901 remove bogus folder 2022-10-07 11:20:26 +02:00
Patrick von Platen
9a95414ea1 Bump to v0.5.0dev0 2022-10-07 11:17:55 +02:00
64 changed files with 2360 additions and 2068 deletions

View File

@@ -1 +1,2 @@
include LICENSE
include src/diffusers/utils/model_card_template.md

View File

@@ -1,6 +1,6 @@
<p align="center">
<br>
<img src="docs/source/imgs/diffusers_library.jpg" width="400"/>
<img src="https://github.com/huggingface/diffusers/raw/main/docs/source/imgs/diffusers_library.jpg" width="400"/>
<br>
<p>
<p align="center">

File diff suppressed because one or more lines are too long

View File

@@ -91,24 +91,24 @@ class MyPipeline(DiffusionPipeline):
# Sample gaussian noise to begin loop
image = torch.randn((batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size))
image = image.to(self.device)
image = image.to(self.device)
# set step values
self.scheduler.set_timesteps(num_inference_steps)
# set step values
self.scheduler.set_timesteps(num_inference_steps)
for t in self.progress_bar(self.scheduler.timesteps):
# 1. predict noise model_output
model_output = self.unet(image, t).sample
for t in self.progress_bar(self.scheduler.timesteps):
# 1. predict noise model_output
model_output = self.unet(image, t).sample
# 2. predict previous mean of image x_t-1 and add variance depending on eta
# eta corresponds to η in paper and should be between [0, 1]
# do x_t -> x_t-1
image = self.scheduler.step(model_output, t, image, eta).prev_sample
# 2. predict previous mean of image x_t-1 and add variance depending on eta
# eta corresponds to η in paper and should be between [0, 1]
# do x_t -> x_t-1
image = self.scheduler.step(model_output, t, image, eta).prev_sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
return image
return image
```
Now you can upload this short file under the name `pipeline.py` in your preferred [model repository](https://huggingface.co/docs/hub/models-uploading). For Stable Diffusion pipelines, you may also [join the community organisation for shared pipelines](https://huggingface.co/organizations/sd-diffusers-pipelines-library/share/BUPyDUuHcciGTOKaExlqtfFcyCZsVFdrjr) to upload yours.

View File

@@ -15,6 +15,7 @@ specific language governing permissions and limitations under the License.
The [`StableDiffusionImg2ImgPipeline`] lets you pass a text prompt and an initial image to condition the generation of new images.
```python
import torch
import requests
from PIL import Image
from io import BytesIO

View File

@@ -119,6 +119,46 @@ accelerate launch train_dreambooth.py \
--max_train_steps=800
```
### Training on a 8 GB GPU:
By using [DeepSpeed](https://www.deepspeed.ai/) it's possible to offload some
tensors from VRAM to either CPU or NVME allowing to train with less VRAM.
DeepSpeed needs to be enabled with `accelerate config`. During configuration
answer yes to "Do you want to use DeepSpeed?". With DeepSpeed stage 2, fp16
mixed precision and offloading both parameters and optimizer state to cpu it's
possible to train on under 8 GB VRAM with a drawback of requiring significantly
more RAM (about 25 GB). See [documentation](https://huggingface.co/docs/accelerate/usage_guides/deepspeed) for more DeepSpeed configuration options.
Changing the default Adam optimizer to DeepSpeed's special version of Adam
`deepspeed.ops.adam.DeepSpeedCPUAdam` gives a substantial speedup but enabling
it requires CUDA toolchain with the same version as pytorch. 8-bit optimizer
does not seem to be compatible with DeepSpeed at the moment.
```bash
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
export INSTANCE_DIR="path-to-instance-images"
export CLASS_DIR="path-to-class-images"
export OUTPUT_DIR="path-to-save-model"
accelerate launch train_dreambooth.py \
--pretrained_model_name_or_path=$MODEL_NAME --use_auth_token \
--instance_data_dir=$INSTANCE_DIR \
--class_data_dir=$CLASS_DIR \
--output_dir=$OUTPUT_DIR \
--with_prior_preservation --prior_loss_weight=1.0 \
--instance_prompt="a photo of sks dog" \
--class_prompt="a photo of dog" \
--resolution=512 \
--train_batch_size=1 \
--gradient_accumulation_steps=1 --gradient_checkpointing \
--learning_rate=5e-6 \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--num_class_images=200 \
--max_train_steps=800 \
--mixed_precision=fp16
```
## Inference

View File

@@ -77,7 +77,7 @@ def parse_args():
type=int,
default=100,
help=(
"Minimal class images for prior perversation loss. If not have enough images, additional images will be"
"Minimal class images for prior preservation loss. If not have enough images, additional images will be"
" sampled with class_prompt."
),
)
@@ -471,9 +471,17 @@ def main():
unet, optimizer, train_dataloader, lr_scheduler
)
# Move text_encode and vae to gpu
text_encoder.to(accelerator.device)
vae.to(accelerator.device)
weight_dtype = torch.float32
if args.mixed_precision == "fp16":
weight_dtype = torch.float16
elif args.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
# Move text_encode and vae to gpu.
# For mixed precision training we cast the text_encoder and vae weights to half-precision
# as these models are only used for inference, keeping weights in full precision is not required.
text_encoder.to(accelerator.device, dtype=weight_dtype)
vae.to(accelerator.device, dtype=weight_dtype)
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
@@ -509,11 +517,11 @@ def main():
with accelerator.accumulate(unet):
# Convert images to latent space
with torch.no_grad():
latents = vae.encode(batch["pixel_values"]).latent_dist.sample()
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
latents = latents * 0.18215
# Sample noise that we'll add to the latents
noise = torch.randn(latents.shape).to(latents.device)
noise = torch.randn_like(latents)
bsz = latents.shape[0]
# Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
@@ -536,15 +544,15 @@ def main():
noise, noise_prior = torch.chunk(noise, 2, dim=0)
# Compute instance loss
loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="none").mean([1, 2, 3]).mean()
# Compute prior loss
prior_loss = F.mse_loss(noise_pred_prior, noise_prior, reduction="none").mean([1, 2, 3]).mean()
prior_loss = F.mse_loss(noise_pred_prior.float(), noise_prior.float(), reduction="mean")
# Add the prior loss to the instance loss.
loss = loss + args.prior_loss_weight * prior_loss
else:
loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
accelerator.backward(loss)
if accelerator.sync_gradients:
@@ -575,9 +583,7 @@ def main():
pipeline.save_pretrained(args.output_dir)
if args.push_to_hub:
repo.push_to_hub(
args, pipeline, repo, commit_message="End of training", blocking=False, auto_lfs_prune=True
)
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
accelerator.end_training()

View File

@@ -0,0 +1,101 @@
# Stable Diffusion text-to-image fine-tuning
The `train_text_to_image.py` script shows how to fine-tune stable diffusion model on your own dataset.
___Note___:
___This script is experimental. The script fine-tunes the whole model and often times the model overifits and runs into issues like catastrophic forgetting. It's recommended to try different hyperparamters to get the best result on your dataset.___
## Running locally
### Installing the dependencies
Before running the scripts, make sure to install the library's training dependencies:
```bash
pip install git+https://github.com/huggingface/diffusers.git
pip install -U -r requirements.txt
```
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
```bash
accelerate config
```
### Pokemon example
You need to accept the model license before downloading or using the weights. In this example we'll use model version `v1-4`, so you'll need to visit [its card](https://huggingface.co/CompVis/stable-diffusion-v1-4), read the license and tick the checkbox if you agree.
You have to be a registered user in 🤗 Hugging Face Hub, and you'll also need to use an access token for the code to work. For more information on access tokens, please refer to [this section of the documentation](https://huggingface.co/docs/hub/security-tokens).
Run the following command to authenticate your token
```bash
huggingface-cli login
```
If you have already cloned the repo, then you won't need to go through these steps.
<br>
#### Hardware
With `gradient_checkpointing` and `mixed_precision` it should be possible to fine tune the model on a single 24GB GPU. For higher `batch_size` and faster training it's better to use GPUs with >30GB memory.
```bash
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
export dataset_name="lambdalabs/pokemon-blip-captions"
accelerate launch train_text_to_image.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--dataset_name=$dataset_name \
--use_ema \
--resolution=512 --center_crop --random_flip \
--train_batch_size=1 \
--gradient_accumulation_steps=4 \
--gradient_checkpointing \
--mixed_precision="fp16" \
--max_train_steps=15000 \
--learning_rate=1e-05 \
--max_grad_norm=1 \
--lr_scheduler="constant" --lr_warmup_steps=0 \
--output_dir="sd-pokemon-model"
```
To run on your own training files prepare the dataset according to the format required by `datasets`, you can find the instructions for how to do that in this [document](https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder-with-metadata).
If you wish to use custom loading logic, you should modify the script, we have left pointers for that in the training script.
```bash
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
export TRAIN_DIR="path_to_your_dataset"
accelerate launch train_text_to_image.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--train_data_dir=$TRAIN_DIR \
--use_ema \
--resolution=512 --center_crop --random_flip \
--train_batch_size=1 \
--gradient_accumulation_steps=4 \
--gradient_checkpointing \
--mixed_precision="fp16" \
--max_train_steps=15000 \
--learning_rate=1e-05 \
--max_grad_norm=1 \
--lr_scheduler="constant" --lr_warmup_steps=0 \
--output_dir="sd-pokemon-model"
```
Once the training is finished the model will be saved in the `output_dir` specified in the command. In this example it's `sd-pokemon-model`. To load the fine-tuned model for inference just pass that path to `StableDiffusionPipeline`
```python
from diffusers import StableDiffusionPipeline
model_path = "path_to_saved_model"
pipe = StableDiffusionPipeline.from_pretrained(model_path, torch_dtype=torch.float16)
pipe.to("cuda")
image = pipe(prompt="yoda").images[0]
image.save("yoda-pokemon.png")
```

View File

@@ -0,0 +1,7 @@
diffusers==0.4.1
accelerate
torchvision
transformers>=4.21.0
ftfy
tensorboard
modelcards

View File

@@ -0,0 +1,627 @@
import argparse
import logging
import math
import os
import random
from pathlib import Path
from typing import Iterable, Optional
import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import set_seed
from datasets import load_dataset
from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel
from diffusers.optimization import get_scheduler
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from huggingface_hub import HfFolder, Repository, whoami
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
logger = get_logger(__name__)
def parse_args():
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument(
"--pretrained_model_name_or_path",
type=str,
default=None,
required=True,
help="Path to pretrained model or model identifier from huggingface.co/models.",
)
parser.add_argument(
"--dataset_name",
type=str,
default=None,
help=(
"The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
" or to a folder containing files that 🤗 Datasets can understand."
),
)
parser.add_argument(
"--dataset_config_name",
type=str,
default=None,
help="The config of the Dataset, leave as None if there's only one config.",
)
parser.add_argument(
"--train_data_dir",
type=str,
default=None,
help=(
"A folder containing the training data. Folder contents must follow the structure described in"
" https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
" must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
),
)
parser.add_argument(
"--image_column", type=str, default="image", help="The column of the dataset containing an image."
)
parser.add_argument(
"--caption_column",
type=str,
default="text",
help="The column of the dataset containing a caption or a list of captions.",
)
parser.add_argument(
"--max_train_samples",
type=int,
default=None,
help=(
"For debugging purposes or quicker training, truncate the number of training examples to this "
"value if set."
),
)
parser.add_argument(
"--output_dir",
type=str,
default="sd-model-finetuned",
help="The output directory where the model predictions and checkpoints will be written.",
)
parser.add_argument(
"--cache_dir",
type=str,
default=None,
help="The directory where the downloaded models and datasets will be stored.",
)
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
parser.add_argument(
"--resolution",
type=int,
default=512,
help=(
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
" resolution"
),
)
parser.add_argument(
"--center_crop",
action="store_true",
help="Whether to center crop images before resizing to resolution (if not set, random crop will be used)",
)
parser.add_argument(
"--random_flip",
action="store_true",
help="whether to randomly flip images horizontally",
)
parser.add_argument(
"--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
)
parser.add_argument("--num_train_epochs", type=int, default=100)
parser.add_argument(
"--max_train_steps",
type=int,
default=None,
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
)
parser.add_argument(
"--gradient_accumulation_steps",
type=int,
default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.",
)
parser.add_argument(
"--gradient_checkpointing",
action="store_true",
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
)
parser.add_argument(
"--learning_rate",
type=float,
default=1e-4,
help="Initial learning rate (after the potential warmup period) to use.",
)
parser.add_argument(
"--scale_lr",
action="store_true",
default=False,
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
)
parser.add_argument(
"--lr_scheduler",
type=str,
default="constant",
help=(
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
' "constant", "constant_with_warmup"]'
),
)
parser.add_argument(
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
)
parser.add_argument(
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
)
parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
parser.add_argument(
"--hub_model_id",
type=str,
default=None,
help="The name of the repository to keep in sync with the local `output_dir`.",
)
parser.add_argument(
"--logging_dir",
type=str,
default="logs",
help=(
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
),
)
parser.add_argument(
"--mixed_precision",
type=str,
default="no",
choices=["no", "fp16", "bf16"],
help=(
"Whether to use mixed precision. Choose"
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
"and an Nvidia Ampere GPU."
),
)
parser.add_argument(
"--report_to",
type=str,
default="tensorboard",
help=(
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,'
' `"wandb"` and `"comet_ml"`. Use `"all"` (default) to report to all integrations.'
"Only applicable when `--with_tracking` is passed."
),
)
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
args = parser.parse_args()
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
if env_local_rank != -1 and env_local_rank != args.local_rank:
args.local_rank = env_local_rank
# Sanity checks
if args.dataset_name is None and args.train_data_dir is None:
raise ValueError("Need either a dataset name or a training folder.")
return args
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
if token is None:
token = HfFolder.get_token()
if organization is None:
username = whoami(token)["name"]
return f"{username}/{model_id}"
else:
return f"{organization}/{model_id}"
dataset_name_mapping = {
"lambdalabs/pokemon-blip-captions": ("image", "text"),
}
# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14
class EMAModel:
"""
Exponential Moving Average of models weights
"""
def __init__(self, parameters: Iterable[torch.nn.Parameter], decay=0.9999):
parameters = list(parameters)
self.shadow_params = [p.clone().detach() for p in parameters]
self.decay = decay
self.optimization_step = 0
def get_decay(self, optimization_step):
"""
Compute the decay factor for the exponential moving average.
"""
value = (1 + optimization_step) / (10 + optimization_step)
return 1 - min(self.decay, value)
@torch.no_grad()
def step(self, parameters):
parameters = list(parameters)
self.optimization_step += 1
self.decay = self.get_decay(self.optimization_step)
for s_param, param in zip(self.shadow_params, parameters):
if param.requires_grad:
tmp = self.decay * (s_param - param)
s_param.sub_(tmp)
else:
s_param.copy_(param)
torch.cuda.empty_cache()
def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None:
"""
Copy current averaged parameters into given collection of parameters.
Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
updated with the stored moving averages. If `None`, the
parameters with which this `ExponentialMovingAverage` was
initialized will be used.
"""
parameters = list(parameters)
for s_param, param in zip(self.shadow_params, parameters):
param.data.copy_(s_param.data)
def to(self, device=None, dtype=None) -> None:
r"""Move internal buffers of the ExponentialMovingAverage to `device`.
Args:
device: like `device` argument to `torch.Tensor.to`
"""
# .to() on the tensors handles None correctly
self.shadow_params = [
p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device)
for p in self.shadow_params
]
def main():
args = parse_args()
logging_dir = os.path.join(args.output_dir, args.logging_dir)
accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision,
log_with=args.report_to,
logging_dir=logging_dir,
)
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
# If passed along, set the training seed now.
if args.seed is not None:
set_seed(args.seed)
# Handle the repository creation
if accelerator.is_main_process:
if args.push_to_hub:
if args.hub_model_id is None:
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
else:
repo_name = args.hub_model_id
repo = Repository(args.output_dir, clone_from=repo_name)
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
if "step_*" not in gitignore:
gitignore.write("step_*\n")
if "epoch_*" not in gitignore:
gitignore.write("epoch_*\n")
elif args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True)
# Load models and create wrapper for stable diffusion
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
# Freeze vae and text_encoder
vae.requires_grad_(False)
text_encoder.requires_grad_(False)
if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
if args.scale_lr:
args.learning_rate = (
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
)
# Initialize the optimizer
if args.use_8bit_adam:
try:
import bitsandbytes as bnb
except ImportError:
raise ImportError(
"Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
)
optimizer_cls = bnb.optim.AdamW8bit
else:
optimizer_cls = torch.optim.AdamW
optimizer = optimizer_cls(
unet.parameters(),
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,
)
# TODO (patil-suraj): load scheduler using args
noise_scheduler = DDPMScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, tensor_format="pt"
)
# Get the datasets: you can either provide your own training and evaluation files (see below)
# or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
# In distributed training, the load_dataset function guarantees that only one local process can concurrently
# download the dataset.
if args.dataset_name is not None:
# Downloading and loading a dataset from the hub.
dataset = load_dataset(
args.dataset_name,
args.dataset_config_name,
cache_dir=args.cache_dir,
)
else:
data_files = {}
if args.train_data_dir is not None:
data_files["train"] = os.path.join(args.train_data_dir, "**")
dataset = load_dataset(
"imagefolder",
data_files=data_files,
cache_dir=args.cache_dir,
)
# See more about loading custom images at
# https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
# Preprocessing the datasets.
# We need to tokenize inputs and targets.
column_names = dataset["train"].column_names
# 6. Get the column names for input/target.
dataset_columns = dataset_name_mapping.get(args.dataset_name, None)
if args.image_column is None:
image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
else:
image_column = args.image_column
if image_column not in column_names:
raise ValueError(
f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}"
)
if args.caption_column is None:
caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
else:
caption_column = args.caption_column
if caption_column not in column_names:
raise ValueError(
f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}"
)
# Preprocessing the datasets.
# We need to tokenize input captions and transform the images.
def tokenize_captions(examples, is_train=True):
captions = []
for caption in examples[caption_column]:
if isinstance(caption, str):
captions.append(caption)
elif isinstance(caption, (list, np.ndarray)):
# take a random caption if there are multiple
captions.append(random.choice(caption) if is_train else caption[0])
else:
raise ValueError(
f"Caption column `{caption_column}` should contain either strings or lists of strings."
)
inputs = tokenizer(captions, max_length=tokenizer.model_max_length, padding="do_not_pad", truncation=True)
input_ids = inputs.input_ids
return input_ids
train_transforms = transforms.Compose(
[
transforms.Resize((args.resolution, args.resolution), interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
def preprocess_train(examples):
images = [image.convert("RGB") for image in examples[image_column]]
examples["pixel_values"] = [train_transforms(image) for image in images]
examples["input_ids"] = tokenize_captions(examples)
return examples
with accelerator.main_process_first():
if args.max_train_samples is not None:
dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
# Set the training transforms
train_dataset = dataset["train"].with_transform(preprocess_train)
def collate_fn(examples):
pixel_values = torch.stack([example["pixel_values"] for example in examples])
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
input_ids = [example["input_ids"] for example in examples]
padded_tokens = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt")
return {
"pixel_values": pixel_values,
"input_ids": padded_tokens.input_ids,
"attention_mask": padded_tokens.attention_mask,
}
train_dataloader = torch.utils.data.DataLoader(
train_dataset, shuffle=True, collate_fn=collate_fn, batch_size=args.train_batch_size
)
# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
overrode_max_train_steps = True
lr_scheduler = get_scheduler(
args.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
)
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, optimizer, train_dataloader, lr_scheduler
)
weight_dtype = torch.float32
if args.mixed_precision == "fp16":
weight_dtype = torch.float16
elif args.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
# Move text_encode and vae to gpu.
# For mixed precision training we cast the text_encoder and vae weights to half-precision
# as these models are only used for inference, keeping weights in full precision is not required.
text_encoder.to(accelerator.device, dtype=weight_dtype)
vae.to(accelerator.device, dtype=weight_dtype)
# Create EMA for the unet.
if args.use_ema:
ema_unet = EMAModel(unet.parameters())
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if overrode_max_train_steps:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
# Afterwards we recalculate our number of training epochs
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
# We need to initialize the trackers we use, and also store our configuration.
# The trackers initializes automatically on the main process.
if accelerator.is_main_process:
accelerator.init_trackers("text2image-fine-tune", config=vars(args))
# Train!
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
logger.info("***** Running training *****")
logger.info(f" Num examples = {len(train_dataset)}")
logger.info(f" Num Epochs = {args.num_train_epochs}")
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {args.max_train_steps}")
# Only show the progress bar once on each machine.
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
progress_bar.set_description("Steps")
global_step = 0
for epoch in range(args.num_train_epochs):
unet.train()
train_loss = 0.0
for step, batch in enumerate(train_dataloader):
with accelerator.accumulate(unet):
# Convert images to latent space
latents = vae.encode(batch["pixel_values"].to(weight_dtype)).latent_dist.sample()
latents = latents * 0.18215
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
bsz = latents.shape[0]
# Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device)
timesteps = timesteps.long()
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# Get the text embedding for conditioning
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
# Predict the noise residual and compute loss
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
# Gather the losses across all processes for logging (if we use distributed training).
avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
train_loss += avg_loss.item() / args.gradient_accumulation_steps
# Backpropagate
accelerator.backward(loss)
if accelerator.sync_gradients:
accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
if args.use_ema:
ema_unet.step(unet.parameters())
progress_bar.update(1)
global_step += 1
accelerator.log({"train_loss": train_loss}, step=global_step)
train_loss = 0.0
logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
if global_step >= args.max_train_steps:
break
# Create the pipeline using the trained modules and save it.
accelerator.wait_for_everyone()
if accelerator.is_main_process:
unet = accelerator.unwrap_model(unet)
if args.use_ema:
ema_unet.copy_to(unet.parameters())
pipeline = StableDiffusionPipeline(
text_encoder=text_encoder,
vae=vae,
unet=unet,
tokenizer=tokenizer,
scheduler=PNDMScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True
),
safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"),
feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
)
pipeline.save_pretrained(args.output_dir)
if args.push_to_hub:
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
accelerator.end_training()
if __name__ == "__main__":
main()

View File

@@ -569,9 +569,7 @@ def main():
save_progress(text_encoder, placeholder_token_id, accelerator, args)
if args.push_to_hub:
repo.push_to_hub(
args, pipeline, repo, commit_message="End of training", blocking=False, auto_lfs_prune=True
)
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
accelerator.end_training()

View File

@@ -9,7 +9,7 @@ from accelerate import Accelerator
from accelerate.logging import get_logger
from datasets import load_dataset
from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
from diffusers.hub_utils import init_git_repo, push_to_hub
from diffusers.hub_utils import init_git_repo
from diffusers.optimization import get_scheduler
from diffusers.training_utils import EMAModel
from torchvision.transforms import (
@@ -185,7 +185,7 @@ def main(args):
if epoch % args.save_model_epochs == 0 or epoch == args.num_epochs - 1:
# save the model
if args.push_to_hub:
push_to_hub(args, pipeline, repo, commit_message=f"Epoch {epoch}", blocking=False)
repo.push_to_hub(commit_message=f"Epoch {epoch}", blocking=False)
else:
pipeline.save_pretrained(args.output_dir)
accelerator.wait_for_everyone()

View File

@@ -1,102 +0,0 @@
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple, Union
import torch
from diffusers.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
class CustomPipeline(DiffusionPipeline):
r"""
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Parameters:
unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image.
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `unet` to denoise the encoded image. Can be one of
[`DDPMScheduler`], or [`DDIMScheduler`].
"""
def __init__(self, unet, scheduler):
super().__init__()
self.register_modules(unet=unet, scheduler=scheduler)
@torch.no_grad()
def __call__(
self,
batch_size: int = 1,
generator: Optional[torch.Generator] = None,
eta: float = 0.0,
num_inference_steps: int = 50,
output_type: Optional[str] = "pil",
return_dict: bool = True,
**kwargs,
) -> Union[ImagePipelineOutput, Tuple]:
r"""
Args:
batch_size (`int`, *optional*, defaults to 1):
The number of images to generate.
generator (`torch.Generator`, *optional*):
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
deterministic.
eta (`float`, *optional*, defaults to 0.0):
The eta parameter which controls the scale of the variance (0 is DDIM and 1 is one type of DDPM).
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple.
Returns:
[`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if
`return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
generated images.
"""
# Sample gaussian noise to begin loop
image = torch.randn(
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
generator=generator,
)
image = image.to(self.device)
# set step values
self.scheduler.set_timesteps(num_inference_steps)
for t in self.progress_bar(self.scheduler.timesteps):
# 1. predict noise model_output
model_output = self.unet(image, t).sample
# 2. predict previous mean of image x_t-1 and add variance depending on eta
# eta corresponds to η in paper and should be between [0, 1]
# do x_t -> x_t-1
image = self.scheduler.step(model_output, t, image, eta).prev_sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
if output_type == "pil":
image = self.numpy_to_pil(image)
if not return_dict:
return (image,)
return ImagePipelineOutput(images=image), "This is a test"

View File

@@ -206,7 +206,7 @@ if __name__ == "__main__":
parser.add_argument(
"--opset",
default=14,
type=str,
type=int,
help="The version of the ONNX operator set to use.",
)

View File

@@ -211,7 +211,7 @@ install_requires = [
setup(
name="diffusers",
version="0.4.2", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
version="0.5.0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
description="Diffusers",
long_description=open("README.md", "r", encoding="utf-8").read(),
long_description_content_type="text/markdown",

View File

@@ -9,7 +9,7 @@ from .utils import (
)
__version__ = "0.4.2"
__version__ = "0.5.0"
from .configuration_utils import ConfigMixin
from .onnx_utils import OnnxRuntimeModel

View File

@@ -27,7 +27,7 @@ from huggingface_hub import hf_hub_download
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
from requests import HTTPError
from . import is_torch_available
from . import __version__, is_torch_available
from .modeling_flax_pytorch_utils import convert_pytorch_state_dict_to_flax
from .utils import (
CONFIG_NAME,
@@ -286,10 +286,13 @@ class FlaxModelMixin:
local_files_only = kwargs.pop("local_files_only", False)
use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None)
from_auto_class = kwargs.pop("_from_auto", False)
subfolder = kwargs.pop("subfolder", None)
user_agent = {"file_type": "model", "framework": "flax", "from_auto_class": from_auto_class}
user_agent = {
"diffusers": __version__,
"file_type": "model",
"framework": "flax",
}
# Load config if we don't provide a configuration
config_path = config if config is not None else pretrained_model_name_or_path

View File

@@ -26,6 +26,7 @@ from huggingface_hub import hf_hub_download
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
from requests import HTTPError
from . import __version__
from .utils import CONFIG_NAME, DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, WEIGHTS_NAME, logging
@@ -292,12 +293,15 @@ class ModelMixin(torch.nn.Module):
local_files_only = kwargs.pop("local_files_only", False)
use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None)
from_auto_class = kwargs.pop("_from_auto", False)
torch_dtype = kwargs.pop("torch_dtype", None)
subfolder = kwargs.pop("subfolder", None)
device_map = kwargs.pop("device_map", None)
user_agent = {"file_type": "model", "framework": "pytorch", "from_auto_class": from_auto_class}
user_agent = {
"diffusers": __version__,
"file_type": "model",
"framework": "pytorch",
}
# Load config if we don't provide a configuration
config_path = pretrained_model_name_or_path

View File

@@ -1,3 +1,16 @@
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Optional

View File

@@ -9,9 +9,10 @@ class Upsample2D(nn.Module):
"""
An upsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is
applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
upsampling occurs in the inner-two dimensions.
Parameters:
channels: channels in the inputs and outputs.
use_conv: a bool determining if a convolution is applied.
dims: determines if the signal is 1D, 2D, or 3D. If 3D, then upsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
@@ -40,6 +41,13 @@ class Upsample2D(nn.Module):
if self.use_conv_transpose:
return self.conv(hidden_states)
# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
# TODO(Suraj): Remove this cast once the issue is fixed in PyTorch
# https://github.com/pytorch/pytorch/issues/86679
dtype = hidden_states.dtype
if dtype == torch.bfloat16:
hidden_states = hidden_states.to(torch.float32)
# if `output_size` is passed we force the interpolation output
# size and do not make use of `scale_factor=2`
if output_size is None:
@@ -47,6 +55,10 @@ class Upsample2D(nn.Module):
else:
hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
# If the input is bfloat16, we cast back to bfloat16
if dtype == torch.bfloat16:
hidden_states = hidden_states.to(dtype)
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
if self.use_conv:
if self.name == "conv":
@@ -61,9 +73,10 @@ class Downsample2D(nn.Module):
"""
A downsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is
applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
downsampling occurs in the inner-two dimensions.
Parameters:
channels: channels in the inputs and outputs.
use_conv: a bool determining if a convolution is applied.
dims: determines if the signal is 1D, 2D, or 3D. If 3D, then downsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
@@ -115,21 +128,22 @@ class FirUpsample2D(nn.Module):
def _upsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1):
"""Fused `upsample_2d()` followed by `Conv2d()`.
Args:
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary:
order.
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
C]`.
weight: Weight tensor of the shape `[filterH, filterW, inChannels,
outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
(separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
arbitrary order.
Args:
hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
weight: Weight tensor of the shape `[filterH, filterW, inChannels,
outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
(separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
factor: Integer upsampling factor (default: 2).
gain: Scaling factor for signal magnitude (default: 1.0).
Returns:
Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same datatype as
`x`.
output: Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same
datatype as `hidden_states`.
"""
assert isinstance(factor, int) and factor >= 1
@@ -164,7 +178,6 @@ class FirUpsample2D(nn.Module):
output_shape[1] - (hidden_states.shape[3] - 1) * stride[1] - convW,
)
assert output_padding[0] >= 0 and output_padding[1] >= 0
inC = weight.shape[1]
num_groups = hidden_states.shape[1] // inC
# Transpose weights.
@@ -214,20 +227,23 @@ class FirDownsample2D(nn.Module):
def _downsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1):
"""Fused `Conv2d()` followed by `downsample_2d()`.
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
arbitrary order.
Args:
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary:
order.
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. w: Weight tensor of the shape `[filterH,
filterW, inChannels, outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] //
numGroups`. k: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] *
factor`, which corresponds to average pooling. factor: Integer downsampling factor (default: 2). gain:
Scaling factor for signal magnitude (default: 1.0).
hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
weight:
Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be
performed by `inChannels = x.shape[0] // numGroups`.
kernel: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] *
factor`, which corresponds to average pooling.
factor: Integer downsampling factor (default: 2).
gain: Scaling factor for signal magnitude (default: 1.0).
Returns:
Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and same
datatype as `x`.
output: Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and
same datatype as `x`.
"""
assert isinstance(factor, int) and factor >= 1
@@ -251,17 +267,17 @@ class FirDownsample2D(nn.Module):
torch.tensor(kernel, device=hidden_states.device),
pad=((pad_value + 1) // 2, pad_value // 2),
)
hidden_states = F.conv2d(upfirdn_input, weight, stride=stride_value, padding=0)
output = F.conv2d(upfirdn_input, weight, stride=stride_value, padding=0)
else:
pad_value = kernel.shape[0] - factor
hidden_states = upfirdn2d_native(
output = upfirdn2d_native(
hidden_states,
torch.tensor(kernel, device=hidden_states.device),
down=factor,
pad=((pad_value + 1) // 2, pad_value // 2),
)
return hidden_states
return output
def forward(self, hidden_states):
if self.use_conv:
@@ -393,20 +409,20 @@ class Mish(torch.nn.Module):
def upsample_2d(hidden_states, kernel=None, factor=2, gain=1):
r"""Upsample2D a batch of 2D images with the given filter.
Args:
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified
`gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is a:
multiple of the upsampling factor.
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
C]`.
k: FIR filter of the shape `[firH, firW]` or `[firN]`
`gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is
a: multiple of the upsampling factor.
Args:
hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
(separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
factor: Integer upsampling factor (default: 2).
gain: Scaling factor for signal magnitude (default: 1.0).
Returns:
Tensor of the shape `[N, C, H * factor, W * factor]`
output: Tensor of the shape `[N, C, H * factor, W * factor]`
"""
assert isinstance(factor, int) and factor >= 1
if kernel is None:
@@ -419,30 +435,31 @@ def upsample_2d(hidden_states, kernel=None, factor=2, gain=1):
kernel = kernel * (gain * (factor**2))
pad_value = kernel.shape[0] - factor
return upfirdn2d_native(
output = upfirdn2d_native(
hidden_states,
kernel.to(device=hidden_states.device),
up=factor,
pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2),
)
return output
def downsample_2d(hidden_states, kernel=None, factor=2, gain=1):
r"""Downsample2D a batch of 2D images with the given filter.
Args:
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the
specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its
shape is a multiple of the downsampling factor.
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
C]`.
Args:
hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
(separable). The default is `[1] * factor`, which corresponds to average pooling.
factor: Integer downsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
factor: Integer downsampling factor (default: 2).
gain: Scaling factor for signal magnitude (default: 1.0).
Returns:
Tensor of the shape `[N, C, H // factor, W // factor]`
output: Tensor of the shape `[N, C, H // factor, W // factor]`
"""
assert isinstance(factor, int) and factor >= 1
@@ -456,34 +473,34 @@ def downsample_2d(hidden_states, kernel=None, factor=2, gain=1):
kernel = kernel * gain
pad_value = kernel.shape[0] - factor
return upfirdn2d_native(
output = upfirdn2d_native(
hidden_states, kernel.to(device=hidden_states.device), down=factor, pad=((pad_value + 1) // 2, pad_value // 2)
)
return output
def upfirdn2d_native(input, kernel, up=1, down=1, pad=(0, 0)):
def upfirdn2d_native(tensor, kernel, up=1, down=1, pad=(0, 0)):
up_x = up_y = up
down_x = down_y = down
pad_x0 = pad_y0 = pad[0]
pad_x1 = pad_y1 = pad[1]
_, channel, in_h, in_w = input.shape
input = input.reshape(-1, in_h, in_w, 1)
# Rename this variable (input); it shadows a builtin.sonarlint(python:S5806)
_, channel, in_h, in_w = tensor.shape
tensor = tensor.reshape(-1, in_h, in_w, 1)
_, in_h, in_w, minor = input.shape
_, in_h, in_w, minor = tensor.shape
kernel_h, kernel_w = kernel.shape
out = input.view(-1, in_h, 1, in_w, 1, minor)
out = tensor.view(-1, in_h, 1, in_w, 1, minor)
# Temporary workaround for mps specific issue: https://github.com/pytorch/pytorch/issues/84535
if input.device.type == "mps":
if tensor.device.type == "mps":
out = out.to("cpu")
out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
out = out.view(-1, in_h * up_y, in_w * up_x, minor)
out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
out = out.to(input.device) # Move back to mps if necessary
out = out.to(tensor.device) # Move back to mps if necessary
out = out[
:,
max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),

View File

@@ -1,3 +1,16 @@
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import flax.linen as nn
import jax
import jax.numpy as jnp

View File

@@ -1,3 +1,16 @@
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Optional, Tuple, Union

View File

@@ -1,3 +1,16 @@
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Optional, Tuple, Union
@@ -223,7 +236,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
encoder_hidden_states: torch.Tensor,
return_dict: bool = True,
) -> Union[UNet2DConditionOutput, Tuple]:
"""r
r"""
Args:
sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps

View File

@@ -1,3 +1,16 @@
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple, Union
import flax

View File

@@ -10,10 +10,8 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
import numpy as np
# limitations under the License.
import numpy as np
import torch
from torch import nn

View File

@@ -10,6 +10,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import flax.linen as nn
import jax.numpy as jnp

View File

@@ -1,3 +1,16 @@
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Optional, Tuple, Union

View File

@@ -1,3 +1,17 @@
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# JAX implementation of VQGAN from taming-transformers https://github.com/CompVis/taming-transformers
import math
@@ -119,6 +133,8 @@ class FlaxResnetBlock2D(nn.Module):
Output channels
dropout (:obj:`float`, *optional*, defaults to 0.0):
Dropout rate
groups (:obj:`int`, *optional*, defaults to `32`):
The number of groups to use for group norm.
use_nin_shortcut (:obj:`bool`, *optional*, defaults to `None`):
Whether to use `nin_shortcut`. This activates a new layer inside ResNet block
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
@@ -128,13 +144,14 @@ class FlaxResnetBlock2D(nn.Module):
in_channels: int
out_channels: int = None
dropout: float = 0.0
groups: int = 32
use_nin_shortcut: bool = None
dtype: jnp.dtype = jnp.float32
def setup(self):
out_channels = self.in_channels if self.out_channels is None else self.out_channels
self.norm1 = nn.GroupNorm(num_groups=32, epsilon=1e-6)
self.norm1 = nn.GroupNorm(num_groups=self.groups, epsilon=1e-6)
self.conv1 = nn.Conv(
out_channels,
kernel_size=(3, 3),
@@ -143,7 +160,7 @@ class FlaxResnetBlock2D(nn.Module):
dtype=self.dtype,
)
self.norm2 = nn.GroupNorm(num_groups=32, epsilon=1e-6)
self.norm2 = nn.GroupNorm(num_groups=self.groups, epsilon=1e-6)
self.dropout_layer = nn.Dropout(self.dropout)
self.conv2 = nn.Conv(
out_channels,
@@ -191,12 +208,15 @@ class FlaxAttentionBlock(nn.Module):
Input channels
num_head_channels (:obj:`int`, *optional*, defaults to `None`):
Number of attention heads
num_groups (:obj:`int`, *optional*, defaults to `32`):
The number of groups to use for group norm
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
"""
channels: int
num_head_channels: int = None
num_groups: int = 32
dtype: jnp.dtype = jnp.float32
def setup(self):
@@ -204,7 +224,7 @@ class FlaxAttentionBlock(nn.Module):
dense = partial(nn.Dense, self.channels, dtype=self.dtype)
self.group_norm = nn.GroupNorm(num_groups=32, epsilon=1e-6)
self.group_norm = nn.GroupNorm(num_groups=self.num_groups, epsilon=1e-6)
self.query, self.key, self.value = dense(), dense(), dense()
self.proj_attn = dense()
@@ -264,6 +284,8 @@ class FlaxDownEncoderBlock2D(nn.Module):
Dropout rate
num_layers (:obj:`int`, *optional*, defaults to 1):
Number of Resnet layer block
resnet_groups (:obj:`int`, *optional*, defaults to `32`):
The number of groups to use for the Resnet block group norm
add_downsample (:obj:`bool`, *optional*, defaults to `True`):
Whether to add downsample layer
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
@@ -273,6 +295,7 @@ class FlaxDownEncoderBlock2D(nn.Module):
out_channels: int
dropout: float = 0.0
num_layers: int = 1
resnet_groups: int = 32
add_downsample: bool = True
dtype: jnp.dtype = jnp.float32
@@ -285,6 +308,7 @@ class FlaxDownEncoderBlock2D(nn.Module):
in_channels=in_channels,
out_channels=self.out_channels,
dropout=self.dropout,
groups=self.resnet_groups,
dtype=self.dtype,
)
resnets.append(res_block)
@@ -303,9 +327,9 @@ class FlaxDownEncoderBlock2D(nn.Module):
return hidden_states
class FlaxUpEncoderBlock2D(nn.Module):
class FlaxUpDecoderBlock2D(nn.Module):
r"""
Flax Resnet blocks-based Encoder block for diffusion-based VAE.
Flax Resnet blocks-based Decoder block for diffusion-based VAE.
Parameters:
in_channels (:obj:`int`):
@@ -316,8 +340,10 @@ class FlaxUpEncoderBlock2D(nn.Module):
Dropout rate
num_layers (:obj:`int`, *optional*, defaults to 1):
Number of Resnet layer block
add_downsample (:obj:`bool`, *optional*, defaults to `True`):
Whether to add downsample layer
resnet_groups (:obj:`int`, *optional*, defaults to `32`):
The number of groups to use for the Resnet block group norm
add_upsample (:obj:`bool`, *optional*, defaults to `True`):
Whether to add upsample layer
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
"""
@@ -325,6 +351,7 @@ class FlaxUpEncoderBlock2D(nn.Module):
out_channels: int
dropout: float = 0.0
num_layers: int = 1
resnet_groups: int = 32
add_upsample: bool = True
dtype: jnp.dtype = jnp.float32
@@ -336,6 +363,7 @@ class FlaxUpEncoderBlock2D(nn.Module):
in_channels=in_channels,
out_channels=self.out_channels,
dropout=self.dropout,
groups=self.resnet_groups,
dtype=self.dtype,
)
resnets.append(res_block)
@@ -366,6 +394,8 @@ class FlaxUNetMidBlock2D(nn.Module):
Dropout rate
num_layers (:obj:`int`, *optional*, defaults to 1):
Number of Resnet layer block
resnet_groups (:obj:`int`, *optional*, defaults to `32`):
The number of groups to use for the Resnet and Attention block group norm
attn_num_head_channels (:obj:`int`, *optional*, defaults to `1`):
Number of attention heads for each attention block
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
@@ -374,16 +404,20 @@ class FlaxUNetMidBlock2D(nn.Module):
in_channels: int
dropout: float = 0.0
num_layers: int = 1
resnet_groups: int = 32
attn_num_head_channels: int = 1
dtype: jnp.dtype = jnp.float32
def setup(self):
resnet_groups = self.resnet_groups if self.resnet_groups is not None else min(self.in_channels // 4, 32)
# there is always at least one resnet
resnets = [
FlaxResnetBlock2D(
in_channels=self.in_channels,
out_channels=self.in_channels,
dropout=self.dropout,
groups=resnet_groups,
dtype=self.dtype,
)
]
@@ -392,7 +426,10 @@ class FlaxUNetMidBlock2D(nn.Module):
for _ in range(self.num_layers):
attn_block = FlaxAttentionBlock(
channels=self.in_channels, num_head_channels=self.attn_num_head_channels, dtype=self.dtype
channels=self.in_channels,
num_head_channels=self.attn_num_head_channels,
num_groups=resnet_groups,
dtype=self.dtype,
)
attentions.append(attn_block)
@@ -400,6 +437,7 @@ class FlaxUNetMidBlock2D(nn.Module):
in_channels=self.in_channels,
out_channels=self.in_channels,
dropout=self.dropout,
groups=resnet_groups,
dtype=self.dtype,
)
resnets.append(res_block)
@@ -441,7 +479,7 @@ class FlaxEncoder(nn.Module):
Tuple containing the number of output channels for each block
layers_per_block (:obj:`int`, *optional*, defaults to `2`):
Number of Resnet layer for each block
norm_num_groups (:obj:`int`, *optional*, defaults to `2`):
norm_num_groups (:obj:`int`, *optional*, defaults to `32`):
norm num group
act_fn (:obj:`str`, *optional*, defaults to `silu`):
Activation function
@@ -483,6 +521,7 @@ class FlaxEncoder(nn.Module):
in_channels=input_channel,
out_channels=output_channel,
num_layers=self.layers_per_block,
resnet_groups=self.norm_num_groups,
add_downsample=not is_final_block,
dtype=self.dtype,
)
@@ -491,12 +530,15 @@ class FlaxEncoder(nn.Module):
# middle
self.mid_block = FlaxUNetMidBlock2D(
in_channels=block_out_channels[-1], attn_num_head_channels=None, dtype=self.dtype
in_channels=block_out_channels[-1],
resnet_groups=self.norm_num_groups,
attn_num_head_channels=None,
dtype=self.dtype,
)
# end
conv_out_channels = 2 * self.out_channels if self.double_z else self.out_channels
self.conv_norm_out = nn.GroupNorm(num_groups=32, epsilon=1e-6)
self.conv_norm_out = nn.GroupNorm(num_groups=self.norm_num_groups, epsilon=1e-6)
self.conv_out = nn.Conv(
conv_out_channels,
kernel_size=(3, 3),
@@ -581,7 +623,10 @@ class FlaxDecoder(nn.Module):
# middle
self.mid_block = FlaxUNetMidBlock2D(
in_channels=block_out_channels[-1], attn_num_head_channels=None, dtype=self.dtype
in_channels=block_out_channels[-1],
resnet_groups=self.norm_num_groups,
attn_num_head_channels=None,
dtype=self.dtype,
)
# upsampling
@@ -594,10 +639,11 @@ class FlaxDecoder(nn.Module):
is_final_block = i == len(block_out_channels) - 1
up_block = FlaxUpEncoderBlock2D(
up_block = FlaxUpDecoderBlock2D(
in_channels=prev_output_channel,
out_channels=output_channel,
num_layers=self.layers_per_block + 1,
resnet_groups=self.norm_num_groups,
add_upsample=not is_final_block,
dtype=self.dtype,
)
@@ -607,7 +653,7 @@ class FlaxDecoder(nn.Module):
self.up_blocks = up_blocks
# end
self.conv_norm_out = nn.GroupNorm(num_groups=32, epsilon=1e-6)
self.conv_norm_out = nn.GroupNorm(num_groups=self.norm_num_groups, epsilon=1e-6)
self.conv_out = nn.Conv(
self.out_channels,
kernel_size=(3, 3),

View File

@@ -79,8 +79,10 @@ class OnnxRuntimeModel:
src_path = self.model_save_dir.joinpath(self.latest_model_name)
dst_path = Path(save_directory).joinpath(model_file_name)
if not src_path.samefile(dst_path):
try:
shutil.copyfile(src_path, dst_path)
except shutil.SameFileError:
pass
def save_pretrained(
self,

View File

@@ -62,11 +62,6 @@ for library in LOADABLE_CLASSES:
ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library])
class DummyChecker:
def __init__(self):
self.dummy = True
def import_flax_or_no_model(module, class_name):
try:
# 1. First make sure that if a Flax object is present, import this one
@@ -116,24 +111,27 @@ class FlaxDiffusionPipeline(ConfigMixin):
from diffusers import pipelines
for name, module in kwargs.items():
# retrieve library
library = module.__module__.split(".")[0]
if module is None:
register_dict = {name: (None, None)}
else:
# retrieve library
library = module.__module__.split(".")[0]
# check if the module is a pipeline module
pipeline_dir = module.__module__.split(".")[-2]
path = module.__module__.split(".")
is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir)
# check if the module is a pipeline module
pipeline_dir = module.__module__.split(".")[-2]
path = module.__module__.split(".")
is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir)
# if library is not in LOADABLE_CLASSES, then it is a custom module.
# Or if it's a pipeline module, then the module is inside the pipeline
# folder so we set the library to module name.
if library not in LOADABLE_CLASSES or is_pipeline_module:
library = pipeline_dir
# if library is not in LOADABLE_CLASSES, then it is a custom module.
# Or if it's a pipeline module, then the module is inside the pipeline
# folder so we set the library to module name.
if library not in LOADABLE_CLASSES or is_pipeline_module:
library = pipeline_dir
# retrieve class_name
class_name = module.__class__.__name__
# retrieve class_name
class_name = module.__class__.__name__
register_dict = {name: (library, class_name)}
register_dict = {name: (library, class_name)}
# save model index config
self.register_to_config(**register_dict)
@@ -177,10 +175,6 @@ class FlaxDiffusionPipeline(ConfigMixin):
if save_method_name is not None:
break
# TODO(Patrick, Suraj): to delete after
if isinstance(sub_model, DummyChecker):
continue
save_method = getattr(sub_model, save_method_name)
expects_params = "params" in set(inspect.signature(save_method).parameters.keys())
@@ -194,7 +188,7 @@ class FlaxDiffusionPipeline(ConfigMixin):
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
r"""
Instantiate a PyTorch diffusion pipeline from pre-trained pipeline weights.
Instantiate a Flax diffusion pipeline from pre-trained pipeline weights.
The pipeline is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated).
@@ -329,6 +323,11 @@ class FlaxDiffusionPipeline(ConfigMixin):
pipeline_class = cls
else:
diffusers_module = importlib.import_module(cls.__module__.split(".")[0])
class_name = (
config_dict["_class_name"]
if config_dict["_class_name"].startswith("Flax")
else "Flax" + config_dict["_class_name"]
)
pipeline_class = getattr(diffusers_module, config_dict["_class_name"])
# some modules can be passed directly to the init
@@ -349,13 +348,9 @@ class FlaxDiffusionPipeline(ConfigMixin):
# 3. Load each module in the pipeline
for name, (library_name, class_name) in init_dict.items():
# TODO(Patrick, Suraj) - delete later
if class_name == "DummyChecker":
library_name = "stable_diffusion"
class_name = "FlaxStableDiffusionSafetyChecker"
is_pipeline_module = hasattr(pipelines, library_name)
loaded_sub_model = None
sub_model_should_be_defined = True
# if the model is in a pipeline module, then we load it from the pipeline
if name in passed_class_obj:
@@ -376,6 +371,12 @@ class FlaxDiffusionPipeline(ConfigMixin):
f"{passed_class_obj[name]} is of type: {type(passed_class_obj[name])}, but should be"
f" {expected_class_obj}"
)
elif passed_class_obj[name] is None:
logger.warn(
f"You have passed `None` for {name} to disable its functionality in {pipeline_class}. Note"
f" that this might lead to problems when using {pipeline_class} and is not recommended."
)
sub_model_should_be_defined = False
else:
logger.warn(
f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it"
@@ -386,25 +387,19 @@ class FlaxDiffusionPipeline(ConfigMixin):
loaded_sub_model = passed_class_obj[name]
elif is_pipeline_module:
pipeline_module = getattr(pipelines, library_name)
if from_pt:
class_obj = import_flax_or_no_model(pipeline_module, class_name)
else:
class_obj = getattr(pipeline_module, class_name)
class_obj = import_flax_or_no_model(pipeline_module, class_name)
importable_classes = ALL_IMPORTABLE_CLASSES
class_candidates = {c: class_obj for c in importable_classes.keys()}
else:
# else we just import it from the library.
library = importlib.import_module(library_name)
if from_pt:
class_obj = import_flax_or_no_model(library, class_name)
else:
class_obj = getattr(library, class_name)
class_obj = import_flax_or_no_model(library, class_name)
importable_classes = LOADABLE_CLASSES[library_name]
class_candidates = {c: getattr(library, c) for c in importable_classes.keys()}
if loaded_sub_model is None:
if loaded_sub_model is None and sub_model_should_be_defined:
load_method_name = None
for class_name, class_candidate in class_candidates.items():
if issubclass(class_obj, class_candidate):
@@ -422,11 +417,7 @@ class FlaxDiffusionPipeline(ConfigMixin):
loaded_sub_model, loaded_params = load_method(loadable_folder, from_pt=from_pt, dtype=dtype)
params[name] = loaded_params
elif is_transformers_available() and issubclass(class_obj, FlaxPreTrainedModel):
# make sure we don't initialize the weights to save time
if name == "safety_checker":
loaded_sub_model = DummyChecker()
loaded_params = {}
elif from_pt:
if from_pt:
# TODO(Suraj): Fix this in Transformers. We should be able to use `_do_init=False` here
loaded_sub_model = load_method(loadable_folder, from_pt=from_pt)
loaded_params = loaded_sub_model.params

View File

@@ -29,14 +29,28 @@ from huggingface_hub import snapshot_download
from PIL import Image
from tqdm.auto import tqdm
from . import __version__
from .configuration_utils import ConfigMixin
from .dynamic_modules_utils import get_class_from_dynamic_module
from .schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
from .utils import CONFIG_NAME, DIFFUSERS_CACHE, ONNX_WEIGHTS_NAME, WEIGHTS_NAME, BaseOutput, logging
from .utils import (
CONFIG_NAME,
DIFFUSERS_CACHE,
ONNX_WEIGHTS_NAME,
WEIGHTS_NAME,
BaseOutput,
is_transformers_available,
logging,
)
if is_transformers_available():
from transformers import PreTrainedModel
INDEX_FILE = "diffusion_pytorch_model.bin"
CUSTOM_PIPELINE_FILE_NAME = "pipeline.py"
DUMMY_MODULES_FOLDER = "diffusers.utils"
logger = logging.get_logger(__name__)
@@ -99,23 +113,26 @@ class DiffusionPipeline(ConfigMixin):
for name, module in kwargs.items():
# retrieve library
library = module.__module__.split(".")[0]
if module is None:
register_dict = {name: (None, None)}
else:
library = module.__module__.split(".")[0]
# check if the module is a pipeline module
pipeline_dir = module.__module__.split(".")[-2]
path = module.__module__.split(".")
is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir)
# check if the module is a pipeline module
pipeline_dir = module.__module__.split(".")[-2]
path = module.__module__.split(".")
is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir)
# if library is not in LOADABLE_CLASSES, then it is a custom module.
# Or if it's a pipeline module, then the module is inside the pipeline
# folder so we set the library to module name.
if library not in LOADABLE_CLASSES or is_pipeline_module:
library = pipeline_dir
# if library is not in LOADABLE_CLASSES, then it is a custom module.
# Or if it's a pipeline module, then the module is inside the pipeline
# folder so we set the library to module name.
if library not in LOADABLE_CLASSES or is_pipeline_module:
library = pipeline_dir
# retrieve class_name
class_name = module.__class__.__name__
# retrieve class_name
class_name = module.__class__.__name__
register_dict = {name: (library, class_name)}
register_dict = {name: (library, class_name)}
# save model index config
self.register_to_config(**register_dict)
@@ -338,6 +355,7 @@ class DiffusionPipeline(ConfigMixin):
custom_pipeline = kwargs.pop("custom_pipeline", None)
provider = kwargs.pop("provider", None)
sess_options = kwargs.pop("sess_options", None)
device_map = kwargs.pop("device_map", None)
# 1. Download the checkpoints and configs
# use snapshot download here to get it working from from_pretrained
@@ -359,6 +377,11 @@ class DiffusionPipeline(ConfigMixin):
if custom_pipeline is not None:
allow_patterns += [CUSTOM_PIPELINE_FILE_NAME]
requested_pipeline_class = config_dict.get("_class_name", cls.__name__)
user_agent = {"diffusers": __version__, "pipeline_class": requested_pipeline_class}
if custom_pipeline is not None:
user_agent["custom_pipeline"] = custom_pipeline
# download all allow_patterns
cached_folder = snapshot_download(
pretrained_model_name_or_path,
@@ -369,6 +392,7 @@ class DiffusionPipeline(ConfigMixin):
use_auth_token=use_auth_token,
revision=revision,
allow_patterns=allow_patterns,
user_agent=user_agent,
)
else:
cached_folder = pretrained_model_name_or_path
@@ -408,6 +432,7 @@ class DiffusionPipeline(ConfigMixin):
is_pipeline_module = hasattr(pipelines, library_name)
loaded_sub_model = None
sub_model_should_be_defined = True
# if the model is in a pipeline module, then we load it from the pipeline
if name in passed_class_obj:
@@ -428,6 +453,12 @@ class DiffusionPipeline(ConfigMixin):
f"{passed_class_obj[name]} is of type: {type(passed_class_obj[name])}, but should be"
f" {expected_class_obj}"
)
elif passed_class_obj[name] is None:
logger.warn(
f"You have passed `None` for {name} to disable its functionality in {pipeline_class}. Note"
f" that this might lead to problems when using {pipeline_class} and is not recommended."
)
sub_model_should_be_defined = False
else:
logger.warn(
f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it"
@@ -448,21 +479,39 @@ class DiffusionPipeline(ConfigMixin):
importable_classes = LOADABLE_CLASSES[library_name]
class_candidates = {c: getattr(library, c) for c in importable_classes.keys()}
if loaded_sub_model is None:
if loaded_sub_model is None and sub_model_should_be_defined:
load_method_name = None
for class_name, class_candidate in class_candidates.items():
if issubclass(class_obj, class_candidate):
load_method_name = importable_classes[class_name][1]
load_method = getattr(class_obj, load_method_name)
if load_method_name is None:
none_module = class_obj.__module__
if none_module.startswith(DUMMY_MODULES_FOLDER) and "dummy" in none_module:
# call class_obj for nice error message of missing requirements
class_obj()
raise ValueError(
f"The component {class_obj} of {pipeline_class} cannot be loaded as it does not seem to have"
f" any of the loading methods defined in {ALL_IMPORTABLE_CLASSES}."
)
load_method = getattr(class_obj, load_method_name)
loading_kwargs = {}
if issubclass(class_obj, torch.nn.Module):
loading_kwargs["torch_dtype"] = torch_dtype
if issubclass(class_obj, diffusers.OnnxRuntimeModel):
loading_kwargs["provider"] = provider
loading_kwargs["sess_options"] = sess_options
if (
issubclass(class_obj, diffusers.ModelMixin)
or is_transformers_available()
and issubclass(class_obj, PreTrainedModel)
):
loading_kwargs["device_map"] = device_map
# check if the module is in a subdirectory
if os.path.isdir(os.path.join(cached_folder, name)):
loaded_sub_model = load_method(os.path.join(cached_folder, name), **loading_kwargs)

View File

@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import List, Union
from typing import List, Optional, Union
import numpy as np
@@ -20,11 +20,11 @@ class StableDiffusionPipelineOutput(BaseOutput):
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
nsfw_content_detected (`List[bool]`)
List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content.
(nsfw) content, or `None` if safety checking could not be performed.
"""
images: Union[List[PIL.Image.Image], np.ndarray]
nsfw_content_detected: List[bool]
nsfw_content_detected: Optional[List[bool]]
if is_transformers_available() and is_torch_available():

View File

@@ -1,17 +1,27 @@
from functools import partial
from typing import Dict, List, Optional, Union
import numpy as np
import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict
from flax.jax_utils import unreplicate
from flax.training.common_utils import shard
from PIL import Image
from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel
from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel
from ...pipeline_flax_utils import FlaxDiffusionPipeline
from ...schedulers import FlaxDDIMScheduler, FlaxLMSDiscreteScheduler, FlaxPNDMScheduler
from ...utils import logging
from . import FlaxStableDiffusionPipelineOutput
from .safety_checker_flax import FlaxStableDiffusionSafetyChecker
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
r"""
Pipeline for text-to-image generation using Stable Diffusion.
@@ -52,9 +62,18 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
dtype: jnp.dtype = jnp.float32,
):
super().__init__()
scheduler = scheduler.set_format("np")
self.dtype = dtype
if safety_checker is None:
logger.warn(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
)
self.register_modules(
vae=vae,
text_encoder=text_encoder,
@@ -78,60 +97,44 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
)
return text_input.input_ids
def __call__(
def _get_safety_scores(self, features, params):
special_cos_dist, cos_dist = self.safety_checker(features, params)
return (special_cos_dist, cos_dist)
def _run_safety_checker(self, images, safety_model_params, jit=False):
# safety_model_params should already be replicated when jit is True
pil_images = [Image.fromarray(image) for image in images]
features = self.feature_extractor(pil_images, return_tensors="np").pixel_values
if jit:
features = shard(features)
special_cos_dist, cos_dist = _p_get_safety_scores(self, features, safety_model_params)
special_cos_dist = unshard(special_cos_dist)
cos_dist = unshard(cos_dist)
safety_model_params = unreplicate(safety_model_params)
else:
special_cos_dist, cos_dist = self._get_safety_scores(features, safety_model_params)
images, has_nsfw = self.safety_checker.filtered_with_scores(
special_cos_dist,
cos_dist,
images,
safety_model_params,
)
return images, has_nsfw
def _generate(
self,
prompt_ids: jnp.array,
params: Union[Dict, FrozenDict],
prng_seed: jax.random.PRNGKey,
num_inference_steps: Optional[int] = 50,
height: Optional[int] = 512,
width: Optional[int] = 512,
guidance_scale: Optional[float] = 7.5,
num_inference_steps: int = 50,
height: int = 512,
width: int = 512,
guidance_scale: float = 7.5,
latents: Optional[jnp.array] = None,
return_dict: bool = True,
debug: bool = False,
**kwargs,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
height (`int`, *optional*, defaults to 512):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to 512):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 7.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
generator (`torch.Generator`, *optional*):
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
deterministic.
latents (`jnp.array`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] instead of
a plain tuple.
Returns:
[`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] or `tuple`:
[`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a
`tuple. When returning a tuple, the first element is a list with the generated images, and the second
element is a list of `bool`s denoting whether the corresponding generated image likely represents
"not-safe-for-work" (nsfw) content, according to the `safety_checker`.
"""
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
@@ -171,6 +174,8 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step]
timestep = jnp.broadcast_to(t, latents_input.shape[0])
latents_input = self.scheduler.scale_model_input(scheduler_state, latents_input, t)
# predict the noise residual
noise_pred = self.unet.apply(
{"params": params["unet"]},
@@ -190,6 +195,9 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
params["scheduler"], num_inference_steps=num_inference_steps, shape=latents.shape
)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
if debug:
# run with python for loop
for i in range(num_inference_steps):
@@ -199,21 +207,119 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
# scale and decode the image latents with vae
latents = 1 / 0.18215 * latents
# TODO: check when flax vae gets merged into main
image = self.vae.apply({"params": params["vae"]}, latents, method=self.vae.decode).sample
image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1)
return image
# image = jnp.asarray(image).transpose(0, 2, 3, 1)
# run safety checker
# TODO: check when flax safety checker gets merged into main
# safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="np")
# image, has_nsfw_concept = self.safety_checker(
# images=image, clip_input=safety_checker_input.pixel_values, params=params["safety_params"]
# )
has_nsfw_concept = False
def __call__(
self,
prompt_ids: jnp.array,
params: Union[Dict, FrozenDict],
prng_seed: jax.random.PRNGKey,
num_inference_steps: int = 50,
height: int = 512,
width: int = 512,
guidance_scale: float = 7.5,
latents: jnp.array = None,
return_dict: bool = True,
jit: bool = False,
debug: bool = False,
**kwargs,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
height (`int`, *optional*, defaults to 512):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to 512):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 7.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
generator (`torch.Generator`, *optional*):
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
deterministic.
latents (`jnp.array`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
jit (`bool`, defaults to `False`):
Whether to run `pmap` versions of the generation and safety scoring functions. NOTE: This argument
exists because `__call__` is not yet end-to-end pmap-able. It will be removed in a future release.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] instead of
a plain tuple.
Returns:
[`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] or `tuple`:
[`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a
`tuple. When returning a tuple, the first element is a list with the generated images, and the second
element is a list of `bool`s denoting whether the corresponding generated image likely represents
"not-safe-for-work" (nsfw) content, according to the `safety_checker`.
"""
if jit:
images = _p_generate(
self, prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, debug
)
else:
images = self._generate(
prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, debug
)
if self.safety_checker is not None:
safety_params = params["safety_checker"]
images_uint8_casted = (images * 255).round().astype("uint8")
num_devices, batch_size = images.shape[:2]
images_uint8_casted = np.asarray(images_uint8_casted).reshape(num_devices * batch_size, height, width, 3)
images_uint8_casted, has_nsfw_concept = self._run_safety_checker(images_uint8_casted, safety_params, jit)
images = np.asarray(images)
# block images
if any(has_nsfw_concept):
for i, is_nsfw in enumerate(has_nsfw_concept):
images[i] = np.asarray(images_uint8_casted[i])
images = images.reshape(num_devices, batch_size, height, width, 3)
else:
has_nsfw_concept = False
if not return_dict:
return (image, has_nsfw_concept)
return (images, has_nsfw_concept)
return FlaxStableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
return FlaxStableDiffusionPipelineOutput(images=images, nsfw_content_detected=has_nsfw_concept)
# TODO: maybe use a config dict instead of so many static argnums
@partial(jax.pmap, static_broadcasted_argnums=(0, 4, 5, 6, 7, 9))
def _p_generate(
pipe, prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, debug
):
return pipe._generate(
prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, debug
)
@partial(jax.pmap, static_broadcasted_argnums=(0,))
def _p_get_safety_scores(pipe, features, params):
return pipe._get_safety_scores(features, params)
def unshard(x: jnp.ndarray):
# einops.rearrange(x, 'd b ... -> (d b) ...')
num_devices, batch_size = x.shape[:2]
rest = x.shape[2:]
return x.reshape(num_devices * batch_size, *rest)

View File

@@ -71,6 +71,16 @@ class StableDiffusionPipeline(DiffusionPipeline):
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
if safety_checker is None:
logger.warn(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
)
self.register_modules(
vae=vae,
text_encoder=text_encoder,
@@ -234,8 +244,8 @@ class StableDiffusionPipeline(DiffusionPipeline):
uncond_tokens = [""]
elif type(prompt) is not type(negative_prompt):
raise TypeError(
"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
" {type(prompt)}."
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif isinstance(negative_prompt, str):
uncond_tokens = [negative_prompt]
@@ -331,12 +341,19 @@ class StableDiffusionPipeline(DiffusionPipeline):
image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
if self.safety_checker is not None:
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(
self.device
)
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
)
else:
has_nsfw_concept = None
if output_type == "pil":
image = self.numpy_to_pil(image)

View File

@@ -83,6 +83,16 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
if safety_checker is None:
logger.warn(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
)
self.register_modules(
vae=vae,
text_encoder=text_encoder,
@@ -249,8 +259,8 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
uncond_tokens = [""]
elif type(prompt) is not type(negative_prompt):
raise TypeError(
"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
" {type(prompt)}."
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif isinstance(negative_prompt, str):
uncond_tokens = [negative_prompt]
@@ -284,8 +294,25 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
init_latents = init_latent_dist.sample(generator=generator)
init_latents = 0.18215 * init_latents
# expand init_latents for batch_size
init_latents = torch.cat([init_latents] * batch_size * num_images_per_prompt, dim=0)
if isinstance(prompt, str):
prompt = [prompt]
if len(prompt) > init_latents.shape[0] and len(prompt) % init_latents.shape[0] == 0:
# expand init_latents for batch_size
deprecation_message = (
f"You have passed {len(prompt)} text prompts (`prompt`), but only {init_latents.shape[0]} initial"
" images (`init_image`). Initial images are now duplicating to match the number of text prompts. Note"
" that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
" your script to pass as many init images as text prompts to suppress this warning."
)
deprecate("len(prompt) != len(init_image)", "1.0.0", deprecation_message, standard_warn=False)
additional_image_per_prompt = len(prompt) // init_latents.shape[0]
init_latents = torch.cat([init_latents] * additional_image_per_prompt * num_images_per_prompt, dim=0)
elif len(prompt) > init_latents.shape[0] and len(prompt) % init_latents.shape[0] != 0:
raise ValueError(
f"Cannot duplicate `init_image` of batch size {init_latents.shape[0]} to {len(prompt)} text prompts."
)
else:
init_latents = torch.cat([init_latents] * num_images_per_prompt, dim=0)
# get the original timestep using init_timestep
offset = self.scheduler.config.get("steps_offset", 0)
@@ -342,10 +369,15 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
)
if self.safety_checker is not None:
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(
self.device
)
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
)
else:
has_nsfw_concept = None
if output_type == "pil":
image = self.numpy_to_pil(image)

View File

@@ -98,6 +98,16 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
if safety_checker is None:
logger.warn(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
)
self.register_modules(
vae=vae,
text_encoder=text_encoder,
@@ -266,8 +276,8 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
uncond_tokens = [""]
elif type(prompt) is not type(negative_prompt):
raise TypeError(
"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
" {type(prompt)}."
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif isinstance(negative_prompt, str):
uncond_tokens = [negative_prompt]
@@ -382,8 +392,13 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values)
if self.safety_checker is not None:
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(
self.device
)
image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values)
else:
has_nsfw_concept = None
if output_type == "pil":
image = self.numpy_to_pil(image)

View File

@@ -108,8 +108,8 @@ class StableDiffusionOnnxPipeline(DiffusionPipeline):
uncond_tokens = [""] * batch_size
elif type(prompt) is not type(negative_prompt):
raise TypeError(
"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
" {type(prompt)}."
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif isinstance(negative_prompt, str):
uncond_tokens = [negative_prompt] * batch_size

View File

@@ -19,6 +19,8 @@ def cosine_distance(image_embeds, text_embeds):
class StableDiffusionSafetyChecker(PreTrainedModel):
config_class = CLIPConfig
_no_split_modules = ["CLIPEncoderLayer"]
def __init__(self, config: CLIPConfig):
super().__init__(config)
@@ -28,16 +30,17 @@ class StableDiffusionSafetyChecker(PreTrainedModel):
self.concept_embeds = nn.Parameter(torch.ones(17, config.projection_dim), requires_grad=False)
self.special_care_embeds = nn.Parameter(torch.ones(3, config.projection_dim), requires_grad=False)
self.register_buffer("concept_embeds_weights", torch.ones(17))
self.register_buffer("special_care_embeds_weights", torch.ones(3))
self.concept_embeds_weights = nn.Parameter(torch.ones(17), requires_grad=False)
self.special_care_embeds_weights = nn.Parameter(torch.ones(3), requires_grad=False)
@torch.no_grad()
def forward(self, clip_input, images):
pooled_output = self.vision_model(clip_input)[1] # pooled_output
image_embeds = self.visual_projection(pooled_output)
special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds).cpu().numpy()
cos_dist = cosine_distance(image_embeds, self.concept_embeds).cpu().numpy()
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds).cpu().float().numpy()
cos_dist = cosine_distance(image_embeds, self.concept_embeds).cpu().float().numpy()
result = []
batch_size = image_embeds.shape[0]

View File

@@ -123,7 +123,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
):
deprecate(
"tensor_format",
"0.5.0",
"0.6.0",
"If you're running your code in PyTorch, you can safely remove this argument.",
take_from=kwargs,
)
@@ -192,7 +192,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
the number of diffusion steps used when generating samples with a pre-trained model.
"""
deprecated_offset = deprecate(
"offset", "0.5.0", "Please pass `steps_offset` to `__init__` instead.", take_from=kwargs
"offset", "0.7.0", "Please pass `steps_offset` to `__init__` instead.", take_from=kwargs
)
offset = deprecated_offset or self.config.steps_offset
@@ -283,8 +283,9 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
if eta > 0:
# randn_like does not support generator https://github.com/pytorch/pytorch/issues/27072
device = model_output.device if torch.is_tensor(model_output) else "cpu"
noise = torch.randn(model_output.shape, generator=generator).to(device)
noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator).to(device)
variance = self._get_variance(timestep, prev_timestep) ** (0.5) * eta * noise
prev_sample = prev_sample + variance

View File

@@ -141,6 +141,23 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
# whether we use the final alpha of the "non-previous" one.
self.final_alpha_cumprod = jnp.array(1.0) if set_alpha_to_one else float(self._alphas_cumprod[0])
# standard deviation of the initial noise distribution
self.init_noise_sigma = 1.0
def scale_model_input(
self, state: DDIMSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None
) -> jnp.ndarray:
"""
Args:
state (`PNDMSchedulerState`): the `FlaxPNDMScheduler` state data class instance.
sample (`jnp.ndarray`): input sample
timestep (`int`, optional): current timestep
Returns:
`jnp.ndarray`: scaled input sample
"""
return sample
def create_state(self):
return DDIMSchedulerState.create(
num_train_timesteps=self.config.num_train_timesteps, alphas_cumprod=self._alphas_cumprod

View File

@@ -116,7 +116,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
):
deprecate(
"tensor_format",
"0.5.0",
"0.6.0",
"If you're running your code in PyTorch, you can safely remove this argument.",
take_from=kwargs,
)
@@ -133,6 +133,10 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
elif beta_schedule == "squaredcos_cap_v2":
# Glide cosine schedule
self.betas = betas_for_alpha_bar(num_train_timesteps)
elif beta_schedule == "sigmoid":
# GeoDiff sigmoid schedule
betas = torch.linspace(-6, 6, num_train_timesteps)
self.betas = torch.sigmoid(betas) * (beta_end - beta_start) + beta_start
else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")

View File

@@ -90,7 +90,7 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
):
deprecate(
"tensor_format",
"0.5.0",
"0.6.0",
"If you're running your code in PyTorch, you can safely remove this argument.",
take_from=kwargs,
)

View File

@@ -78,7 +78,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
):
deprecate(
"tensor_format",
"0.5.0",
"0.6.0",
"If you're running your code in PyTorch, you can safely remove this argument.",
take_from=kwargs,
)
@@ -217,7 +217,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
):
deprecate(
"timestep as an index",
"0.5.0",
"0.7.0",
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
" `LMSDiscreteScheduler.step()` will not be supported in future versions. Make sure to pass"
" one of the `scheduler.timesteps` as a timestep.",
@@ -267,7 +267,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
if isinstance(timesteps, torch.IntTensor) or isinstance(timesteps, torch.LongTensor):
deprecate(
"timesteps as indices",
"0.5.0",
"0.7.0",
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
" `LMSDiscreteScheduler.add_noise()` will not be supported in future versions. Make sure to"
" pass values from `scheduler.timesteps` as timesteps.",

View File

@@ -104,7 +104,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
):
deprecate(
"tensor_format",
"0.5.0",
"0.6.0",
"If you're running your code in PyTorch, you can safely remove this argument.",
take_from=kwargs,
)
@@ -159,7 +159,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
the number of diffusion steps used when generating samples with a pre-trained model.
"""
deprecated_offset = deprecate(
"offset", "0.5.0", "Please pass `steps_offset` to `__init__` instead.", take_from=kwargs
"offset", "0.7.0", "Please pass `steps_offset` to `__init__` instead.", take_from=kwargs
)
offset = deprecated_offset or self.config.steps_offset

View File

@@ -153,6 +153,9 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin):
# mainly at formula (9), (12), (13) and the Algorithm 2.
self.pndm_order = 4
# standard deviation of the initial noise distribution
self.init_noise_sigma = 1.0
def create_state(self):
return PNDMSchedulerState.create(num_train_timesteps=self.config.num_train_timesteps)
@@ -196,7 +199,7 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin):
)
return state.replace(
timesteps=jnp.concatenate([state.prk_timesteps, state.plms_timesteps]).astype(jnp.int64),
timesteps=jnp.concatenate([state.prk_timesteps, state.plms_timesteps]).astype(jnp.int32),
counter=0,
# Reserve space for the state variables
cur_model_output=jnp.zeros(shape),
@@ -204,6 +207,23 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin):
ets=jnp.zeros((4,) + shape),
)
def scale_model_input(
self, state: PNDMSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None
) -> jnp.ndarray:
"""
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep.
Args:
state (`PNDMSchedulerState`): the `FlaxPNDMScheduler` state data class instance.
sample (`jnp.ndarray`): input sample
timestep (`int`, optional): current timestep
Returns:
`jnp.ndarray`: scaled input sample
"""
return sample
def step(
self,
state: PNDMSchedulerState,

View File

@@ -79,7 +79,7 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
):
deprecate(
"tensor_format",
"0.5.0",
"0.6.0",
"If you're running your code in PyTorch, you can safely remove this argument.",
take_from=kwargs,
)
@@ -156,10 +156,6 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
self.discrete_sigmas[timesteps - 1].to(timesteps.device),
)
def set_seed(self, seed):
deprecate("set_seed", "0.5.0", "Please consider passing a generator instead.")
torch.manual_seed(seed)
def step_pred(
self,
model_output: torch.FloatTensor,
@@ -167,7 +163,6 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
sample: torch.FloatTensor,
generator: Optional[torch.Generator] = None,
return_dict: bool = True,
**kwargs,
) -> Union[SdeVeOutput, Tuple]:
"""
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
@@ -186,9 +181,6 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
`return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
"""
if "seed" in kwargs and kwargs["seed"] is not None:
self.set_seed(kwargs["seed"])
if self.timesteps is None:
raise ValueError(
"`self.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler"
@@ -231,7 +223,6 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
sample: torch.FloatTensor,
generator: Optional[torch.Generator] = None,
return_dict: bool = True,
**kwargs,
) -> Union[SchedulerOutput, Tuple]:
"""
Correct the predicted sample based on the output model_output of the network. This is often run repeatedly
@@ -249,9 +240,6 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
`return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
"""
if "seed" in kwargs and kwargs["seed"] is not None:
self.set_seed(kwargs["seed"])
if self.timesteps is None:
raise ValueError(
"`self.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler"

View File

@@ -43,7 +43,7 @@ class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin):
def __init__(self, num_train_timesteps=2000, beta_min=0.1, beta_max=20, sampling_eps=1e-3, **kwargs):
deprecate(
"tensor_format",
"0.5.0",
"0.6.0",
"If you're running your code in PyTorch, you can safely remove this argument.",
take_from=kwargs,
)

View File

@@ -45,7 +45,7 @@ class SchedulerMixin:
def set_format(self, tensor_format="pt"):
deprecate(
"set_format",
"0.5.0",
"0.6.0",
"If you're running your code in PyTorch, you can safely remove this function as the schedulers are always"
" in Pytorch",
)

View File

@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from dataclasses import dataclass
import jax.numpy as jnp
@@ -42,12 +41,3 @@ class FlaxSchedulerMixin:
"""
config_name = SCHEDULER_CONFIG_NAME
def set_format(self, tensor_format="pt"):
warnings.warn(
"The method `set_format` is deprecated and will be removed in version `0.5.0`."
"If you're running your code in PyTorch, you can safely remove this function as the schedulers"
"are always in Pytorch",
DeprecationWarning,
)
return self

View File

@@ -9,3 +9,11 @@ class FlaxStableDiffusionPipeline(metaclass=DummyObject):
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["flax", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax", "transformers"])

View File

@@ -10,6 +10,14 @@ class FlaxModelMixin(metaclass=DummyObject):
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
class FlaxUNet2DConditionModel(metaclass=DummyObject):
_backends = ["flax"]
@@ -17,6 +25,14 @@ class FlaxUNet2DConditionModel(metaclass=DummyObject):
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
class FlaxAutoencoderKL(metaclass=DummyObject):
_backends = ["flax"]
@@ -24,6 +40,14 @@ class FlaxAutoencoderKL(metaclass=DummyObject):
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
class FlaxDiffusionPipeline(metaclass=DummyObject):
_backends = ["flax"]
@@ -31,6 +55,14 @@ class FlaxDiffusionPipeline(metaclass=DummyObject):
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
class FlaxDDIMScheduler(metaclass=DummyObject):
_backends = ["flax"]
@@ -38,6 +70,14 @@ class FlaxDDIMScheduler(metaclass=DummyObject):
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
class FlaxDDPMScheduler(metaclass=DummyObject):
_backends = ["flax"]
@@ -45,6 +85,14 @@ class FlaxDDPMScheduler(metaclass=DummyObject):
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
class FlaxKarrasVeScheduler(metaclass=DummyObject):
_backends = ["flax"]
@@ -52,6 +100,14 @@ class FlaxKarrasVeScheduler(metaclass=DummyObject):
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
class FlaxLMSDiscreteScheduler(metaclass=DummyObject):
_backends = ["flax"]
@@ -59,6 +115,14 @@ class FlaxLMSDiscreteScheduler(metaclass=DummyObject):
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
class FlaxPNDMScheduler(metaclass=DummyObject):
_backends = ["flax"]
@@ -66,6 +130,14 @@ class FlaxPNDMScheduler(metaclass=DummyObject):
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
class FlaxSchedulerMixin(metaclass=DummyObject):
_backends = ["flax"]
@@ -73,9 +145,25 @@ class FlaxSchedulerMixin(metaclass=DummyObject):
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
class FlaxScoreSdeVeScheduler(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])

View File

@@ -10,6 +10,14 @@ class ModelMixin(metaclass=DummyObject):
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class AutoencoderKL(metaclass=DummyObject):
_backends = ["torch"]
@@ -17,6 +25,14 @@ class AutoencoderKL(metaclass=DummyObject):
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class UNet2DConditionModel(metaclass=DummyObject):
_backends = ["torch"]
@@ -24,6 +40,14 @@ class UNet2DConditionModel(metaclass=DummyObject):
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class UNet2DModel(metaclass=DummyObject):
_backends = ["torch"]
@@ -31,6 +55,14 @@ class UNet2DModel(metaclass=DummyObject):
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class VQModel(metaclass=DummyObject):
_backends = ["torch"]
@@ -38,6 +70,14 @@ class VQModel(metaclass=DummyObject):
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
def get_constant_schedule(*args, **kwargs):
requires_backends(get_constant_schedule, ["torch"])
@@ -73,6 +113,14 @@ class DiffusionPipeline(metaclass=DummyObject):
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class DDIMPipeline(metaclass=DummyObject):
_backends = ["torch"]
@@ -80,6 +128,14 @@ class DDIMPipeline(metaclass=DummyObject):
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class DDPMPipeline(metaclass=DummyObject):
_backends = ["torch"]
@@ -87,6 +143,14 @@ class DDPMPipeline(metaclass=DummyObject):
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class KarrasVePipeline(metaclass=DummyObject):
_backends = ["torch"]
@@ -94,6 +158,14 @@ class KarrasVePipeline(metaclass=DummyObject):
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class LDMPipeline(metaclass=DummyObject):
_backends = ["torch"]
@@ -101,6 +173,14 @@ class LDMPipeline(metaclass=DummyObject):
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class PNDMPipeline(metaclass=DummyObject):
_backends = ["torch"]
@@ -108,6 +188,14 @@ class PNDMPipeline(metaclass=DummyObject):
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class ScoreSdeVePipeline(metaclass=DummyObject):
_backends = ["torch"]
@@ -115,6 +203,14 @@ class ScoreSdeVePipeline(metaclass=DummyObject):
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class DDIMScheduler(metaclass=DummyObject):
_backends = ["torch"]
@@ -122,6 +218,14 @@ class DDIMScheduler(metaclass=DummyObject):
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class DDPMScheduler(metaclass=DummyObject):
_backends = ["torch"]
@@ -129,6 +233,14 @@ class DDPMScheduler(metaclass=DummyObject):
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class KarrasVeScheduler(metaclass=DummyObject):
_backends = ["torch"]
@@ -136,6 +248,14 @@ class KarrasVeScheduler(metaclass=DummyObject):
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class PNDMScheduler(metaclass=DummyObject):
_backends = ["torch"]
@@ -143,6 +263,14 @@ class PNDMScheduler(metaclass=DummyObject):
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class SchedulerMixin(metaclass=DummyObject):
_backends = ["torch"]
@@ -150,6 +278,14 @@ class SchedulerMixin(metaclass=DummyObject):
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class ScoreSdeVeScheduler(metaclass=DummyObject):
_backends = ["torch"]
@@ -157,9 +293,25 @@ class ScoreSdeVeScheduler(metaclass=DummyObject):
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class EMAModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])

View File

@@ -9,3 +9,11 @@ class LMSDiscreteScheduler(metaclass=DummyObject):
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "scipy"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "scipy"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "scipy"])

View File

@@ -9,3 +9,11 @@ class StableDiffusionOnnxPipeline(metaclass=DummyObject):
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers", "onnx"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers", "onnx"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers", "onnx"])

View File

@@ -10,6 +10,14 @@ class LDMTextToImagePipeline(metaclass=DummyObject):
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class StableDiffusionImg2ImgPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
@@ -17,6 +25,14 @@ class StableDiffusionImg2ImgPipeline(metaclass=DummyObject):
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class StableDiffusionInpaintPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
@@ -24,9 +40,25 @@ class StableDiffusionInpaintPipeline(metaclass=DummyObject):
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class StableDiffusionPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])

View File

@@ -7,20 +7,27 @@ from distutils.util import strtobool
from pathlib import Path
from typing import Union
import torch
import PIL.Image
import PIL.ImageOps
import requests
from packaging import version
from .import_utils import is_flax_available, is_torch_available
global_rng = random.Random()
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
is_torch_higher_equal_than_1_12 = version.parse(version.parse(torch.__version__).base_version) >= version.parse("1.12")
if is_torch_higher_equal_than_1_12:
torch_device = "mps" if torch.backends.mps.is_available() else torch_device
if is_torch_available():
import torch
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
is_torch_higher_equal_than_1_12 = version.parse(version.parse(torch.__version__).base_version) >= version.parse(
"1.12"
)
if is_torch_higher_equal_than_1_12:
torch_device = "mps" if torch.backends.mps.is_available() else torch_device
def get_tests_dir(append_path=None):
@@ -89,6 +96,13 @@ def slow(test_case):
return unittest.skipUnless(_run_slow_tests, "test is slow")(test_case)
def require_flax(test_case):
"""
Decorator marking a test that requires JAX & Flax. These tests are skipped when one / both are not installed
"""
return unittest.skipUnless(is_flax_available(), "test requires JAX & Flax")(test_case)
def load_image(image: Union[str, PIL.Image.Image]) -> PIL.Image.Image:
"""
Args:

View File

@@ -0,0 +1,44 @@
from diffusers.utils import is_flax_available
from diffusers.utils.testing_utils import require_flax
if is_flax_available():
import jax
@require_flax
class FlaxModelTesterMixin:
def test_output(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
variables = model.init(inputs_dict["prng_key"], inputs_dict["sample"])
jax.lax.stop_gradient(variables)
output = model.apply(variables, inputs_dict["sample"])
if isinstance(output, dict):
output = output.sample
self.assertIsNotNone(output)
expected_shape = inputs_dict["sample"].shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
def test_forward_with_norm_groups(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["norm_num_groups"] = 16
init_dict["block_out_channels"] = (16, 32)
model = self.model_class(**init_dict)
variables = model.init(inputs_dict["prng_key"], inputs_dict["sample"])
jax.lax.stop_gradient(variables)
output = model.apply(variables, inputs_dict["sample"])
if isinstance(output, dict):
output = output.sample
self.assertIsNotNone(output)
expected_shape = inputs_dict["sample"].shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")

View File

@@ -273,37 +273,39 @@ class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase):
model = self.model_class(**init_dict)
model.to(torch_device)
assert not model.is_gradient_checkpointing and model.training
out = model(**inputs_dict).sample
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
# we won't calculate the loss and rather backprop on out.sum()
model.zero_grad()
out.sum().backward()
# now we save the output and parameter gradients that we will use for comparison purposes with
# the non-checkpointed run.
output_not_checkpointed = out.data.clone()
grad_not_checkpointed = {}
for name, param in model.named_parameters():
grad_not_checkpointed[name] = param.grad.data.clone()
labels = torch.randn_like(out)
loss = (out - labels).mean()
loss.backward()
model.enable_gradient_checkpointing()
out = model(**inputs_dict).sample
# re-instantiate the model now enabling gradient checkpointing
model_2 = self.model_class(**init_dict)
# clone model
model_2.load_state_dict(model.state_dict())
model_2.to(torch_device)
model_2.enable_gradient_checkpointing()
assert model_2.is_gradient_checkpointing and model_2.training
out_2 = model_2(**inputs_dict).sample
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
# we won't calculate the loss and rather backprop on out.sum()
model.zero_grad()
out.sum().backward()
# now we save the output and parameter gradients that we will use for comparison purposes with
# the non-checkpointed run.
output_checkpointed = out.data.clone()
grad_checkpointed = {}
for name, param in model.named_parameters():
grad_checkpointed[name] = param.grad.data.clone()
model_2.zero_grad()
loss_2 = (out_2 - labels).mean()
loss_2.backward()
# compare the output and parameters gradients
self.assertTrue((output_checkpointed == output_not_checkpointed).all())
for name in grad_checkpointed:
self.assertTrue(torch.allclose(grad_checkpointed[name], grad_not_checkpointed[name], atol=5e-5))
self.assertTrue((loss - loss_2).abs() < 1e-5)
named_params = dict(model.named_parameters())
named_params_2 = dict(model_2.named_parameters())
for name, param in named_params.items():
self.assertTrue(torch.allclose(param.grad.data, named_params_2[name].grad.data, atol=5e-5))
# TODO(Patrick) - Re-add this test after having cleaned up LDM

View File

@@ -0,0 +1,39 @@
import unittest
from diffusers import FlaxAutoencoderKL
from diffusers.utils import is_flax_available
from diffusers.utils.testing_utils import require_flax
from .test_modeling_common_flax import FlaxModelTesterMixin
if is_flax_available():
import jax
@require_flax
class FlaxAutoencoderKLTests(FlaxModelTesterMixin, unittest.TestCase):
model_class = FlaxAutoencoderKL
@property
def dummy_input(self):
batch_size = 4
num_channels = 3
sizes = (32, 32)
prng_key = jax.random.PRNGKey(0)
image = jax.random.uniform(prng_key, ((batch_size, num_channels) + sizes))
return {"sample": image, "prng_key": prng_key}
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"block_out_channels": [32, 64],
"in_channels": 3,
"out_channels": 3,
"down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"],
"up_block_types": ["UpDecoderBlock2D", "UpDecoderBlock2D"],
"latent_channels": 4,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict

View File

@@ -17,12 +17,15 @@ import gc
import os
import random
import tempfile
import tracemalloc
import unittest
import numpy as np
import torch
import accelerate
import PIL
import transformers
from diffusers import (
AutoencoderKL,
DDIMPipeline,
@@ -50,6 +53,7 @@ from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME, floats_tensor, load_image, slow, torch_device
from diffusers.utils.testing_utils import get_tests_dir
from packaging import version
from PIL import Image
from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextConfig, CLIPTextModel, CLIPTokenizer
@@ -488,6 +492,23 @@ class PipelineFastTests(unittest.TestCase):
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
def test_from_pretrained_error_message_uninstalled_packages(self):
# TODO(Patrick, Pedro) - need better test here for the future
pipe = StableDiffusionPipeline.from_pretrained("hf-internal-testing/tiny-stable-diffusion-lms-pipe")
assert isinstance(pipe, StableDiffusionPipeline)
assert isinstance(pipe.scheduler, LMSDiscreteScheduler)
def test_stable_diffusion_no_safety_checker(self):
pipe = StableDiffusionPipeline.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-lms-pipe", safety_checker=None
)
assert isinstance(pipe, StableDiffusionPipeline)
assert isinstance(pipe.scheduler, LMSDiscreteScheduler)
assert pipe.safety_checker is None
image = pipe("example prompt", num_inference_steps=2).images[0]
assert image is not None
def test_stable_diffusion_k_lms(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
unet = self.dummy_cond_unet
@@ -565,6 +586,46 @@ class PipelineFastTests(unittest.TestCase):
assert np.abs(output_2.images.flatten() - output_1.images.flatten()).max() < 1e-4
def test_stable_diffusion_negative_prompt(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
unet = self.dummy_cond_unet
scheduler = PNDMScheduler(skip_prk_steps=True)
vae = self.dummy_vae
bert = self.dummy_text_encoder
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
# make sure here that pndm scheduler skips prk
sd_pipe = StableDiffusionPipeline(
unet=unet,
scheduler=scheduler,
vae=vae,
text_encoder=bert,
tokenizer=tokenizer,
safety_checker=self.dummy_safety_checker,
feature_extractor=self.dummy_extractor,
)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
prompt = "A painting of a squirrel eating a burger"
negative_prompt = "french fries"
generator = torch.Generator(device=device).manual_seed(0)
output = sd_pipe(
prompt,
negative_prompt=negative_prompt,
generator=generator,
guidance_scale=6.0,
num_inference_steps=2,
output_type="np",
)
image = output.images
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 128, 128, 3)
expected_slice = np.array([0.4851, 0.4617, 0.4765, 0.5127, 0.4845, 0.5153, 0.5141, 0.4886, 0.4719])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
def test_score_sde_ve_pipeline(self):
unet = self.dummy_uncond_unet
scheduler = ScoreSdeVeScheduler()
@@ -694,6 +755,90 @@ class PipelineFastTests(unittest.TestCase):
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
def test_stable_diffusion_img2img_negative_prompt(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
unet = self.dummy_cond_unet
scheduler = PNDMScheduler(skip_prk_steps=True)
vae = self.dummy_vae
bert = self.dummy_text_encoder
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
init_image = self.dummy_image.to(device)
# make sure here that pndm scheduler skips prk
sd_pipe = StableDiffusionImg2ImgPipeline(
unet=unet,
scheduler=scheduler,
vae=vae,
text_encoder=bert,
tokenizer=tokenizer,
safety_checker=self.dummy_safety_checker,
feature_extractor=self.dummy_extractor,
)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
prompt = "A painting of a squirrel eating a burger"
negative_prompt = "french fries"
generator = torch.Generator(device=device).manual_seed(0)
output = sd_pipe(
prompt,
negative_prompt=negative_prompt,
generator=generator,
guidance_scale=6.0,
num_inference_steps=2,
output_type="np",
init_image=init_image,
)
image = output.images
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 32, 32, 3)
expected_slice = np.array([0.4065, 0.3783, 0.4050, 0.5266, 0.4781, 0.4252, 0.4203, 0.4692, 0.4365])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
def test_stable_diffusion_img2img_multiple_init_images(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
unet = self.dummy_cond_unet
scheduler = PNDMScheduler(skip_prk_steps=True)
vae = self.dummy_vae
bert = self.dummy_text_encoder
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
init_image = self.dummy_image.to(device).repeat(2, 1, 1, 1)
# make sure here that pndm scheduler skips prk
sd_pipe = StableDiffusionImg2ImgPipeline(
unet=unet,
scheduler=scheduler,
vae=vae,
text_encoder=bert,
tokenizer=tokenizer,
safety_checker=self.dummy_safety_checker,
feature_extractor=self.dummy_extractor,
)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
prompt = 2 * ["A painting of a squirrel eating a burger"]
generator = torch.Generator(device=device).manual_seed(0)
output = sd_pipe(
prompt,
generator=generator,
guidance_scale=6.0,
num_inference_steps=2,
output_type="np",
init_image=init_image,
)
image = output.images
image_slice = image[-1, -3:, -3:, -1]
assert image.shape == (2, 32, 32, 3)
expected_slice = np.array([0.5144, 0.4447, 0.4735, 0.6676, 0.5526, 0.5454, 0.645, 0.5149, 0.4689])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
def test_stable_diffusion_img2img_k_lms(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
unet = self.dummy_cond_unet
@@ -809,6 +954,52 @@ class PipelineFastTests(unittest.TestCase):
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
def test_stable_diffusion_inpaint_negative_prompt(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
unet = self.dummy_cond_unet
scheduler = PNDMScheduler(skip_prk_steps=True)
vae = self.dummy_vae
bert = self.dummy_text_encoder
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0]
init_image = Image.fromarray(np.uint8(image)).convert("RGB")
mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((128, 128))
# make sure here that pndm scheduler skips prk
sd_pipe = StableDiffusionInpaintPipeline(
unet=unet,
scheduler=scheduler,
vae=vae,
text_encoder=bert,
tokenizer=tokenizer,
safety_checker=self.dummy_safety_checker,
feature_extractor=self.dummy_extractor,
)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
prompt = "A painting of a squirrel eating a burger"
negative_prompt = "french fries"
generator = torch.Generator(device=device).manual_seed(0)
output = sd_pipe(
prompt,
negative_prompt=negative_prompt,
generator=generator,
guidance_scale=6.0,
num_inference_steps=2,
output_type="np",
init_image=init_image,
mask_image=mask_image,
)
image = output.images
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 32, 32, 3)
expected_slice = np.array([0.4765, 0.5339, 0.4541, 0.6240, 0.5439, 0.4055, 0.5503, 0.5891, 0.5150])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
def test_stable_diffusion_num_images_per_prompt(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
unet = self.dummy_cond_unet
@@ -1851,13 +2042,21 @@ class PipelineTesterMixin(unittest.TestCase):
[1.8285, 1.2857, -0.1024, 1.2406, -2.3068, 1.0747, -0.0818, -0.6520, -2.9506]
)
assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3
elif step == 50:
latents = latents.detach().cpu().numpy()
assert latents.shape == (1, 4, 64, 64)
latents_slice = latents[0, -3:, -3:, -1]
expected_slice = np.array(
[1.1078, 1.5803, 0.2773, -0.0589, -1.7928, -0.3665, -0.4695, -1.0727, -1.1601]
)
assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-2
test_callback_fn.has_been_called = False
pipe = StableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16
)
pipe.to(torch_device)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
pipe.enable_attention_slicing()
@@ -1891,6 +2090,12 @@ class PipelineTesterMixin(unittest.TestCase):
latents_slice = latents[0, -3:, -3:, -1]
expected_slice = np.array([0.9052, -0.0184, 0.4810, 0.2898, 0.5851, 1.4920, 0.5362, 1.9838, 0.0530])
assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3
elif step == 37:
latents = latents.detach().cpu().numpy()
assert latents.shape == (1, 4, 64, 96)
latents_slice = latents[0, -3:, -3:, -1]
expected_slice = np.array([0.7071, 0.7831, 0.8300, 1.8140, 1.7840, 1.9402, 1.3651, 1.6590, 1.2828])
assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-2
test_callback_fn.has_been_called = False
@@ -1941,6 +2146,12 @@ class PipelineTesterMixin(unittest.TestCase):
[-0.5472, 1.1218, -0.5505, -0.9390, -1.0794, 0.4063, 0.5158, 0.6429, -1.5246]
)
assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3
elif step == 37:
latents = latents.detach().cpu().numpy()
assert latents.shape == (1, 4, 64, 64)
latents_slice = latents[0, -3:, -3:, -1]
expected_slice = np.array([0.4781, 1.1572, 0.6258, 0.2291, 0.2554, -0.1443, 0.7085, -0.1598, -0.5659])
assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3
test_callback_fn.has_been_called = False
@@ -1993,6 +2204,13 @@ class PipelineTesterMixin(unittest.TestCase):
[-0.5950, -0.3039, -1.1672, 0.1594, -1.1572, 0.6719, -1.9712, -0.0403, 0.9592]
)
assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3
elif step == 5:
assert latents.shape == (1, 4, 64, 64)
latents_slice = latents[0, -3:, -3:, -1]
expected_slice = np.array(
[-0.4776, -0.0119, -0.8519, -0.0275, -0.9764, 0.9820, -0.3843, 0.3788, 1.2264]
)
assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3
test_callback_fn.has_been_called = False
@@ -2007,3 +2225,53 @@ class PipelineTesterMixin(unittest.TestCase):
pipe(prompt=prompt, num_inference_steps=5, guidance_scale=7.5, callback=test_callback_fn, callback_steps=1)
assert test_callback_fn.has_been_called
assert number_of_steps == 6
@slow
@unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
def test_stable_diffusion_accelerate_load_works(self):
if version.parse(version.parse(transformers.__version__).base_version) < version.parse("4.23"):
return
if version.parse(version.parse(accelerate.__version__).base_version) < version.parse("0.14"):
return
model_id = "CompVis/stable-diffusion-v1-4"
_ = StableDiffusionPipeline.from_pretrained(
model_id, revision="fp16", torch_dtype=torch.float16, use_auth_token=True, device_map="auto"
).to(torch_device)
@slow
@unittest.skipIf(torch_device == "cpu", "This test is supposed to run on GPU")
def test_stable_diffusion_accelerate_load_reduces_memory_footprint(self):
if version.parse(version.parse(transformers.__version__).base_version) < version.parse("4.23"):
return
if version.parse(version.parse(accelerate.__version__).base_version) < version.parse("0.14"):
return
pipeline_id = "CompVis/stable-diffusion-v1-4"
torch.cuda.empty_cache()
gc.collect()
tracemalloc.start()
pipeline_normal_load = StableDiffusionPipeline.from_pretrained(
pipeline_id, revision="fp16", torch_dtype=torch.float16, use_auth_token=True
)
pipeline_normal_load.to(torch_device)
_, peak_normal = tracemalloc.get_traced_memory()
tracemalloc.stop()
del pipeline_normal_load
torch.cuda.empty_cache()
gc.collect()
tracemalloc.start()
_ = StableDiffusionPipeline.from_pretrained(
pipeline_id, revision="fp16", torch_dtype=torch.float16, use_auth_token=True, device_map="auto"
)
_, peak_accelerate = tracemalloc.get_traced_memory()
tracemalloc.stop()
assert peak_accelerate < peak_normal

View File

@@ -0,0 +1,201 @@
# coding=utf-8
# Copyright 2022 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from diffusers.utils import is_flax_available
from diffusers.utils.testing_utils import require_flax, slow
if is_flax_available():
import jax
import jax.numpy as jnp
from diffusers import FlaxDDIMScheduler, FlaxStableDiffusionPipeline
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from jax import pmap
@require_flax
@slow
class FlaxPipelineTests(unittest.TestCase):
def test_dummy_all_tpus(self):
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None
)
prompt = (
"A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of"
" field, close up, split lighting, cinematic"
)
prng_seed = jax.random.PRNGKey(0)
num_inference_steps = 4
num_samples = jax.device_count()
prompt = num_samples * [prompt]
prompt_ids = pipeline.prepare_inputs(prompt)
p_sample = pmap(pipeline.__call__, static_broadcasted_argnums=(3,))
# shard inputs and rng
params = replicate(params)
prng_seed = jax.random.split(prng_seed, 8)
prompt_ids = shard(prompt_ids)
images = p_sample(prompt_ids, params, prng_seed, num_inference_steps).images
assert images.shape == (8, 1, 64, 64, 3)
assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 4.151474)) < 1e-3
assert np.abs((np.abs(images, dtype=np.float32).sum() - 49947.875)) < 5e-1
images_pil = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
assert len(images_pil) == 8
def test_stable_diffusion_v1_4(self):
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4", revision="flax", safety_checker=None
)
prompt = (
"A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of"
" field, close up, split lighting, cinematic"
)
prng_seed = jax.random.PRNGKey(0)
num_inference_steps = 50
num_samples = jax.device_count()
prompt = num_samples * [prompt]
prompt_ids = pipeline.prepare_inputs(prompt)
p_sample = pmap(pipeline.__call__, static_broadcasted_argnums=(3,))
# shard inputs and rng
params = replicate(params)
prng_seed = jax.random.split(prng_seed, 8)
prompt_ids = shard(prompt_ids)
images = p_sample(prompt_ids, params, prng_seed, num_inference_steps).images
assert images.shape == (8, 1, 512, 512, 3)
assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 0.05652401)) < 1e-3
assert np.abs((np.abs(images, dtype=np.float32).sum() - 2383808.2)) < 5e-1
def test_stable_diffusion_v1_4_bfloat_16(self):
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4", revision="bf16", dtype=jnp.bfloat16, safety_checker=None
)
prompt = (
"A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of"
" field, close up, split lighting, cinematic"
)
prng_seed = jax.random.PRNGKey(0)
num_inference_steps = 50
num_samples = jax.device_count()
prompt = num_samples * [prompt]
prompt_ids = pipeline.prepare_inputs(prompt)
p_sample = pmap(pipeline.__call__, static_broadcasted_argnums=(3,))
# shard inputs and rng
params = replicate(params)
prng_seed = jax.random.split(prng_seed, 8)
prompt_ids = shard(prompt_ids)
images = p_sample(prompt_ids, params, prng_seed, num_inference_steps).images
assert images.shape == (8, 1, 512, 512, 3)
assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 0.06652832)) < 1e-3
assert np.abs((np.abs(images, dtype=np.float32).sum() - 2384849.8)) < 5e-1
def test_stable_diffusion_v1_4_bfloat_16_with_safety(self):
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4", revision="bf16", dtype=jnp.bfloat16
)
prompt = (
"A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of"
" field, close up, split lighting, cinematic"
)
prng_seed = jax.random.PRNGKey(0)
num_inference_steps = 50
num_samples = jax.device_count()
prompt = num_samples * [prompt]
prompt_ids = pipeline.prepare_inputs(prompt)
# shard inputs and rng
params = replicate(params)
prng_seed = jax.random.split(prng_seed, 8)
prompt_ids = shard(prompt_ids)
images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
assert images.shape == (8, 1, 512, 512, 3)
assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 0.06652832)) < 1e-3
assert np.abs((np.abs(images, dtype=np.float32).sum() - 2384849.8)) < 5e-1
def test_stable_diffusion_v1_4_bfloat_16_ddim(self):
scheduler = FlaxDDIMScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
set_alpha_to_one=False,
steps_offset=1,
)
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
revision="bf16",
dtype=jnp.bfloat16,
scheduler=scheduler,
safety_checker=None,
)
scheduler_state = scheduler.create_state()
params["scheduler"] = scheduler_state
prompt = (
"A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of"
" field, close up, split lighting, cinematic"
)
prng_seed = jax.random.PRNGKey(0)
num_inference_steps = 50
num_samples = jax.device_count()
prompt = num_samples * [prompt]
prompt_ids = pipeline.prepare_inputs(prompt)
p_sample = pmap(pipeline.__call__, static_broadcasted_argnums=(3,))
# shard inputs and rng
params = replicate(params)
prng_seed = jax.random.split(prng_seed, 8)
prompt_ids = shard(prompt_ids)
images = p_sample(prompt_ids, params, prng_seed, num_inference_steps).images
assert images.shape == (8, 1, 512, 512, 3)
assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 0.045043945)) < 1e-3
assert np.abs((np.abs(images, dtype=np.float32).sum() - 2347693.5)) < 5e-1

View File

@@ -38,6 +38,14 @@ class {0}(metaclass=DummyObject):
def __init__(self, *args, **kwargs):
requires_backends(self, {1})
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, {1})
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, {1})
"""