mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-07 21:14:44 +08:00
Compare commits
43 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0679d09083 | ||
|
|
1d51224403 | ||
|
|
7c2262640b | ||
|
|
78db11dbf3 | ||
|
|
e713346ad1 | ||
|
|
26c7df5d82 | ||
|
|
e001fededf | ||
|
|
0a09af2f0a | ||
|
|
f1d4289be8 | ||
|
|
323a9e1f6d | ||
|
|
60c384bcd2 | ||
|
|
008b608f15 | ||
|
|
5afc2b60cd | ||
|
|
96598639c0 | ||
|
|
80be0744a6 | ||
|
|
679c77f8ea | ||
|
|
db47b1e4d9 | ||
|
|
966e2fc461 | ||
|
|
6bc11782b7 | ||
|
|
c1b6ea3dce | ||
|
|
24b8b5cf5e | ||
|
|
757babfcad | ||
|
|
e895952816 | ||
|
|
a124204490 | ||
|
|
66a5279a94 | ||
|
|
797b290ed0 | ||
|
|
81bdbb5e2a | ||
|
|
71ca10c6a4 | ||
|
|
22963ed826 | ||
|
|
fab17528da | ||
|
|
feaa73243d | ||
|
|
a73f8b7251 | ||
|
|
5af6eed9ee | ||
|
|
f3983d16ee | ||
|
|
92d7086366 | ||
|
|
ec831b6a72 | ||
|
|
cb0bf0bd0b | ||
|
|
e0fece2b26 | ||
|
|
75bb6d2d46 | ||
|
|
906e4105d7 | ||
|
|
7258dc4943 | ||
|
|
c93a8cc901 | ||
|
|
9a95414ea1 |
@@ -1 +1,2 @@
|
||||
include LICENSE
|
||||
include src/diffusers/utils/model_card_template.md
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
<p align="center">
|
||||
<br>
|
||||
<img src="docs/source/imgs/diffusers_library.jpg" width="400"/>
|
||||
<img src="https://github.com/huggingface/diffusers/raw/main/docs/source/imgs/diffusers_library.jpg" width="400"/>
|
||||
<br>
|
||||
<p>
|
||||
<p align="center">
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -91,24 +91,24 @@ class MyPipeline(DiffusionPipeline):
|
||||
# Sample gaussian noise to begin loop
|
||||
image = torch.randn((batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size))
|
||||
|
||||
image = image.to(self.device)
|
||||
image = image.to(self.device)
|
||||
|
||||
# set step values
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
# set step values
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
for t in self.progress_bar(self.scheduler.timesteps):
|
||||
# 1. predict noise model_output
|
||||
model_output = self.unet(image, t).sample
|
||||
for t in self.progress_bar(self.scheduler.timesteps):
|
||||
# 1. predict noise model_output
|
||||
model_output = self.unet(image, t).sample
|
||||
|
||||
# 2. predict previous mean of image x_t-1 and add variance depending on eta
|
||||
# eta corresponds to η in paper and should be between [0, 1]
|
||||
# do x_t -> x_t-1
|
||||
image = self.scheduler.step(model_output, t, image, eta).prev_sample
|
||||
# 2. predict previous mean of image x_t-1 and add variance depending on eta
|
||||
# eta corresponds to η in paper and should be between [0, 1]
|
||||
# do x_t -> x_t-1
|
||||
image = self.scheduler.step(model_output, t, image, eta).prev_sample
|
||||
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
||||
|
||||
return image
|
||||
return image
|
||||
```
|
||||
|
||||
Now you can upload this short file under the name `pipeline.py` in your preferred [model repository](https://huggingface.co/docs/hub/models-uploading). For Stable Diffusion pipelines, you may also [join the community organisation for shared pipelines](https://huggingface.co/organizations/sd-diffusers-pipelines-library/share/BUPyDUuHcciGTOKaExlqtfFcyCZsVFdrjr) to upload yours.
|
||||
|
||||
@@ -15,6 +15,7 @@ specific language governing permissions and limitations under the License.
|
||||
The [`StableDiffusionImg2ImgPipeline`] lets you pass a text prompt and an initial image to condition the generation of new images.
|
||||
|
||||
```python
|
||||
import torch
|
||||
import requests
|
||||
from PIL import Image
|
||||
from io import BytesIO
|
||||
|
||||
@@ -119,6 +119,46 @@ accelerate launch train_dreambooth.py \
|
||||
--max_train_steps=800
|
||||
```
|
||||
|
||||
### Training on a 8 GB GPU:
|
||||
|
||||
By using [DeepSpeed](https://www.deepspeed.ai/) it's possible to offload some
|
||||
tensors from VRAM to either CPU or NVME allowing to train with less VRAM.
|
||||
|
||||
DeepSpeed needs to be enabled with `accelerate config`. During configuration
|
||||
answer yes to "Do you want to use DeepSpeed?". With DeepSpeed stage 2, fp16
|
||||
mixed precision and offloading both parameters and optimizer state to cpu it's
|
||||
possible to train on under 8 GB VRAM with a drawback of requiring significantly
|
||||
more RAM (about 25 GB). See [documentation](https://huggingface.co/docs/accelerate/usage_guides/deepspeed) for more DeepSpeed configuration options.
|
||||
|
||||
Changing the default Adam optimizer to DeepSpeed's special version of Adam
|
||||
`deepspeed.ops.adam.DeepSpeedCPUAdam` gives a substantial speedup but enabling
|
||||
it requires CUDA toolchain with the same version as pytorch. 8-bit optimizer
|
||||
does not seem to be compatible with DeepSpeed at the moment.
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
|
||||
export INSTANCE_DIR="path-to-instance-images"
|
||||
export CLASS_DIR="path-to-class-images"
|
||||
export OUTPUT_DIR="path-to-save-model"
|
||||
|
||||
accelerate launch train_dreambooth.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME --use_auth_token \
|
||||
--instance_data_dir=$INSTANCE_DIR \
|
||||
--class_data_dir=$CLASS_DIR \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--with_prior_preservation --prior_loss_weight=1.0 \
|
||||
--instance_prompt="a photo of sks dog" \
|
||||
--class_prompt="a photo of dog" \
|
||||
--resolution=512 \
|
||||
--train_batch_size=1 \
|
||||
--gradient_accumulation_steps=1 --gradient_checkpointing \
|
||||
--learning_rate=5e-6 \
|
||||
--lr_scheduler="constant" \
|
||||
--lr_warmup_steps=0 \
|
||||
--num_class_images=200 \
|
||||
--max_train_steps=800 \
|
||||
--mixed_precision=fp16
|
||||
```
|
||||
|
||||
## Inference
|
||||
|
||||
|
||||
@@ -77,7 +77,7 @@ def parse_args():
|
||||
type=int,
|
||||
default=100,
|
||||
help=(
|
||||
"Minimal class images for prior perversation loss. If not have enough images, additional images will be"
|
||||
"Minimal class images for prior preservation loss. If not have enough images, additional images will be"
|
||||
" sampled with class_prompt."
|
||||
),
|
||||
)
|
||||
@@ -471,9 +471,17 @@ def main():
|
||||
unet, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
|
||||
# Move text_encode and vae to gpu
|
||||
text_encoder.to(accelerator.device)
|
||||
vae.to(accelerator.device)
|
||||
weight_dtype = torch.float32
|
||||
if args.mixed_precision == "fp16":
|
||||
weight_dtype = torch.float16
|
||||
elif args.mixed_precision == "bf16":
|
||||
weight_dtype = torch.bfloat16
|
||||
|
||||
# Move text_encode and vae to gpu.
|
||||
# For mixed precision training we cast the text_encoder and vae weights to half-precision
|
||||
# as these models are only used for inference, keeping weights in full precision is not required.
|
||||
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
||||
vae.to(accelerator.device, dtype=weight_dtype)
|
||||
|
||||
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
@@ -509,11 +517,11 @@ def main():
|
||||
with accelerator.accumulate(unet):
|
||||
# Convert images to latent space
|
||||
with torch.no_grad():
|
||||
latents = vae.encode(batch["pixel_values"]).latent_dist.sample()
|
||||
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
|
||||
latents = latents * 0.18215
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn(latents.shape).to(latents.device)
|
||||
noise = torch.randn_like(latents)
|
||||
bsz = latents.shape[0]
|
||||
# Sample a random timestep for each image
|
||||
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
|
||||
@@ -536,15 +544,15 @@ def main():
|
||||
noise, noise_prior = torch.chunk(noise, 2, dim=0)
|
||||
|
||||
# Compute instance loss
|
||||
loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
|
||||
loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="none").mean([1, 2, 3]).mean()
|
||||
|
||||
# Compute prior loss
|
||||
prior_loss = F.mse_loss(noise_pred_prior, noise_prior, reduction="none").mean([1, 2, 3]).mean()
|
||||
prior_loss = F.mse_loss(noise_pred_prior.float(), noise_prior.float(), reduction="mean")
|
||||
|
||||
# Add the prior loss to the instance loss.
|
||||
loss = loss + args.prior_loss_weight * prior_loss
|
||||
else:
|
||||
loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
|
||||
loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
|
||||
|
||||
accelerator.backward(loss)
|
||||
if accelerator.sync_gradients:
|
||||
@@ -575,9 +583,7 @@ def main():
|
||||
pipeline.save_pretrained(args.output_dir)
|
||||
|
||||
if args.push_to_hub:
|
||||
repo.push_to_hub(
|
||||
args, pipeline, repo, commit_message="End of training", blocking=False, auto_lfs_prune=True
|
||||
)
|
||||
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
|
||||
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
101
examples/text_to_image/README.md
Normal file
101
examples/text_to_image/README.md
Normal file
@@ -0,0 +1,101 @@
|
||||
# Stable Diffusion text-to-image fine-tuning
|
||||
|
||||
The `train_text_to_image.py` script shows how to fine-tune stable diffusion model on your own dataset.
|
||||
|
||||
___Note___:
|
||||
|
||||
___This script is experimental. The script fine-tunes the whole model and often times the model overifits and runs into issues like catastrophic forgetting. It's recommended to try different hyperparamters to get the best result on your dataset.___
|
||||
|
||||
|
||||
## Running locally
|
||||
### Installing the dependencies
|
||||
|
||||
Before running the scripts, make sure to install the library's training dependencies:
|
||||
|
||||
```bash
|
||||
pip install git+https://github.com/huggingface/diffusers.git
|
||||
pip install -U -r requirements.txt
|
||||
```
|
||||
|
||||
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
|
||||
|
||||
```bash
|
||||
accelerate config
|
||||
```
|
||||
|
||||
### Pokemon example
|
||||
|
||||
You need to accept the model license before downloading or using the weights. In this example we'll use model version `v1-4`, so you'll need to visit [its card](https://huggingface.co/CompVis/stable-diffusion-v1-4), read the license and tick the checkbox if you agree.
|
||||
|
||||
You have to be a registered user in 🤗 Hugging Face Hub, and you'll also need to use an access token for the code to work. For more information on access tokens, please refer to [this section of the documentation](https://huggingface.co/docs/hub/security-tokens).
|
||||
|
||||
Run the following command to authenticate your token
|
||||
|
||||
```bash
|
||||
huggingface-cli login
|
||||
```
|
||||
|
||||
If you have already cloned the repo, then you won't need to go through these steps.
|
||||
|
||||
<br>
|
||||
|
||||
#### Hardware
|
||||
With `gradient_checkpointing` and `mixed_precision` it should be possible to fine tune the model on a single 24GB GPU. For higher `batch_size` and faster training it's better to use GPUs with >30GB memory.
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
|
||||
export dataset_name="lambdalabs/pokemon-blip-captions"
|
||||
|
||||
accelerate launch train_text_to_image.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||
--dataset_name=$dataset_name \
|
||||
--use_ema \
|
||||
--resolution=512 --center_crop --random_flip \
|
||||
--train_batch_size=1 \
|
||||
--gradient_accumulation_steps=4 \
|
||||
--gradient_checkpointing \
|
||||
--mixed_precision="fp16" \
|
||||
--max_train_steps=15000 \
|
||||
--learning_rate=1e-05 \
|
||||
--max_grad_norm=1 \
|
||||
--lr_scheduler="constant" --lr_warmup_steps=0 \
|
||||
--output_dir="sd-pokemon-model"
|
||||
```
|
||||
|
||||
|
||||
To run on your own training files prepare the dataset according to the format required by `datasets`, you can find the instructions for how to do that in this [document](https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder-with-metadata).
|
||||
If you wish to use custom loading logic, you should modify the script, we have left pointers for that in the training script.
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
|
||||
export TRAIN_DIR="path_to_your_dataset"
|
||||
|
||||
accelerate launch train_text_to_image.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||
--train_data_dir=$TRAIN_DIR \
|
||||
--use_ema \
|
||||
--resolution=512 --center_crop --random_flip \
|
||||
--train_batch_size=1 \
|
||||
--gradient_accumulation_steps=4 \
|
||||
--gradient_checkpointing \
|
||||
--mixed_precision="fp16" \
|
||||
--max_train_steps=15000 \
|
||||
--learning_rate=1e-05 \
|
||||
--max_grad_norm=1 \
|
||||
--lr_scheduler="constant" --lr_warmup_steps=0 \
|
||||
--output_dir="sd-pokemon-model"
|
||||
```
|
||||
|
||||
Once the training is finished the model will be saved in the `output_dir` specified in the command. In this example it's `sd-pokemon-model`. To load the fine-tuned model for inference just pass that path to `StableDiffusionPipeline`
|
||||
|
||||
|
||||
```python
|
||||
from diffusers import StableDiffusionPipeline
|
||||
|
||||
model_path = "path_to_saved_model"
|
||||
pipe = StableDiffusionPipeline.from_pretrained(model_path, torch_dtype=torch.float16)
|
||||
pipe.to("cuda")
|
||||
|
||||
image = pipe(prompt="yoda").images[0]
|
||||
image.save("yoda-pokemon.png")
|
||||
```
|
||||
7
examples/text_to_image/requirements.txt
Normal file
7
examples/text_to_image/requirements.txt
Normal file
@@ -0,0 +1,7 @@
|
||||
diffusers==0.4.1
|
||||
accelerate
|
||||
torchvision
|
||||
transformers>=4.21.0
|
||||
ftfy
|
||||
tensorboard
|
||||
modelcards
|
||||
627
examples/text_to_image/train_text_to_image.py
Normal file
627
examples/text_to_image/train_text_to_image.py
Normal file
@@ -0,0 +1,627 @@
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
from pathlib import Path
|
||||
from typing import Iterable, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
|
||||
from accelerate import Accelerator
|
||||
from accelerate.logging import get_logger
|
||||
from accelerate.utils import set_seed
|
||||
from datasets import load_dataset
|
||||
from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
|
||||
from huggingface_hub import HfFolder, Repository, whoami
|
||||
from torchvision import transforms
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
||||
parser.add_argument(
|
||||
"--pretrained_model_name_or_path",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset_name",
|
||||
type=str,
|
||||
default=None,
|
||||
help=(
|
||||
"The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
|
||||
" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
|
||||
" or to a folder containing files that 🤗 Datasets can understand."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset_config_name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The config of the Dataset, leave as None if there's only one config.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train_data_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help=(
|
||||
"A folder containing the training data. Folder contents must follow the structure described in"
|
||||
" https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
|
||||
" must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--image_column", type=str, default="image", help="The column of the dataset containing an image."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--caption_column",
|
||||
type=str,
|
||||
default="text",
|
||||
help="The column of the dataset containing a caption or a list of captions.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_train_samples",
|
||||
type=int,
|
||||
default=None,
|
||||
help=(
|
||||
"For debugging purposes or quicker training, truncate the number of training examples to this "
|
||||
"value if set."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
default="sd-model-finetuned",
|
||||
help="The output directory where the model predictions and checkpoints will be written.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cache_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The directory where the downloaded models and datasets will be stored.",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
|
||||
parser.add_argument(
|
||||
"--resolution",
|
||||
type=int,
|
||||
default=512,
|
||||
help=(
|
||||
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
|
||||
" resolution"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--center_crop",
|
||||
action="store_true",
|
||||
help="Whether to center crop images before resizing to resolution (if not set, random crop will be used)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--random_flip",
|
||||
action="store_true",
|
||||
help="whether to randomly flip images horizontally",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
|
||||
)
|
||||
parser.add_argument("--num_train_epochs", type=int, default=100)
|
||||
parser.add_argument(
|
||||
"--max_train_steps",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gradient_accumulation_steps",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gradient_checkpointing",
|
||||
action="store_true",
|
||||
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--learning_rate",
|
||||
type=float,
|
||||
default=1e-4,
|
||||
help="Initial learning rate (after the potential warmup period) to use.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--scale_lr",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lr_scheduler",
|
||||
type=str,
|
||||
default="constant",
|
||||
help=(
|
||||
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
|
||||
' "constant", "constant_with_warmup"]'
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
|
||||
)
|
||||
parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")
|
||||
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
|
||||
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
|
||||
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
|
||||
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
|
||||
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
||||
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
|
||||
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
|
||||
parser.add_argument(
|
||||
"--hub_model_id",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The name of the repository to keep in sync with the local `output_dir`.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--logging_dir",
|
||||
type=str,
|
||||
default="logs",
|
||||
help=(
|
||||
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
|
||||
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mixed_precision",
|
||||
type=str,
|
||||
default="no",
|
||||
choices=["no", "fp16", "bf16"],
|
||||
help=(
|
||||
"Whether to use mixed precision. Choose"
|
||||
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
|
||||
"and an Nvidia Ampere GPU."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--report_to",
|
||||
type=str,
|
||||
default="tensorboard",
|
||||
help=(
|
||||
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,'
|
||||
' `"wandb"` and `"comet_ml"`. Use `"all"` (default) to report to all integrations.'
|
||||
"Only applicable when `--with_tracking` is passed."
|
||||
),
|
||||
)
|
||||
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
||||
|
||||
args = parser.parse_args()
|
||||
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
||||
if env_local_rank != -1 and env_local_rank != args.local_rank:
|
||||
args.local_rank = env_local_rank
|
||||
|
||||
# Sanity checks
|
||||
if args.dataset_name is None and args.train_data_dir is None:
|
||||
raise ValueError("Need either a dataset name or a training folder.")
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
|
||||
if token is None:
|
||||
token = HfFolder.get_token()
|
||||
if organization is None:
|
||||
username = whoami(token)["name"]
|
||||
return f"{username}/{model_id}"
|
||||
else:
|
||||
return f"{organization}/{model_id}"
|
||||
|
||||
|
||||
dataset_name_mapping = {
|
||||
"lambdalabs/pokemon-blip-captions": ("image", "text"),
|
||||
}
|
||||
|
||||
|
||||
# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14
|
||||
class EMAModel:
|
||||
"""
|
||||
Exponential Moving Average of models weights
|
||||
"""
|
||||
|
||||
def __init__(self, parameters: Iterable[torch.nn.Parameter], decay=0.9999):
|
||||
parameters = list(parameters)
|
||||
self.shadow_params = [p.clone().detach() for p in parameters]
|
||||
|
||||
self.decay = decay
|
||||
self.optimization_step = 0
|
||||
|
||||
def get_decay(self, optimization_step):
|
||||
"""
|
||||
Compute the decay factor for the exponential moving average.
|
||||
"""
|
||||
value = (1 + optimization_step) / (10 + optimization_step)
|
||||
return 1 - min(self.decay, value)
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, parameters):
|
||||
parameters = list(parameters)
|
||||
|
||||
self.optimization_step += 1
|
||||
self.decay = self.get_decay(self.optimization_step)
|
||||
|
||||
for s_param, param in zip(self.shadow_params, parameters):
|
||||
if param.requires_grad:
|
||||
tmp = self.decay * (s_param - param)
|
||||
s_param.sub_(tmp)
|
||||
else:
|
||||
s_param.copy_(param)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None:
|
||||
"""
|
||||
Copy current averaged parameters into given collection of parameters.
|
||||
|
||||
Args:
|
||||
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
||||
updated with the stored moving averages. If `None`, the
|
||||
parameters with which this `ExponentialMovingAverage` was
|
||||
initialized will be used.
|
||||
"""
|
||||
parameters = list(parameters)
|
||||
for s_param, param in zip(self.shadow_params, parameters):
|
||||
param.data.copy_(s_param.data)
|
||||
|
||||
def to(self, device=None, dtype=None) -> None:
|
||||
r"""Move internal buffers of the ExponentialMovingAverage to `device`.
|
||||
|
||||
Args:
|
||||
device: like `device` argument to `torch.Tensor.to`
|
||||
"""
|
||||
# .to() on the tensors handles None correctly
|
||||
self.shadow_params = [
|
||||
p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device)
|
||||
for p in self.shadow_params
|
||||
]
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
logging_dir = os.path.join(args.output_dir, args.logging_dir)
|
||||
|
||||
accelerator = Accelerator(
|
||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||
mixed_precision=args.mixed_precision,
|
||||
log_with=args.report_to,
|
||||
logging_dir=logging_dir,
|
||||
)
|
||||
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
level=logging.INFO,
|
||||
)
|
||||
|
||||
# If passed along, set the training seed now.
|
||||
if args.seed is not None:
|
||||
set_seed(args.seed)
|
||||
|
||||
# Handle the repository creation
|
||||
if accelerator.is_main_process:
|
||||
if args.push_to_hub:
|
||||
if args.hub_model_id is None:
|
||||
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
|
||||
else:
|
||||
repo_name = args.hub_model_id
|
||||
repo = Repository(args.output_dir, clone_from=repo_name)
|
||||
|
||||
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
|
||||
if "step_*" not in gitignore:
|
||||
gitignore.write("step_*\n")
|
||||
if "epoch_*" not in gitignore:
|
||||
gitignore.write("epoch_*\n")
|
||||
elif args.output_dir is not None:
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
# Load models and create wrapper for stable diffusion
|
||||
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
|
||||
text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
|
||||
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
|
||||
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
|
||||
|
||||
# Freeze vae and text_encoder
|
||||
vae.requires_grad_(False)
|
||||
text_encoder.requires_grad_(False)
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
unet.enable_gradient_checkpointing()
|
||||
|
||||
if args.scale_lr:
|
||||
args.learning_rate = (
|
||||
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
|
||||
)
|
||||
|
||||
# Initialize the optimizer
|
||||
if args.use_8bit_adam:
|
||||
try:
|
||||
import bitsandbytes as bnb
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
|
||||
)
|
||||
|
||||
optimizer_cls = bnb.optim.AdamW8bit
|
||||
else:
|
||||
optimizer_cls = torch.optim.AdamW
|
||||
|
||||
optimizer = optimizer_cls(
|
||||
unet.parameters(),
|
||||
lr=args.learning_rate,
|
||||
betas=(args.adam_beta1, args.adam_beta2),
|
||||
weight_decay=args.adam_weight_decay,
|
||||
eps=args.adam_epsilon,
|
||||
)
|
||||
|
||||
# TODO (patil-suraj): load scheduler using args
|
||||
noise_scheduler = DDPMScheduler(
|
||||
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, tensor_format="pt"
|
||||
)
|
||||
|
||||
# Get the datasets: you can either provide your own training and evaluation files (see below)
|
||||
# or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
|
||||
|
||||
# In distributed training, the load_dataset function guarantees that only one local process can concurrently
|
||||
# download the dataset.
|
||||
if args.dataset_name is not None:
|
||||
# Downloading and loading a dataset from the hub.
|
||||
dataset = load_dataset(
|
||||
args.dataset_name,
|
||||
args.dataset_config_name,
|
||||
cache_dir=args.cache_dir,
|
||||
)
|
||||
else:
|
||||
data_files = {}
|
||||
if args.train_data_dir is not None:
|
||||
data_files["train"] = os.path.join(args.train_data_dir, "**")
|
||||
dataset = load_dataset(
|
||||
"imagefolder",
|
||||
data_files=data_files,
|
||||
cache_dir=args.cache_dir,
|
||||
)
|
||||
# See more about loading custom images at
|
||||
# https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
|
||||
|
||||
# Preprocessing the datasets.
|
||||
# We need to tokenize inputs and targets.
|
||||
column_names = dataset["train"].column_names
|
||||
|
||||
# 6. Get the column names for input/target.
|
||||
dataset_columns = dataset_name_mapping.get(args.dataset_name, None)
|
||||
if args.image_column is None:
|
||||
image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
|
||||
else:
|
||||
image_column = args.image_column
|
||||
if image_column not in column_names:
|
||||
raise ValueError(
|
||||
f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}"
|
||||
)
|
||||
if args.caption_column is None:
|
||||
caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
|
||||
else:
|
||||
caption_column = args.caption_column
|
||||
if caption_column not in column_names:
|
||||
raise ValueError(
|
||||
f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}"
|
||||
)
|
||||
|
||||
# Preprocessing the datasets.
|
||||
# We need to tokenize input captions and transform the images.
|
||||
def tokenize_captions(examples, is_train=True):
|
||||
captions = []
|
||||
for caption in examples[caption_column]:
|
||||
if isinstance(caption, str):
|
||||
captions.append(caption)
|
||||
elif isinstance(caption, (list, np.ndarray)):
|
||||
# take a random caption if there are multiple
|
||||
captions.append(random.choice(caption) if is_train else caption[0])
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Caption column `{caption_column}` should contain either strings or lists of strings."
|
||||
)
|
||||
inputs = tokenizer(captions, max_length=tokenizer.model_max_length, padding="do_not_pad", truncation=True)
|
||||
input_ids = inputs.input_ids
|
||||
return input_ids
|
||||
|
||||
train_transforms = transforms.Compose(
|
||||
[
|
||||
transforms.Resize((args.resolution, args.resolution), interpolation=transforms.InterpolationMode.BILINEAR),
|
||||
transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
|
||||
transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.5], [0.5]),
|
||||
]
|
||||
)
|
||||
|
||||
def preprocess_train(examples):
|
||||
images = [image.convert("RGB") for image in examples[image_column]]
|
||||
examples["pixel_values"] = [train_transforms(image) for image in images]
|
||||
examples["input_ids"] = tokenize_captions(examples)
|
||||
|
||||
return examples
|
||||
|
||||
with accelerator.main_process_first():
|
||||
if args.max_train_samples is not None:
|
||||
dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
|
||||
# Set the training transforms
|
||||
train_dataset = dataset["train"].with_transform(preprocess_train)
|
||||
|
||||
def collate_fn(examples):
|
||||
pixel_values = torch.stack([example["pixel_values"] for example in examples])
|
||||
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
|
||||
input_ids = [example["input_ids"] for example in examples]
|
||||
padded_tokens = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt")
|
||||
return {
|
||||
"pixel_values": pixel_values,
|
||||
"input_ids": padded_tokens.input_ids,
|
||||
"attention_mask": padded_tokens.attention_mask,
|
||||
}
|
||||
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
train_dataset, shuffle=True, collate_fn=collate_fn, batch_size=args.train_batch_size
|
||||
)
|
||||
|
||||
# Scheduler and math around the number of training steps.
|
||||
overrode_max_train_steps = False
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
if args.max_train_steps is None:
|
||||
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
||||
overrode_max_train_steps = True
|
||||
|
||||
lr_scheduler = get_scheduler(
|
||||
args.lr_scheduler,
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
|
||||
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
|
||||
)
|
||||
|
||||
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
unet, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
|
||||
weight_dtype = torch.float32
|
||||
if args.mixed_precision == "fp16":
|
||||
weight_dtype = torch.float16
|
||||
elif args.mixed_precision == "bf16":
|
||||
weight_dtype = torch.bfloat16
|
||||
|
||||
# Move text_encode and vae to gpu.
|
||||
# For mixed precision training we cast the text_encoder and vae weights to half-precision
|
||||
# as these models are only used for inference, keeping weights in full precision is not required.
|
||||
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
||||
vae.to(accelerator.device, dtype=weight_dtype)
|
||||
|
||||
# Create EMA for the unet.
|
||||
if args.use_ema:
|
||||
ema_unet = EMAModel(unet.parameters())
|
||||
|
||||
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
if overrode_max_train_steps:
|
||||
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
||||
# Afterwards we recalculate our number of training epochs
|
||||
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||
|
||||
# We need to initialize the trackers we use, and also store our configuration.
|
||||
# The trackers initializes automatically on the main process.
|
||||
if accelerator.is_main_process:
|
||||
accelerator.init_trackers("text2image-fine-tune", config=vars(args))
|
||||
|
||||
# Train!
|
||||
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
||||
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(f" Num examples = {len(train_dataset)}")
|
||||
logger.info(f" Num Epochs = {args.num_train_epochs}")
|
||||
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
|
||||
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
||||
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
|
||||
logger.info(f" Total optimization steps = {args.max_train_steps}")
|
||||
|
||||
# Only show the progress bar once on each machine.
|
||||
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
|
||||
progress_bar.set_description("Steps")
|
||||
global_step = 0
|
||||
|
||||
for epoch in range(args.num_train_epochs):
|
||||
unet.train()
|
||||
train_loss = 0.0
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
with accelerator.accumulate(unet):
|
||||
# Convert images to latent space
|
||||
latents = vae.encode(batch["pixel_values"].to(weight_dtype)).latent_dist.sample()
|
||||
latents = latents * 0.18215
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(latents)
|
||||
bsz = latents.shape[0]
|
||||
# Sample a random timestep for each image
|
||||
timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device)
|
||||
timesteps = timesteps.long()
|
||||
|
||||
# Add noise to the latents according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process)
|
||||
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||
|
||||
# Get the text embedding for conditioning
|
||||
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
|
||||
|
||||
# Predict the noise residual and compute loss
|
||||
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
||||
loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
|
||||
|
||||
# Gather the losses across all processes for logging (if we use distributed training).
|
||||
avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
|
||||
train_loss += avg_loss.item() / args.gradient_accumulation_steps
|
||||
|
||||
# Backpropagate
|
||||
accelerator.backward(loss)
|
||||
if accelerator.sync_gradients:
|
||||
accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
if args.use_ema:
|
||||
ema_unet.step(unet.parameters())
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
accelerator.log({"train_loss": train_loss}, step=global_step)
|
||||
train_loss = 0.0
|
||||
|
||||
logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
|
||||
progress_bar.set_postfix(**logs)
|
||||
|
||||
if global_step >= args.max_train_steps:
|
||||
break
|
||||
|
||||
# Create the pipeline using the trained modules and save it.
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
unet = accelerator.unwrap_model(unet)
|
||||
if args.use_ema:
|
||||
ema_unet.copy_to(unet.parameters())
|
||||
|
||||
pipeline = StableDiffusionPipeline(
|
||||
text_encoder=text_encoder,
|
||||
vae=vae,
|
||||
unet=unet,
|
||||
tokenizer=tokenizer,
|
||||
scheduler=PNDMScheduler(
|
||||
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True
|
||||
),
|
||||
safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"),
|
||||
feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
|
||||
)
|
||||
pipeline.save_pretrained(args.output_dir)
|
||||
|
||||
if args.push_to_hub:
|
||||
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
|
||||
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -569,9 +569,7 @@ def main():
|
||||
save_progress(text_encoder, placeholder_token_id, accelerator, args)
|
||||
|
||||
if args.push_to_hub:
|
||||
repo.push_to_hub(
|
||||
args, pipeline, repo, commit_message="End of training", blocking=False, auto_lfs_prune=True
|
||||
)
|
||||
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
|
||||
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ from accelerate import Accelerator
|
||||
from accelerate.logging import get_logger
|
||||
from datasets import load_dataset
|
||||
from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
|
||||
from diffusers.hub_utils import init_git_repo, push_to_hub
|
||||
from diffusers.hub_utils import init_git_repo
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import EMAModel
|
||||
from torchvision.transforms import (
|
||||
@@ -185,7 +185,7 @@ def main(args):
|
||||
if epoch % args.save_model_epochs == 0 or epoch == args.num_epochs - 1:
|
||||
# save the model
|
||||
if args.push_to_hub:
|
||||
push_to_hub(args, pipeline, repo, commit_message=f"Epoch {epoch}", blocking=False)
|
||||
repo.push_to_hub(commit_message=f"Epoch {epoch}", blocking=False)
|
||||
else:
|
||||
pipeline.save_pretrained(args.output_dir)
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
@@ -1,102 +0,0 @@
|
||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||
|
||||
|
||||
class CustomPipeline(DiffusionPipeline):
|
||||
r"""
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
Parameters:
|
||||
unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image.
|
||||
scheduler ([`SchedulerMixin`]):
|
||||
A scheduler to be used in combination with `unet` to denoise the encoded image. Can be one of
|
||||
[`DDPMScheduler`], or [`DDIMScheduler`].
|
||||
"""
|
||||
|
||||
def __init__(self, unet, scheduler):
|
||||
super().__init__()
|
||||
self.register_modules(unet=unet, scheduler=scheduler)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
batch_size: int = 1,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
eta: float = 0.0,
|
||||
num_inference_steps: int = 50,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
**kwargs,
|
||||
) -> Union[ImagePipelineOutput, Tuple]:
|
||||
r"""
|
||||
Args:
|
||||
batch_size (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
||||
deterministic.
|
||||
eta (`float`, *optional*, defaults to 0.0):
|
||||
The eta parameter which controls the scale of the variance (0 is DDIM and 1 is one type of DDPM).
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
[`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if
|
||||
`return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
|
||||
generated images.
|
||||
"""
|
||||
|
||||
# Sample gaussian noise to begin loop
|
||||
image = torch.randn(
|
||||
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
|
||||
generator=generator,
|
||||
)
|
||||
image = image.to(self.device)
|
||||
|
||||
# set step values
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
for t in self.progress_bar(self.scheduler.timesteps):
|
||||
# 1. predict noise model_output
|
||||
model_output = self.unet(image, t).sample
|
||||
|
||||
# 2. predict previous mean of image x_t-1 and add variance depending on eta
|
||||
# eta corresponds to η in paper and should be between [0, 1]
|
||||
# do x_t -> x_t-1
|
||||
image = self.scheduler.step(model_output, t, image, eta).prev_sample
|
||||
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
||||
if output_type == "pil":
|
||||
image = self.numpy_to_pil(image)
|
||||
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
return ImagePipelineOutput(images=image), "This is a test"
|
||||
@@ -1 +0,0 @@
|
||||
b8fa12635e53eebebc22f95ee863e7af4fc2fb07
|
||||
@@ -1 +0,0 @@
|
||||
../../blobs/bbbcb9f65616524d6199fa3bc16dc0500fb2cbbb
|
||||
@@ -206,7 +206,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"--opset",
|
||||
default=14,
|
||||
type=str,
|
||||
type=int,
|
||||
help="The version of the ONNX operator set to use.",
|
||||
)
|
||||
|
||||
|
||||
2
setup.py
2
setup.py
@@ -211,7 +211,7 @@ install_requires = [
|
||||
|
||||
setup(
|
||||
name="diffusers",
|
||||
version="0.4.2", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
version="0.5.0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
description="Diffusers",
|
||||
long_description=open("README.md", "r", encoding="utf-8").read(),
|
||||
long_description_content_type="text/markdown",
|
||||
|
||||
@@ -9,7 +9,7 @@ from .utils import (
|
||||
)
|
||||
|
||||
|
||||
__version__ = "0.4.2"
|
||||
__version__ = "0.5.0"
|
||||
|
||||
from .configuration_utils import ConfigMixin
|
||||
from .onnx_utils import OnnxRuntimeModel
|
||||
|
||||
@@ -27,7 +27,7 @@ from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
|
||||
from requests import HTTPError
|
||||
|
||||
from . import is_torch_available
|
||||
from . import __version__, is_torch_available
|
||||
from .modeling_flax_pytorch_utils import convert_pytorch_state_dict_to_flax
|
||||
from .utils import (
|
||||
CONFIG_NAME,
|
||||
@@ -286,10 +286,13 @@ class FlaxModelMixin:
|
||||
local_files_only = kwargs.pop("local_files_only", False)
|
||||
use_auth_token = kwargs.pop("use_auth_token", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
from_auto_class = kwargs.pop("_from_auto", False)
|
||||
subfolder = kwargs.pop("subfolder", None)
|
||||
|
||||
user_agent = {"file_type": "model", "framework": "flax", "from_auto_class": from_auto_class}
|
||||
user_agent = {
|
||||
"diffusers": __version__,
|
||||
"file_type": "model",
|
||||
"framework": "flax",
|
||||
}
|
||||
|
||||
# Load config if we don't provide a configuration
|
||||
config_path = config if config is not None else pretrained_model_name_or_path
|
||||
|
||||
@@ -26,6 +26,7 @@ from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
|
||||
from requests import HTTPError
|
||||
|
||||
from . import __version__
|
||||
from .utils import CONFIG_NAME, DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, WEIGHTS_NAME, logging
|
||||
|
||||
|
||||
@@ -292,12 +293,15 @@ class ModelMixin(torch.nn.Module):
|
||||
local_files_only = kwargs.pop("local_files_only", False)
|
||||
use_auth_token = kwargs.pop("use_auth_token", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
from_auto_class = kwargs.pop("_from_auto", False)
|
||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||
subfolder = kwargs.pop("subfolder", None)
|
||||
device_map = kwargs.pop("device_map", None)
|
||||
|
||||
user_agent = {"file_type": "model", "framework": "pytorch", "from_auto_class": from_auto_class}
|
||||
user_agent = {
|
||||
"diffusers": __version__,
|
||||
"file_type": "model",
|
||||
"framework": "pytorch",
|
||||
}
|
||||
|
||||
# Load config if we don't provide a configuration
|
||||
config_path = pretrained_model_name_or_path
|
||||
|
||||
@@ -1,3 +1,16 @@
|
||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
|
||||
@@ -9,9 +9,10 @@ class Upsample2D(nn.Module):
|
||||
"""
|
||||
An upsampling layer with an optional convolution.
|
||||
|
||||
:param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is
|
||||
applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
||||
upsampling occurs in the inner-two dimensions.
|
||||
Parameters:
|
||||
channels: channels in the inputs and outputs.
|
||||
use_conv: a bool determining if a convolution is applied.
|
||||
dims: determines if the signal is 1D, 2D, or 3D. If 3D, then upsampling occurs in the inner-two dimensions.
|
||||
"""
|
||||
|
||||
def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
|
||||
@@ -40,6 +41,13 @@ class Upsample2D(nn.Module):
|
||||
if self.use_conv_transpose:
|
||||
return self.conv(hidden_states)
|
||||
|
||||
# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
|
||||
# TODO(Suraj): Remove this cast once the issue is fixed in PyTorch
|
||||
# https://github.com/pytorch/pytorch/issues/86679
|
||||
dtype = hidden_states.dtype
|
||||
if dtype == torch.bfloat16:
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
|
||||
# if `output_size` is passed we force the interpolation output
|
||||
# size and do not make use of `scale_factor=2`
|
||||
if output_size is None:
|
||||
@@ -47,6 +55,10 @@ class Upsample2D(nn.Module):
|
||||
else:
|
||||
hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
|
||||
|
||||
# If the input is bfloat16, we cast back to bfloat16
|
||||
if dtype == torch.bfloat16:
|
||||
hidden_states = hidden_states.to(dtype)
|
||||
|
||||
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
|
||||
if self.use_conv:
|
||||
if self.name == "conv":
|
||||
@@ -61,9 +73,10 @@ class Downsample2D(nn.Module):
|
||||
"""
|
||||
A downsampling layer with an optional convolution.
|
||||
|
||||
:param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is
|
||||
applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
||||
downsampling occurs in the inner-two dimensions.
|
||||
Parameters:
|
||||
channels: channels in the inputs and outputs.
|
||||
use_conv: a bool determining if a convolution is applied.
|
||||
dims: determines if the signal is 1D, 2D, or 3D. If 3D, then downsampling occurs in the inner-two dimensions.
|
||||
"""
|
||||
|
||||
def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
|
||||
@@ -115,21 +128,22 @@ class FirUpsample2D(nn.Module):
|
||||
def _upsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1):
|
||||
"""Fused `upsample_2d()` followed by `Conv2d()`.
|
||||
|
||||
Args:
|
||||
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
|
||||
efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary:
|
||||
order.
|
||||
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
|
||||
C]`.
|
||||
weight: Weight tensor of the shape `[filterH, filterW, inChannels,
|
||||
outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
|
||||
kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
|
||||
(separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
|
||||
factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
|
||||
efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
|
||||
arbitrary order.
|
||||
|
||||
Args:
|
||||
hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
|
||||
weight: Weight tensor of the shape `[filterH, filterW, inChannels,
|
||||
outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
|
||||
kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
|
||||
(separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
|
||||
factor: Integer upsampling factor (default: 2).
|
||||
gain: Scaling factor for signal magnitude (default: 1.0).
|
||||
|
||||
Returns:
|
||||
Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same datatype as
|
||||
`x`.
|
||||
output: Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same
|
||||
datatype as `hidden_states`.
|
||||
"""
|
||||
|
||||
assert isinstance(factor, int) and factor >= 1
|
||||
@@ -164,7 +178,6 @@ class FirUpsample2D(nn.Module):
|
||||
output_shape[1] - (hidden_states.shape[3] - 1) * stride[1] - convW,
|
||||
)
|
||||
assert output_padding[0] >= 0 and output_padding[1] >= 0
|
||||
inC = weight.shape[1]
|
||||
num_groups = hidden_states.shape[1] // inC
|
||||
|
||||
# Transpose weights.
|
||||
@@ -214,20 +227,23 @@ class FirDownsample2D(nn.Module):
|
||||
|
||||
def _downsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1):
|
||||
"""Fused `Conv2d()` followed by `downsample_2d()`.
|
||||
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
|
||||
efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
|
||||
arbitrary order.
|
||||
|
||||
Args:
|
||||
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
|
||||
efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary:
|
||||
order.
|
||||
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. w: Weight tensor of the shape `[filterH,
|
||||
filterW, inChannels, outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] //
|
||||
numGroups`. k: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] *
|
||||
factor`, which corresponds to average pooling. factor: Integer downsampling factor (default: 2). gain:
|
||||
Scaling factor for signal magnitude (default: 1.0).
|
||||
hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
|
||||
weight:
|
||||
Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be
|
||||
performed by `inChannels = x.shape[0] // numGroups`.
|
||||
kernel: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] *
|
||||
factor`, which corresponds to average pooling.
|
||||
factor: Integer downsampling factor (default: 2).
|
||||
gain: Scaling factor for signal magnitude (default: 1.0).
|
||||
|
||||
Returns:
|
||||
Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and same
|
||||
datatype as `x`.
|
||||
output: Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and
|
||||
same datatype as `x`.
|
||||
"""
|
||||
|
||||
assert isinstance(factor, int) and factor >= 1
|
||||
@@ -251,17 +267,17 @@ class FirDownsample2D(nn.Module):
|
||||
torch.tensor(kernel, device=hidden_states.device),
|
||||
pad=((pad_value + 1) // 2, pad_value // 2),
|
||||
)
|
||||
hidden_states = F.conv2d(upfirdn_input, weight, stride=stride_value, padding=0)
|
||||
output = F.conv2d(upfirdn_input, weight, stride=stride_value, padding=0)
|
||||
else:
|
||||
pad_value = kernel.shape[0] - factor
|
||||
hidden_states = upfirdn2d_native(
|
||||
output = upfirdn2d_native(
|
||||
hidden_states,
|
||||
torch.tensor(kernel, device=hidden_states.device),
|
||||
down=factor,
|
||||
pad=((pad_value + 1) // 2, pad_value // 2),
|
||||
)
|
||||
|
||||
return hidden_states
|
||||
return output
|
||||
|
||||
def forward(self, hidden_states):
|
||||
if self.use_conv:
|
||||
@@ -393,20 +409,20 @@ class Mish(torch.nn.Module):
|
||||
|
||||
def upsample_2d(hidden_states, kernel=None, factor=2, gain=1):
|
||||
r"""Upsample2D a batch of 2D images with the given filter.
|
||||
|
||||
Args:
|
||||
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
|
||||
filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified
|
||||
`gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is a:
|
||||
multiple of the upsampling factor.
|
||||
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
|
||||
C]`.
|
||||
k: FIR filter of the shape `[firH, firW]` or `[firN]`
|
||||
`gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is
|
||||
a: multiple of the upsampling factor.
|
||||
|
||||
Args:
|
||||
hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
|
||||
kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
|
||||
(separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
|
||||
factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
|
||||
factor: Integer upsampling factor (default: 2).
|
||||
gain: Scaling factor for signal magnitude (default: 1.0).
|
||||
|
||||
Returns:
|
||||
Tensor of the shape `[N, C, H * factor, W * factor]`
|
||||
output: Tensor of the shape `[N, C, H * factor, W * factor]`
|
||||
"""
|
||||
assert isinstance(factor, int) and factor >= 1
|
||||
if kernel is None:
|
||||
@@ -419,30 +435,31 @@ def upsample_2d(hidden_states, kernel=None, factor=2, gain=1):
|
||||
|
||||
kernel = kernel * (gain * (factor**2))
|
||||
pad_value = kernel.shape[0] - factor
|
||||
return upfirdn2d_native(
|
||||
output = upfirdn2d_native(
|
||||
hidden_states,
|
||||
kernel.to(device=hidden_states.device),
|
||||
up=factor,
|
||||
pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2),
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
def downsample_2d(hidden_states, kernel=None, factor=2, gain=1):
|
||||
r"""Downsample2D a batch of 2D images with the given filter.
|
||||
|
||||
Args:
|
||||
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
|
||||
given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the
|
||||
specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its
|
||||
shape is a multiple of the downsampling factor.
|
||||
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
|
||||
C]`.
|
||||
|
||||
Args:
|
||||
hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
|
||||
kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
|
||||
(separable). The default is `[1] * factor`, which corresponds to average pooling.
|
||||
factor: Integer downsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
|
||||
factor: Integer downsampling factor (default: 2).
|
||||
gain: Scaling factor for signal magnitude (default: 1.0).
|
||||
|
||||
Returns:
|
||||
Tensor of the shape `[N, C, H // factor, W // factor]`
|
||||
output: Tensor of the shape `[N, C, H // factor, W // factor]`
|
||||
"""
|
||||
|
||||
assert isinstance(factor, int) and factor >= 1
|
||||
@@ -456,34 +473,34 @@ def downsample_2d(hidden_states, kernel=None, factor=2, gain=1):
|
||||
|
||||
kernel = kernel * gain
|
||||
pad_value = kernel.shape[0] - factor
|
||||
return upfirdn2d_native(
|
||||
output = upfirdn2d_native(
|
||||
hidden_states, kernel.to(device=hidden_states.device), down=factor, pad=((pad_value + 1) // 2, pad_value // 2)
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
def upfirdn2d_native(input, kernel, up=1, down=1, pad=(0, 0)):
|
||||
def upfirdn2d_native(tensor, kernel, up=1, down=1, pad=(0, 0)):
|
||||
up_x = up_y = up
|
||||
down_x = down_y = down
|
||||
pad_x0 = pad_y0 = pad[0]
|
||||
pad_x1 = pad_y1 = pad[1]
|
||||
|
||||
_, channel, in_h, in_w = input.shape
|
||||
input = input.reshape(-1, in_h, in_w, 1)
|
||||
# Rename this variable (input); it shadows a builtin.sonarlint(python:S5806)
|
||||
_, channel, in_h, in_w = tensor.shape
|
||||
tensor = tensor.reshape(-1, in_h, in_w, 1)
|
||||
|
||||
_, in_h, in_w, minor = input.shape
|
||||
_, in_h, in_w, minor = tensor.shape
|
||||
kernel_h, kernel_w = kernel.shape
|
||||
|
||||
out = input.view(-1, in_h, 1, in_w, 1, minor)
|
||||
out = tensor.view(-1, in_h, 1, in_w, 1, minor)
|
||||
|
||||
# Temporary workaround for mps specific issue: https://github.com/pytorch/pytorch/issues/84535
|
||||
if input.device.type == "mps":
|
||||
if tensor.device.type == "mps":
|
||||
out = out.to("cpu")
|
||||
out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
|
||||
out = out.view(-1, in_h * up_y, in_w * up_x, minor)
|
||||
|
||||
out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
|
||||
out = out.to(input.device) # Move back to mps if necessary
|
||||
out = out.to(tensor.device) # Move back to mps if necessary
|
||||
out = out[
|
||||
:,
|
||||
max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
|
||||
|
||||
@@ -1,3 +1,16 @@
|
||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import flax.linen as nn
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
@@ -1,3 +1,16 @@
|
||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
|
||||
@@ -1,3 +1,16 @@
|
||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
@@ -223,7 +236,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
return_dict: bool = True,
|
||||
) -> Union[UNet2DConditionOutput, Tuple]:
|
||||
"""r
|
||||
r"""
|
||||
Args:
|
||||
sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
|
||||
timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
|
||||
|
||||
@@ -1,3 +1,16 @@
|
||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Tuple, Union
|
||||
|
||||
import flax
|
||||
|
||||
@@ -10,10 +10,8 @@
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
|
||||
import numpy as np
|
||||
|
||||
# limitations under the License.
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import flax.linen as nn
|
||||
import jax.numpy as jnp
|
||||
|
||||
@@ -1,3 +1,16 @@
|
||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# JAX implementation of VQGAN from taming-transformers https://github.com/CompVis/taming-transformers
|
||||
|
||||
import math
|
||||
@@ -119,6 +133,8 @@ class FlaxResnetBlock2D(nn.Module):
|
||||
Output channels
|
||||
dropout (:obj:`float`, *optional*, defaults to 0.0):
|
||||
Dropout rate
|
||||
groups (:obj:`int`, *optional*, defaults to `32`):
|
||||
The number of groups to use for group norm.
|
||||
use_nin_shortcut (:obj:`bool`, *optional*, defaults to `None`):
|
||||
Whether to use `nin_shortcut`. This activates a new layer inside ResNet block
|
||||
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
|
||||
@@ -128,13 +144,14 @@ class FlaxResnetBlock2D(nn.Module):
|
||||
in_channels: int
|
||||
out_channels: int = None
|
||||
dropout: float = 0.0
|
||||
groups: int = 32
|
||||
use_nin_shortcut: bool = None
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
out_channels = self.in_channels if self.out_channels is None else self.out_channels
|
||||
|
||||
self.norm1 = nn.GroupNorm(num_groups=32, epsilon=1e-6)
|
||||
self.norm1 = nn.GroupNorm(num_groups=self.groups, epsilon=1e-6)
|
||||
self.conv1 = nn.Conv(
|
||||
out_channels,
|
||||
kernel_size=(3, 3),
|
||||
@@ -143,7 +160,7 @@ class FlaxResnetBlock2D(nn.Module):
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
self.norm2 = nn.GroupNorm(num_groups=32, epsilon=1e-6)
|
||||
self.norm2 = nn.GroupNorm(num_groups=self.groups, epsilon=1e-6)
|
||||
self.dropout_layer = nn.Dropout(self.dropout)
|
||||
self.conv2 = nn.Conv(
|
||||
out_channels,
|
||||
@@ -191,12 +208,15 @@ class FlaxAttentionBlock(nn.Module):
|
||||
Input channels
|
||||
num_head_channels (:obj:`int`, *optional*, defaults to `None`):
|
||||
Number of attention heads
|
||||
num_groups (:obj:`int`, *optional*, defaults to `32`):
|
||||
The number of groups to use for group norm
|
||||
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
|
||||
Parameters `dtype`
|
||||
|
||||
"""
|
||||
channels: int
|
||||
num_head_channels: int = None
|
||||
num_groups: int = 32
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
@@ -204,7 +224,7 @@ class FlaxAttentionBlock(nn.Module):
|
||||
|
||||
dense = partial(nn.Dense, self.channels, dtype=self.dtype)
|
||||
|
||||
self.group_norm = nn.GroupNorm(num_groups=32, epsilon=1e-6)
|
||||
self.group_norm = nn.GroupNorm(num_groups=self.num_groups, epsilon=1e-6)
|
||||
self.query, self.key, self.value = dense(), dense(), dense()
|
||||
self.proj_attn = dense()
|
||||
|
||||
@@ -264,6 +284,8 @@ class FlaxDownEncoderBlock2D(nn.Module):
|
||||
Dropout rate
|
||||
num_layers (:obj:`int`, *optional*, defaults to 1):
|
||||
Number of Resnet layer block
|
||||
resnet_groups (:obj:`int`, *optional*, defaults to `32`):
|
||||
The number of groups to use for the Resnet block group norm
|
||||
add_downsample (:obj:`bool`, *optional*, defaults to `True`):
|
||||
Whether to add downsample layer
|
||||
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
|
||||
@@ -273,6 +295,7 @@ class FlaxDownEncoderBlock2D(nn.Module):
|
||||
out_channels: int
|
||||
dropout: float = 0.0
|
||||
num_layers: int = 1
|
||||
resnet_groups: int = 32
|
||||
add_downsample: bool = True
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
@@ -285,6 +308,7 @@ class FlaxDownEncoderBlock2D(nn.Module):
|
||||
in_channels=in_channels,
|
||||
out_channels=self.out_channels,
|
||||
dropout=self.dropout,
|
||||
groups=self.resnet_groups,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
resnets.append(res_block)
|
||||
@@ -303,9 +327,9 @@ class FlaxDownEncoderBlock2D(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FlaxUpEncoderBlock2D(nn.Module):
|
||||
class FlaxUpDecoderBlock2D(nn.Module):
|
||||
r"""
|
||||
Flax Resnet blocks-based Encoder block for diffusion-based VAE.
|
||||
Flax Resnet blocks-based Decoder block for diffusion-based VAE.
|
||||
|
||||
Parameters:
|
||||
in_channels (:obj:`int`):
|
||||
@@ -316,8 +340,10 @@ class FlaxUpEncoderBlock2D(nn.Module):
|
||||
Dropout rate
|
||||
num_layers (:obj:`int`, *optional*, defaults to 1):
|
||||
Number of Resnet layer block
|
||||
add_downsample (:obj:`bool`, *optional*, defaults to `True`):
|
||||
Whether to add downsample layer
|
||||
resnet_groups (:obj:`int`, *optional*, defaults to `32`):
|
||||
The number of groups to use for the Resnet block group norm
|
||||
add_upsample (:obj:`bool`, *optional*, defaults to `True`):
|
||||
Whether to add upsample layer
|
||||
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
|
||||
Parameters `dtype`
|
||||
"""
|
||||
@@ -325,6 +351,7 @@ class FlaxUpEncoderBlock2D(nn.Module):
|
||||
out_channels: int
|
||||
dropout: float = 0.0
|
||||
num_layers: int = 1
|
||||
resnet_groups: int = 32
|
||||
add_upsample: bool = True
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
@@ -336,6 +363,7 @@ class FlaxUpEncoderBlock2D(nn.Module):
|
||||
in_channels=in_channels,
|
||||
out_channels=self.out_channels,
|
||||
dropout=self.dropout,
|
||||
groups=self.resnet_groups,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
resnets.append(res_block)
|
||||
@@ -366,6 +394,8 @@ class FlaxUNetMidBlock2D(nn.Module):
|
||||
Dropout rate
|
||||
num_layers (:obj:`int`, *optional*, defaults to 1):
|
||||
Number of Resnet layer block
|
||||
resnet_groups (:obj:`int`, *optional*, defaults to `32`):
|
||||
The number of groups to use for the Resnet and Attention block group norm
|
||||
attn_num_head_channels (:obj:`int`, *optional*, defaults to `1`):
|
||||
Number of attention heads for each attention block
|
||||
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
|
||||
@@ -374,16 +404,20 @@ class FlaxUNetMidBlock2D(nn.Module):
|
||||
in_channels: int
|
||||
dropout: float = 0.0
|
||||
num_layers: int = 1
|
||||
resnet_groups: int = 32
|
||||
attn_num_head_channels: int = 1
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
resnet_groups = self.resnet_groups if self.resnet_groups is not None else min(self.in_channels // 4, 32)
|
||||
|
||||
# there is always at least one resnet
|
||||
resnets = [
|
||||
FlaxResnetBlock2D(
|
||||
in_channels=self.in_channels,
|
||||
out_channels=self.in_channels,
|
||||
dropout=self.dropout,
|
||||
groups=resnet_groups,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
]
|
||||
@@ -392,7 +426,10 @@ class FlaxUNetMidBlock2D(nn.Module):
|
||||
|
||||
for _ in range(self.num_layers):
|
||||
attn_block = FlaxAttentionBlock(
|
||||
channels=self.in_channels, num_head_channels=self.attn_num_head_channels, dtype=self.dtype
|
||||
channels=self.in_channels,
|
||||
num_head_channels=self.attn_num_head_channels,
|
||||
num_groups=resnet_groups,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
attentions.append(attn_block)
|
||||
|
||||
@@ -400,6 +437,7 @@ class FlaxUNetMidBlock2D(nn.Module):
|
||||
in_channels=self.in_channels,
|
||||
out_channels=self.in_channels,
|
||||
dropout=self.dropout,
|
||||
groups=resnet_groups,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
resnets.append(res_block)
|
||||
@@ -441,7 +479,7 @@ class FlaxEncoder(nn.Module):
|
||||
Tuple containing the number of output channels for each block
|
||||
layers_per_block (:obj:`int`, *optional*, defaults to `2`):
|
||||
Number of Resnet layer for each block
|
||||
norm_num_groups (:obj:`int`, *optional*, defaults to `2`):
|
||||
norm_num_groups (:obj:`int`, *optional*, defaults to `32`):
|
||||
norm num group
|
||||
act_fn (:obj:`str`, *optional*, defaults to `silu`):
|
||||
Activation function
|
||||
@@ -483,6 +521,7 @@ class FlaxEncoder(nn.Module):
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
num_layers=self.layers_per_block,
|
||||
resnet_groups=self.norm_num_groups,
|
||||
add_downsample=not is_final_block,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
@@ -491,12 +530,15 @@ class FlaxEncoder(nn.Module):
|
||||
|
||||
# middle
|
||||
self.mid_block = FlaxUNetMidBlock2D(
|
||||
in_channels=block_out_channels[-1], attn_num_head_channels=None, dtype=self.dtype
|
||||
in_channels=block_out_channels[-1],
|
||||
resnet_groups=self.norm_num_groups,
|
||||
attn_num_head_channels=None,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
# end
|
||||
conv_out_channels = 2 * self.out_channels if self.double_z else self.out_channels
|
||||
self.conv_norm_out = nn.GroupNorm(num_groups=32, epsilon=1e-6)
|
||||
self.conv_norm_out = nn.GroupNorm(num_groups=self.norm_num_groups, epsilon=1e-6)
|
||||
self.conv_out = nn.Conv(
|
||||
conv_out_channels,
|
||||
kernel_size=(3, 3),
|
||||
@@ -581,7 +623,10 @@ class FlaxDecoder(nn.Module):
|
||||
|
||||
# middle
|
||||
self.mid_block = FlaxUNetMidBlock2D(
|
||||
in_channels=block_out_channels[-1], attn_num_head_channels=None, dtype=self.dtype
|
||||
in_channels=block_out_channels[-1],
|
||||
resnet_groups=self.norm_num_groups,
|
||||
attn_num_head_channels=None,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
# upsampling
|
||||
@@ -594,10 +639,11 @@ class FlaxDecoder(nn.Module):
|
||||
|
||||
is_final_block = i == len(block_out_channels) - 1
|
||||
|
||||
up_block = FlaxUpEncoderBlock2D(
|
||||
up_block = FlaxUpDecoderBlock2D(
|
||||
in_channels=prev_output_channel,
|
||||
out_channels=output_channel,
|
||||
num_layers=self.layers_per_block + 1,
|
||||
resnet_groups=self.norm_num_groups,
|
||||
add_upsample=not is_final_block,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
@@ -607,7 +653,7 @@ class FlaxDecoder(nn.Module):
|
||||
self.up_blocks = up_blocks
|
||||
|
||||
# end
|
||||
self.conv_norm_out = nn.GroupNorm(num_groups=32, epsilon=1e-6)
|
||||
self.conv_norm_out = nn.GroupNorm(num_groups=self.norm_num_groups, epsilon=1e-6)
|
||||
self.conv_out = nn.Conv(
|
||||
self.out_channels,
|
||||
kernel_size=(3, 3),
|
||||
|
||||
@@ -79,8 +79,10 @@ class OnnxRuntimeModel:
|
||||
|
||||
src_path = self.model_save_dir.joinpath(self.latest_model_name)
|
||||
dst_path = Path(save_directory).joinpath(model_file_name)
|
||||
if not src_path.samefile(dst_path):
|
||||
try:
|
||||
shutil.copyfile(src_path, dst_path)
|
||||
except shutil.SameFileError:
|
||||
pass
|
||||
|
||||
def save_pretrained(
|
||||
self,
|
||||
|
||||
@@ -62,11 +62,6 @@ for library in LOADABLE_CLASSES:
|
||||
ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library])
|
||||
|
||||
|
||||
class DummyChecker:
|
||||
def __init__(self):
|
||||
self.dummy = True
|
||||
|
||||
|
||||
def import_flax_or_no_model(module, class_name):
|
||||
try:
|
||||
# 1. First make sure that if a Flax object is present, import this one
|
||||
@@ -116,24 +111,27 @@ class FlaxDiffusionPipeline(ConfigMixin):
|
||||
from diffusers import pipelines
|
||||
|
||||
for name, module in kwargs.items():
|
||||
# retrieve library
|
||||
library = module.__module__.split(".")[0]
|
||||
if module is None:
|
||||
register_dict = {name: (None, None)}
|
||||
else:
|
||||
# retrieve library
|
||||
library = module.__module__.split(".")[0]
|
||||
|
||||
# check if the module is a pipeline module
|
||||
pipeline_dir = module.__module__.split(".")[-2]
|
||||
path = module.__module__.split(".")
|
||||
is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir)
|
||||
# check if the module is a pipeline module
|
||||
pipeline_dir = module.__module__.split(".")[-2]
|
||||
path = module.__module__.split(".")
|
||||
is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir)
|
||||
|
||||
# if library is not in LOADABLE_CLASSES, then it is a custom module.
|
||||
# Or if it's a pipeline module, then the module is inside the pipeline
|
||||
# folder so we set the library to module name.
|
||||
if library not in LOADABLE_CLASSES or is_pipeline_module:
|
||||
library = pipeline_dir
|
||||
# if library is not in LOADABLE_CLASSES, then it is a custom module.
|
||||
# Or if it's a pipeline module, then the module is inside the pipeline
|
||||
# folder so we set the library to module name.
|
||||
if library not in LOADABLE_CLASSES or is_pipeline_module:
|
||||
library = pipeline_dir
|
||||
|
||||
# retrieve class_name
|
||||
class_name = module.__class__.__name__
|
||||
# retrieve class_name
|
||||
class_name = module.__class__.__name__
|
||||
|
||||
register_dict = {name: (library, class_name)}
|
||||
register_dict = {name: (library, class_name)}
|
||||
|
||||
# save model index config
|
||||
self.register_to_config(**register_dict)
|
||||
@@ -177,10 +175,6 @@ class FlaxDiffusionPipeline(ConfigMixin):
|
||||
if save_method_name is not None:
|
||||
break
|
||||
|
||||
# TODO(Patrick, Suraj): to delete after
|
||||
if isinstance(sub_model, DummyChecker):
|
||||
continue
|
||||
|
||||
save_method = getattr(sub_model, save_method_name)
|
||||
expects_params = "params" in set(inspect.signature(save_method).parameters.keys())
|
||||
|
||||
@@ -194,7 +188,7 @@ class FlaxDiffusionPipeline(ConfigMixin):
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
|
||||
r"""
|
||||
Instantiate a PyTorch diffusion pipeline from pre-trained pipeline weights.
|
||||
Instantiate a Flax diffusion pipeline from pre-trained pipeline weights.
|
||||
|
||||
The pipeline is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated).
|
||||
|
||||
@@ -329,6 +323,11 @@ class FlaxDiffusionPipeline(ConfigMixin):
|
||||
pipeline_class = cls
|
||||
else:
|
||||
diffusers_module = importlib.import_module(cls.__module__.split(".")[0])
|
||||
class_name = (
|
||||
config_dict["_class_name"]
|
||||
if config_dict["_class_name"].startswith("Flax")
|
||||
else "Flax" + config_dict["_class_name"]
|
||||
)
|
||||
pipeline_class = getattr(diffusers_module, config_dict["_class_name"])
|
||||
|
||||
# some modules can be passed directly to the init
|
||||
@@ -349,13 +348,9 @@ class FlaxDiffusionPipeline(ConfigMixin):
|
||||
|
||||
# 3. Load each module in the pipeline
|
||||
for name, (library_name, class_name) in init_dict.items():
|
||||
# TODO(Patrick, Suraj) - delete later
|
||||
if class_name == "DummyChecker":
|
||||
library_name = "stable_diffusion"
|
||||
class_name = "FlaxStableDiffusionSafetyChecker"
|
||||
|
||||
is_pipeline_module = hasattr(pipelines, library_name)
|
||||
loaded_sub_model = None
|
||||
sub_model_should_be_defined = True
|
||||
|
||||
# if the model is in a pipeline module, then we load it from the pipeline
|
||||
if name in passed_class_obj:
|
||||
@@ -376,6 +371,12 @@ class FlaxDiffusionPipeline(ConfigMixin):
|
||||
f"{passed_class_obj[name]} is of type: {type(passed_class_obj[name])}, but should be"
|
||||
f" {expected_class_obj}"
|
||||
)
|
||||
elif passed_class_obj[name] is None:
|
||||
logger.warn(
|
||||
f"You have passed `None` for {name} to disable its functionality in {pipeline_class}. Note"
|
||||
f" that this might lead to problems when using {pipeline_class} and is not recommended."
|
||||
)
|
||||
sub_model_should_be_defined = False
|
||||
else:
|
||||
logger.warn(
|
||||
f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it"
|
||||
@@ -386,25 +387,19 @@ class FlaxDiffusionPipeline(ConfigMixin):
|
||||
loaded_sub_model = passed_class_obj[name]
|
||||
elif is_pipeline_module:
|
||||
pipeline_module = getattr(pipelines, library_name)
|
||||
if from_pt:
|
||||
class_obj = import_flax_or_no_model(pipeline_module, class_name)
|
||||
else:
|
||||
class_obj = getattr(pipeline_module, class_name)
|
||||
class_obj = import_flax_or_no_model(pipeline_module, class_name)
|
||||
|
||||
importable_classes = ALL_IMPORTABLE_CLASSES
|
||||
class_candidates = {c: class_obj for c in importable_classes.keys()}
|
||||
else:
|
||||
# else we just import it from the library.
|
||||
library = importlib.import_module(library_name)
|
||||
if from_pt:
|
||||
class_obj = import_flax_or_no_model(library, class_name)
|
||||
else:
|
||||
class_obj = getattr(library, class_name)
|
||||
class_obj = import_flax_or_no_model(library, class_name)
|
||||
|
||||
importable_classes = LOADABLE_CLASSES[library_name]
|
||||
class_candidates = {c: getattr(library, c) for c in importable_classes.keys()}
|
||||
|
||||
if loaded_sub_model is None:
|
||||
if loaded_sub_model is None and sub_model_should_be_defined:
|
||||
load_method_name = None
|
||||
for class_name, class_candidate in class_candidates.items():
|
||||
if issubclass(class_obj, class_candidate):
|
||||
@@ -422,11 +417,7 @@ class FlaxDiffusionPipeline(ConfigMixin):
|
||||
loaded_sub_model, loaded_params = load_method(loadable_folder, from_pt=from_pt, dtype=dtype)
|
||||
params[name] = loaded_params
|
||||
elif is_transformers_available() and issubclass(class_obj, FlaxPreTrainedModel):
|
||||
# make sure we don't initialize the weights to save time
|
||||
if name == "safety_checker":
|
||||
loaded_sub_model = DummyChecker()
|
||||
loaded_params = {}
|
||||
elif from_pt:
|
||||
if from_pt:
|
||||
# TODO(Suraj): Fix this in Transformers. We should be able to use `_do_init=False` here
|
||||
loaded_sub_model = load_method(loadable_folder, from_pt=from_pt)
|
||||
loaded_params = loaded_sub_model.params
|
||||
|
||||
@@ -29,14 +29,28 @@ from huggingface_hub import snapshot_download
|
||||
from PIL import Image
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from . import __version__
|
||||
from .configuration_utils import ConfigMixin
|
||||
from .dynamic_modules_utils import get_class_from_dynamic_module
|
||||
from .schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
|
||||
from .utils import CONFIG_NAME, DIFFUSERS_CACHE, ONNX_WEIGHTS_NAME, WEIGHTS_NAME, BaseOutput, logging
|
||||
from .utils import (
|
||||
CONFIG_NAME,
|
||||
DIFFUSERS_CACHE,
|
||||
ONNX_WEIGHTS_NAME,
|
||||
WEIGHTS_NAME,
|
||||
BaseOutput,
|
||||
is_transformers_available,
|
||||
logging,
|
||||
)
|
||||
|
||||
|
||||
if is_transformers_available():
|
||||
from transformers import PreTrainedModel
|
||||
|
||||
|
||||
INDEX_FILE = "diffusion_pytorch_model.bin"
|
||||
CUSTOM_PIPELINE_FILE_NAME = "pipeline.py"
|
||||
DUMMY_MODULES_FOLDER = "diffusers.utils"
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@@ -99,23 +113,26 @@ class DiffusionPipeline(ConfigMixin):
|
||||
|
||||
for name, module in kwargs.items():
|
||||
# retrieve library
|
||||
library = module.__module__.split(".")[0]
|
||||
if module is None:
|
||||
register_dict = {name: (None, None)}
|
||||
else:
|
||||
library = module.__module__.split(".")[0]
|
||||
|
||||
# check if the module is a pipeline module
|
||||
pipeline_dir = module.__module__.split(".")[-2]
|
||||
path = module.__module__.split(".")
|
||||
is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir)
|
||||
# check if the module is a pipeline module
|
||||
pipeline_dir = module.__module__.split(".")[-2]
|
||||
path = module.__module__.split(".")
|
||||
is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir)
|
||||
|
||||
# if library is not in LOADABLE_CLASSES, then it is a custom module.
|
||||
# Or if it's a pipeline module, then the module is inside the pipeline
|
||||
# folder so we set the library to module name.
|
||||
if library not in LOADABLE_CLASSES or is_pipeline_module:
|
||||
library = pipeline_dir
|
||||
# if library is not in LOADABLE_CLASSES, then it is a custom module.
|
||||
# Or if it's a pipeline module, then the module is inside the pipeline
|
||||
# folder so we set the library to module name.
|
||||
if library not in LOADABLE_CLASSES or is_pipeline_module:
|
||||
library = pipeline_dir
|
||||
|
||||
# retrieve class_name
|
||||
class_name = module.__class__.__name__
|
||||
# retrieve class_name
|
||||
class_name = module.__class__.__name__
|
||||
|
||||
register_dict = {name: (library, class_name)}
|
||||
register_dict = {name: (library, class_name)}
|
||||
|
||||
# save model index config
|
||||
self.register_to_config(**register_dict)
|
||||
@@ -338,6 +355,7 @@ class DiffusionPipeline(ConfigMixin):
|
||||
custom_pipeline = kwargs.pop("custom_pipeline", None)
|
||||
provider = kwargs.pop("provider", None)
|
||||
sess_options = kwargs.pop("sess_options", None)
|
||||
device_map = kwargs.pop("device_map", None)
|
||||
|
||||
# 1. Download the checkpoints and configs
|
||||
# use snapshot download here to get it working from from_pretrained
|
||||
@@ -359,6 +377,11 @@ class DiffusionPipeline(ConfigMixin):
|
||||
if custom_pipeline is not None:
|
||||
allow_patterns += [CUSTOM_PIPELINE_FILE_NAME]
|
||||
|
||||
requested_pipeline_class = config_dict.get("_class_name", cls.__name__)
|
||||
user_agent = {"diffusers": __version__, "pipeline_class": requested_pipeline_class}
|
||||
if custom_pipeline is not None:
|
||||
user_agent["custom_pipeline"] = custom_pipeline
|
||||
|
||||
# download all allow_patterns
|
||||
cached_folder = snapshot_download(
|
||||
pretrained_model_name_or_path,
|
||||
@@ -369,6 +392,7 @@ class DiffusionPipeline(ConfigMixin):
|
||||
use_auth_token=use_auth_token,
|
||||
revision=revision,
|
||||
allow_patterns=allow_patterns,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
else:
|
||||
cached_folder = pretrained_model_name_or_path
|
||||
@@ -408,6 +432,7 @@ class DiffusionPipeline(ConfigMixin):
|
||||
|
||||
is_pipeline_module = hasattr(pipelines, library_name)
|
||||
loaded_sub_model = None
|
||||
sub_model_should_be_defined = True
|
||||
|
||||
# if the model is in a pipeline module, then we load it from the pipeline
|
||||
if name in passed_class_obj:
|
||||
@@ -428,6 +453,12 @@ class DiffusionPipeline(ConfigMixin):
|
||||
f"{passed_class_obj[name]} is of type: {type(passed_class_obj[name])}, but should be"
|
||||
f" {expected_class_obj}"
|
||||
)
|
||||
elif passed_class_obj[name] is None:
|
||||
logger.warn(
|
||||
f"You have passed `None` for {name} to disable its functionality in {pipeline_class}. Note"
|
||||
f" that this might lead to problems when using {pipeline_class} and is not recommended."
|
||||
)
|
||||
sub_model_should_be_defined = False
|
||||
else:
|
||||
logger.warn(
|
||||
f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it"
|
||||
@@ -448,21 +479,39 @@ class DiffusionPipeline(ConfigMixin):
|
||||
importable_classes = LOADABLE_CLASSES[library_name]
|
||||
class_candidates = {c: getattr(library, c) for c in importable_classes.keys()}
|
||||
|
||||
if loaded_sub_model is None:
|
||||
if loaded_sub_model is None and sub_model_should_be_defined:
|
||||
load_method_name = None
|
||||
for class_name, class_candidate in class_candidates.items():
|
||||
if issubclass(class_obj, class_candidate):
|
||||
load_method_name = importable_classes[class_name][1]
|
||||
|
||||
load_method = getattr(class_obj, load_method_name)
|
||||
if load_method_name is None:
|
||||
none_module = class_obj.__module__
|
||||
if none_module.startswith(DUMMY_MODULES_FOLDER) and "dummy" in none_module:
|
||||
# call class_obj for nice error message of missing requirements
|
||||
class_obj()
|
||||
|
||||
raise ValueError(
|
||||
f"The component {class_obj} of {pipeline_class} cannot be loaded as it does not seem to have"
|
||||
f" any of the loading methods defined in {ALL_IMPORTABLE_CLASSES}."
|
||||
)
|
||||
|
||||
load_method = getattr(class_obj, load_method_name)
|
||||
loading_kwargs = {}
|
||||
|
||||
if issubclass(class_obj, torch.nn.Module):
|
||||
loading_kwargs["torch_dtype"] = torch_dtype
|
||||
if issubclass(class_obj, diffusers.OnnxRuntimeModel):
|
||||
loading_kwargs["provider"] = provider
|
||||
loading_kwargs["sess_options"] = sess_options
|
||||
|
||||
if (
|
||||
issubclass(class_obj, diffusers.ModelMixin)
|
||||
or is_transformers_available()
|
||||
and issubclass(class_obj, PreTrainedModel)
|
||||
):
|
||||
loading_kwargs["device_map"] = device_map
|
||||
|
||||
# check if the module is in a subdirectory
|
||||
if os.path.isdir(os.path.join(cached_folder, name)):
|
||||
loaded_sub_model = load_method(os.path.join(cached_folder, name), **loading_kwargs)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Union
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -20,11 +20,11 @@ class StableDiffusionPipelineOutput(BaseOutput):
|
||||
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
|
||||
nsfw_content_detected (`List[bool]`)
|
||||
List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
||||
(nsfw) content.
|
||||
(nsfw) content, or `None` if safety checking could not be performed.
|
||||
"""
|
||||
|
||||
images: Union[List[PIL.Image.Image], np.ndarray]
|
||||
nsfw_content_detected: List[bool]
|
||||
nsfw_content_detected: Optional[List[bool]]
|
||||
|
||||
|
||||
if is_transformers_available() and is_torch_available():
|
||||
|
||||
@@ -1,17 +1,27 @@
|
||||
from functools import partial
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from flax.core.frozen_dict import FrozenDict
|
||||
from flax.jax_utils import unreplicate
|
||||
from flax.training.common_utils import shard
|
||||
from PIL import Image
|
||||
from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel
|
||||
|
||||
from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel
|
||||
from ...pipeline_flax_utils import FlaxDiffusionPipeline
|
||||
from ...schedulers import FlaxDDIMScheduler, FlaxLMSDiscreteScheduler, FlaxPNDMScheduler
|
||||
from ...utils import logging
|
||||
from . import FlaxStableDiffusionPipelineOutput
|
||||
from .safety_checker_flax import FlaxStableDiffusionSafetyChecker
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
|
||||
r"""
|
||||
Pipeline for text-to-image generation using Stable Diffusion.
|
||||
@@ -52,9 +62,18 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
|
||||
dtype: jnp.dtype = jnp.float32,
|
||||
):
|
||||
super().__init__()
|
||||
scheduler = scheduler.set_format("np")
|
||||
self.dtype = dtype
|
||||
|
||||
if safety_checker is None:
|
||||
logger.warn(
|
||||
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
||||
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
|
||||
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
|
||||
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
|
||||
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
|
||||
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
|
||||
)
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
@@ -78,60 +97,44 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
|
||||
)
|
||||
return text_input.input_ids
|
||||
|
||||
def __call__(
|
||||
def _get_safety_scores(self, features, params):
|
||||
special_cos_dist, cos_dist = self.safety_checker(features, params)
|
||||
return (special_cos_dist, cos_dist)
|
||||
|
||||
def _run_safety_checker(self, images, safety_model_params, jit=False):
|
||||
# safety_model_params should already be replicated when jit is True
|
||||
pil_images = [Image.fromarray(image) for image in images]
|
||||
features = self.feature_extractor(pil_images, return_tensors="np").pixel_values
|
||||
|
||||
if jit:
|
||||
features = shard(features)
|
||||
special_cos_dist, cos_dist = _p_get_safety_scores(self, features, safety_model_params)
|
||||
special_cos_dist = unshard(special_cos_dist)
|
||||
cos_dist = unshard(cos_dist)
|
||||
safety_model_params = unreplicate(safety_model_params)
|
||||
else:
|
||||
special_cos_dist, cos_dist = self._get_safety_scores(features, safety_model_params)
|
||||
|
||||
images, has_nsfw = self.safety_checker.filtered_with_scores(
|
||||
special_cos_dist,
|
||||
cos_dist,
|
||||
images,
|
||||
safety_model_params,
|
||||
)
|
||||
return images, has_nsfw
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
prompt_ids: jnp.array,
|
||||
params: Union[Dict, FrozenDict],
|
||||
prng_seed: jax.random.PRNGKey,
|
||||
num_inference_steps: Optional[int] = 50,
|
||||
height: Optional[int] = 512,
|
||||
width: Optional[int] = 512,
|
||||
guidance_scale: Optional[float] = 7.5,
|
||||
num_inference_steps: int = 50,
|
||||
height: int = 512,
|
||||
width: int = 512,
|
||||
guidance_scale: float = 7.5,
|
||||
latents: Optional[jnp.array] = None,
|
||||
return_dict: bool = True,
|
||||
debug: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`):
|
||||
The prompt or prompts to guide the image generation.
|
||||
height (`int`, *optional*, defaults to 512):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*, defaults to 512):
|
||||
The width in pixels of the generated image.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
guidance_scale (`float`, *optional*, defaults to 7.5):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
||||
deterministic.
|
||||
latents (`jnp.array`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] instead of
|
||||
a plain tuple.
|
||||
|
||||
Returns:
|
||||
[`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] or `tuple`:
|
||||
[`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a
|
||||
`tuple. When returning a tuple, the first element is a list with the generated images, and the second
|
||||
element is a list of `bool`s denoting whether the corresponding generated image likely represents
|
||||
"not-safe-for-work" (nsfw) content, according to the `safety_checker`.
|
||||
"""
|
||||
if height % 8 != 0 or width % 8 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
||||
|
||||
@@ -171,6 +174,8 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
|
||||
t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step]
|
||||
timestep = jnp.broadcast_to(t, latents_input.shape[0])
|
||||
|
||||
latents_input = self.scheduler.scale_model_input(scheduler_state, latents_input, t)
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet.apply(
|
||||
{"params": params["unet"]},
|
||||
@@ -190,6 +195,9 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
|
||||
params["scheduler"], num_inference_steps=num_inference_steps, shape=latents.shape
|
||||
)
|
||||
|
||||
# scale the initial noise by the standard deviation required by the scheduler
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
|
||||
if debug:
|
||||
# run with python for loop
|
||||
for i in range(num_inference_steps):
|
||||
@@ -199,21 +207,119 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
|
||||
|
||||
# scale and decode the image latents with vae
|
||||
latents = 1 / 0.18215 * latents
|
||||
# TODO: check when flax vae gets merged into main
|
||||
image = self.vae.apply({"params": params["vae"]}, latents, method=self.vae.decode).sample
|
||||
|
||||
image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1)
|
||||
return image
|
||||
|
||||
# image = jnp.asarray(image).transpose(0, 2, 3, 1)
|
||||
# run safety checker
|
||||
# TODO: check when flax safety checker gets merged into main
|
||||
# safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="np")
|
||||
# image, has_nsfw_concept = self.safety_checker(
|
||||
# images=image, clip_input=safety_checker_input.pixel_values, params=params["safety_params"]
|
||||
# )
|
||||
has_nsfw_concept = False
|
||||
def __call__(
|
||||
self,
|
||||
prompt_ids: jnp.array,
|
||||
params: Union[Dict, FrozenDict],
|
||||
prng_seed: jax.random.PRNGKey,
|
||||
num_inference_steps: int = 50,
|
||||
height: int = 512,
|
||||
width: int = 512,
|
||||
guidance_scale: float = 7.5,
|
||||
latents: jnp.array = None,
|
||||
return_dict: bool = True,
|
||||
jit: bool = False,
|
||||
debug: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`):
|
||||
The prompt or prompts to guide the image generation.
|
||||
height (`int`, *optional*, defaults to 512):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*, defaults to 512):
|
||||
The width in pixels of the generated image.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
guidance_scale (`float`, *optional*, defaults to 7.5):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
||||
deterministic.
|
||||
latents (`jnp.array`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
jit (`bool`, defaults to `False`):
|
||||
Whether to run `pmap` versions of the generation and safety scoring functions. NOTE: This argument
|
||||
exists because `__call__` is not yet end-to-end pmap-able. It will be removed in a future release.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] instead of
|
||||
a plain tuple.
|
||||
|
||||
Returns:
|
||||
[`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] or `tuple`:
|
||||
[`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a
|
||||
`tuple. When returning a tuple, the first element is a list with the generated images, and the second
|
||||
element is a list of `bool`s denoting whether the corresponding generated image likely represents
|
||||
"not-safe-for-work" (nsfw) content, according to the `safety_checker`.
|
||||
"""
|
||||
if jit:
|
||||
images = _p_generate(
|
||||
self, prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, debug
|
||||
)
|
||||
else:
|
||||
images = self._generate(
|
||||
prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, debug
|
||||
)
|
||||
|
||||
if self.safety_checker is not None:
|
||||
safety_params = params["safety_checker"]
|
||||
images_uint8_casted = (images * 255).round().astype("uint8")
|
||||
num_devices, batch_size = images.shape[:2]
|
||||
|
||||
images_uint8_casted = np.asarray(images_uint8_casted).reshape(num_devices * batch_size, height, width, 3)
|
||||
images_uint8_casted, has_nsfw_concept = self._run_safety_checker(images_uint8_casted, safety_params, jit)
|
||||
images = np.asarray(images)
|
||||
|
||||
# block images
|
||||
if any(has_nsfw_concept):
|
||||
for i, is_nsfw in enumerate(has_nsfw_concept):
|
||||
images[i] = np.asarray(images_uint8_casted[i])
|
||||
|
||||
images = images.reshape(num_devices, batch_size, height, width, 3)
|
||||
else:
|
||||
has_nsfw_concept = False
|
||||
|
||||
if not return_dict:
|
||||
return (image, has_nsfw_concept)
|
||||
return (images, has_nsfw_concept)
|
||||
|
||||
return FlaxStableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
||||
return FlaxStableDiffusionPipelineOutput(images=images, nsfw_content_detected=has_nsfw_concept)
|
||||
|
||||
|
||||
# TODO: maybe use a config dict instead of so many static argnums
|
||||
@partial(jax.pmap, static_broadcasted_argnums=(0, 4, 5, 6, 7, 9))
|
||||
def _p_generate(
|
||||
pipe, prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, debug
|
||||
):
|
||||
return pipe._generate(
|
||||
prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, debug
|
||||
)
|
||||
|
||||
|
||||
@partial(jax.pmap, static_broadcasted_argnums=(0,))
|
||||
def _p_get_safety_scores(pipe, features, params):
|
||||
return pipe._get_safety_scores(features, params)
|
||||
|
||||
|
||||
def unshard(x: jnp.ndarray):
|
||||
# einops.rearrange(x, 'd b ... -> (d b) ...')
|
||||
num_devices, batch_size = x.shape[:2]
|
||||
rest = x.shape[2:]
|
||||
return x.reshape(num_devices * batch_size, *rest)
|
||||
|
||||
@@ -71,6 +71,16 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
||||
new_config["steps_offset"] = 1
|
||||
scheduler._internal_dict = FrozenDict(new_config)
|
||||
|
||||
if safety_checker is None:
|
||||
logger.warn(
|
||||
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
||||
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
|
||||
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
|
||||
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
|
||||
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
|
||||
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
|
||||
)
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
@@ -234,8 +244,8 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
||||
uncond_tokens = [""]
|
||||
elif type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
" {type(prompt)}."
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif isinstance(negative_prompt, str):
|
||||
uncond_tokens = [negative_prompt]
|
||||
@@ -331,12 +341,19 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
||||
image = self.vae.decode(latents).sample
|
||||
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
||||
|
||||
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
|
||||
image, has_nsfw_concept = self.safety_checker(
|
||||
images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
|
||||
)
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
|
||||
if self.safety_checker is not None:
|
||||
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(
|
||||
self.device
|
||||
)
|
||||
image, has_nsfw_concept = self.safety_checker(
|
||||
images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
|
||||
)
|
||||
else:
|
||||
has_nsfw_concept = None
|
||||
|
||||
if output_type == "pil":
|
||||
image = self.numpy_to_pil(image)
|
||||
|
||||
@@ -83,6 +83,16 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
new_config["steps_offset"] = 1
|
||||
scheduler._internal_dict = FrozenDict(new_config)
|
||||
|
||||
if safety_checker is None:
|
||||
logger.warn(
|
||||
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
||||
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
|
||||
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
|
||||
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
|
||||
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
|
||||
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
|
||||
)
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
@@ -249,8 +259,8 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
uncond_tokens = [""]
|
||||
elif type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
" {type(prompt)}."
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif isinstance(negative_prompt, str):
|
||||
uncond_tokens = [negative_prompt]
|
||||
@@ -284,8 +294,25 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
init_latents = init_latent_dist.sample(generator=generator)
|
||||
init_latents = 0.18215 * init_latents
|
||||
|
||||
# expand init_latents for batch_size
|
||||
init_latents = torch.cat([init_latents] * batch_size * num_images_per_prompt, dim=0)
|
||||
if isinstance(prompt, str):
|
||||
prompt = [prompt]
|
||||
if len(prompt) > init_latents.shape[0] and len(prompt) % init_latents.shape[0] == 0:
|
||||
# expand init_latents for batch_size
|
||||
deprecation_message = (
|
||||
f"You have passed {len(prompt)} text prompts (`prompt`), but only {init_latents.shape[0]} initial"
|
||||
" images (`init_image`). Initial images are now duplicating to match the number of text prompts. Note"
|
||||
" that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
|
||||
" your script to pass as many init images as text prompts to suppress this warning."
|
||||
)
|
||||
deprecate("len(prompt) != len(init_image)", "1.0.0", deprecation_message, standard_warn=False)
|
||||
additional_image_per_prompt = len(prompt) // init_latents.shape[0]
|
||||
init_latents = torch.cat([init_latents] * additional_image_per_prompt * num_images_per_prompt, dim=0)
|
||||
elif len(prompt) > init_latents.shape[0] and len(prompt) % init_latents.shape[0] != 0:
|
||||
raise ValueError(
|
||||
f"Cannot duplicate `init_image` of batch size {init_latents.shape[0]} to {len(prompt)} text prompts."
|
||||
)
|
||||
else:
|
||||
init_latents = torch.cat([init_latents] * num_images_per_prompt, dim=0)
|
||||
|
||||
# get the original timestep using init_timestep
|
||||
offset = self.scheduler.config.get("steps_offset", 0)
|
||||
@@ -342,10 +369,15 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
||||
|
||||
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
|
||||
image, has_nsfw_concept = self.safety_checker(
|
||||
images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
|
||||
)
|
||||
if self.safety_checker is not None:
|
||||
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(
|
||||
self.device
|
||||
)
|
||||
image, has_nsfw_concept = self.safety_checker(
|
||||
images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
|
||||
)
|
||||
else:
|
||||
has_nsfw_concept = None
|
||||
|
||||
if output_type == "pil":
|
||||
image = self.numpy_to_pil(image)
|
||||
|
||||
@@ -98,6 +98,16 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
new_config["steps_offset"] = 1
|
||||
scheduler._internal_dict = FrozenDict(new_config)
|
||||
|
||||
if safety_checker is None:
|
||||
logger.warn(
|
||||
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
||||
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
|
||||
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
|
||||
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
|
||||
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
|
||||
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
|
||||
)
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
@@ -266,8 +276,8 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
uncond_tokens = [""]
|
||||
elif type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
" {type(prompt)}."
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif isinstance(negative_prompt, str):
|
||||
uncond_tokens = [negative_prompt]
|
||||
@@ -382,8 +392,13 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
||||
|
||||
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
|
||||
image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values)
|
||||
if self.safety_checker is not None:
|
||||
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(
|
||||
self.device
|
||||
)
|
||||
image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values)
|
||||
else:
|
||||
has_nsfw_concept = None
|
||||
|
||||
if output_type == "pil":
|
||||
image = self.numpy_to_pil(image)
|
||||
|
||||
@@ -108,8 +108,8 @@ class StableDiffusionOnnxPipeline(DiffusionPipeline):
|
||||
uncond_tokens = [""] * batch_size
|
||||
elif type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
" {type(prompt)}."
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif isinstance(negative_prompt, str):
|
||||
uncond_tokens = [negative_prompt] * batch_size
|
||||
|
||||
@@ -19,6 +19,8 @@ def cosine_distance(image_embeds, text_embeds):
|
||||
class StableDiffusionSafetyChecker(PreTrainedModel):
|
||||
config_class = CLIPConfig
|
||||
|
||||
_no_split_modules = ["CLIPEncoderLayer"]
|
||||
|
||||
def __init__(self, config: CLIPConfig):
|
||||
super().__init__(config)
|
||||
|
||||
@@ -28,16 +30,17 @@ class StableDiffusionSafetyChecker(PreTrainedModel):
|
||||
self.concept_embeds = nn.Parameter(torch.ones(17, config.projection_dim), requires_grad=False)
|
||||
self.special_care_embeds = nn.Parameter(torch.ones(3, config.projection_dim), requires_grad=False)
|
||||
|
||||
self.register_buffer("concept_embeds_weights", torch.ones(17))
|
||||
self.register_buffer("special_care_embeds_weights", torch.ones(3))
|
||||
self.concept_embeds_weights = nn.Parameter(torch.ones(17), requires_grad=False)
|
||||
self.special_care_embeds_weights = nn.Parameter(torch.ones(3), requires_grad=False)
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, clip_input, images):
|
||||
pooled_output = self.vision_model(clip_input)[1] # pooled_output
|
||||
image_embeds = self.visual_projection(pooled_output)
|
||||
|
||||
special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds).cpu().numpy()
|
||||
cos_dist = cosine_distance(image_embeds, self.concept_embeds).cpu().numpy()
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
||||
special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds).cpu().float().numpy()
|
||||
cos_dist = cosine_distance(image_embeds, self.concept_embeds).cpu().float().numpy()
|
||||
|
||||
result = []
|
||||
batch_size = image_embeds.shape[0]
|
||||
|
||||
@@ -123,7 +123,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
):
|
||||
deprecate(
|
||||
"tensor_format",
|
||||
"0.5.0",
|
||||
"0.6.0",
|
||||
"If you're running your code in PyTorch, you can safely remove this argument.",
|
||||
take_from=kwargs,
|
||||
)
|
||||
@@ -192,7 +192,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
the number of diffusion steps used when generating samples with a pre-trained model.
|
||||
"""
|
||||
deprecated_offset = deprecate(
|
||||
"offset", "0.5.0", "Please pass `steps_offset` to `__init__` instead.", take_from=kwargs
|
||||
"offset", "0.7.0", "Please pass `steps_offset` to `__init__` instead.", take_from=kwargs
|
||||
)
|
||||
offset = deprecated_offset or self.config.steps_offset
|
||||
|
||||
@@ -283,8 +283,9 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
|
||||
|
||||
if eta > 0:
|
||||
# randn_like does not support generator https://github.com/pytorch/pytorch/issues/27072
|
||||
device = model_output.device if torch.is_tensor(model_output) else "cpu"
|
||||
noise = torch.randn(model_output.shape, generator=generator).to(device)
|
||||
noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator).to(device)
|
||||
variance = self._get_variance(timestep, prev_timestep) ** (0.5) * eta * noise
|
||||
|
||||
prev_sample = prev_sample + variance
|
||||
|
||||
@@ -141,6 +141,23 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
# whether we use the final alpha of the "non-previous" one.
|
||||
self.final_alpha_cumprod = jnp.array(1.0) if set_alpha_to_one else float(self._alphas_cumprod[0])
|
||||
|
||||
# standard deviation of the initial noise distribution
|
||||
self.init_noise_sigma = 1.0
|
||||
|
||||
def scale_model_input(
|
||||
self, state: DDIMSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None
|
||||
) -> jnp.ndarray:
|
||||
"""
|
||||
Args:
|
||||
state (`PNDMSchedulerState`): the `FlaxPNDMScheduler` state data class instance.
|
||||
sample (`jnp.ndarray`): input sample
|
||||
timestep (`int`, optional): current timestep
|
||||
|
||||
Returns:
|
||||
`jnp.ndarray`: scaled input sample
|
||||
"""
|
||||
return sample
|
||||
|
||||
def create_state(self):
|
||||
return DDIMSchedulerState.create(
|
||||
num_train_timesteps=self.config.num_train_timesteps, alphas_cumprod=self._alphas_cumprod
|
||||
|
||||
@@ -116,7 +116,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
):
|
||||
deprecate(
|
||||
"tensor_format",
|
||||
"0.5.0",
|
||||
"0.6.0",
|
||||
"If you're running your code in PyTorch, you can safely remove this argument.",
|
||||
take_from=kwargs,
|
||||
)
|
||||
@@ -133,6 +133,10 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
elif beta_schedule == "squaredcos_cap_v2":
|
||||
# Glide cosine schedule
|
||||
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
||||
elif beta_schedule == "sigmoid":
|
||||
# GeoDiff sigmoid schedule
|
||||
betas = torch.linspace(-6, 6, num_train_timesteps)
|
||||
self.betas = torch.sigmoid(betas) * (beta_end - beta_start) + beta_start
|
||||
else:
|
||||
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
|
||||
|
||||
|
||||
@@ -90,7 +90,7 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
|
||||
):
|
||||
deprecate(
|
||||
"tensor_format",
|
||||
"0.5.0",
|
||||
"0.6.0",
|
||||
"If you're running your code in PyTorch, you can safely remove this argument.",
|
||||
take_from=kwargs,
|
||||
)
|
||||
|
||||
@@ -78,7 +78,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
):
|
||||
deprecate(
|
||||
"tensor_format",
|
||||
"0.5.0",
|
||||
"0.6.0",
|
||||
"If you're running your code in PyTorch, you can safely remove this argument.",
|
||||
take_from=kwargs,
|
||||
)
|
||||
@@ -217,7 +217,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
):
|
||||
deprecate(
|
||||
"timestep as an index",
|
||||
"0.5.0",
|
||||
"0.7.0",
|
||||
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
|
||||
" `LMSDiscreteScheduler.step()` will not be supported in future versions. Make sure to pass"
|
||||
" one of the `scheduler.timesteps` as a timestep.",
|
||||
@@ -267,7 +267,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
if isinstance(timesteps, torch.IntTensor) or isinstance(timesteps, torch.LongTensor):
|
||||
deprecate(
|
||||
"timesteps as indices",
|
||||
"0.5.0",
|
||||
"0.7.0",
|
||||
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
|
||||
" `LMSDiscreteScheduler.add_noise()` will not be supported in future versions. Make sure to"
|
||||
" pass values from `scheduler.timesteps` as timesteps.",
|
||||
|
||||
@@ -104,7 +104,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
):
|
||||
deprecate(
|
||||
"tensor_format",
|
||||
"0.5.0",
|
||||
"0.6.0",
|
||||
"If you're running your code in PyTorch, you can safely remove this argument.",
|
||||
take_from=kwargs,
|
||||
)
|
||||
@@ -159,7 +159,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
the number of diffusion steps used when generating samples with a pre-trained model.
|
||||
"""
|
||||
deprecated_offset = deprecate(
|
||||
"offset", "0.5.0", "Please pass `steps_offset` to `__init__` instead.", take_from=kwargs
|
||||
"offset", "0.7.0", "Please pass `steps_offset` to `__init__` instead.", take_from=kwargs
|
||||
)
|
||||
offset = deprecated_offset or self.config.steps_offset
|
||||
|
||||
|
||||
@@ -153,6 +153,9 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
# mainly at formula (9), (12), (13) and the Algorithm 2.
|
||||
self.pndm_order = 4
|
||||
|
||||
# standard deviation of the initial noise distribution
|
||||
self.init_noise_sigma = 1.0
|
||||
|
||||
def create_state(self):
|
||||
return PNDMSchedulerState.create(num_train_timesteps=self.config.num_train_timesteps)
|
||||
|
||||
@@ -196,7 +199,7 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
)
|
||||
|
||||
return state.replace(
|
||||
timesteps=jnp.concatenate([state.prk_timesteps, state.plms_timesteps]).astype(jnp.int64),
|
||||
timesteps=jnp.concatenate([state.prk_timesteps, state.plms_timesteps]).astype(jnp.int32),
|
||||
counter=0,
|
||||
# Reserve space for the state variables
|
||||
cur_model_output=jnp.zeros(shape),
|
||||
@@ -204,6 +207,23 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
ets=jnp.zeros((4,) + shape),
|
||||
)
|
||||
|
||||
def scale_model_input(
|
||||
self, state: PNDMSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None
|
||||
) -> jnp.ndarray:
|
||||
"""
|
||||
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
||||
current timestep.
|
||||
|
||||
Args:
|
||||
state (`PNDMSchedulerState`): the `FlaxPNDMScheduler` state data class instance.
|
||||
sample (`jnp.ndarray`): input sample
|
||||
timestep (`int`, optional): current timestep
|
||||
|
||||
Returns:
|
||||
`jnp.ndarray`: scaled input sample
|
||||
"""
|
||||
return sample
|
||||
|
||||
def step(
|
||||
self,
|
||||
state: PNDMSchedulerState,
|
||||
|
||||
@@ -79,7 +79,7 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
|
||||
):
|
||||
deprecate(
|
||||
"tensor_format",
|
||||
"0.5.0",
|
||||
"0.6.0",
|
||||
"If you're running your code in PyTorch, you can safely remove this argument.",
|
||||
take_from=kwargs,
|
||||
)
|
||||
@@ -156,10 +156,6 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.discrete_sigmas[timesteps - 1].to(timesteps.device),
|
||||
)
|
||||
|
||||
def set_seed(self, seed):
|
||||
deprecate("set_seed", "0.5.0", "Please consider passing a generator instead.")
|
||||
torch.manual_seed(seed)
|
||||
|
||||
def step_pred(
|
||||
self,
|
||||
model_output: torch.FloatTensor,
|
||||
@@ -167,7 +163,6 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
|
||||
sample: torch.FloatTensor,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
return_dict: bool = True,
|
||||
**kwargs,
|
||||
) -> Union[SdeVeOutput, Tuple]:
|
||||
"""
|
||||
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
|
||||
@@ -186,9 +181,6 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
|
||||
`return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
|
||||
|
||||
"""
|
||||
if "seed" in kwargs and kwargs["seed"] is not None:
|
||||
self.set_seed(kwargs["seed"])
|
||||
|
||||
if self.timesteps is None:
|
||||
raise ValueError(
|
||||
"`self.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler"
|
||||
@@ -231,7 +223,6 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
|
||||
sample: torch.FloatTensor,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
return_dict: bool = True,
|
||||
**kwargs,
|
||||
) -> Union[SchedulerOutput, Tuple]:
|
||||
"""
|
||||
Correct the predicted sample based on the output model_output of the network. This is often run repeatedly
|
||||
@@ -249,9 +240,6 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
|
||||
`return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
|
||||
|
||||
"""
|
||||
if "seed" in kwargs and kwargs["seed"] is not None:
|
||||
self.set_seed(kwargs["seed"])
|
||||
|
||||
if self.timesteps is None:
|
||||
raise ValueError(
|
||||
"`self.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler"
|
||||
|
||||
@@ -43,7 +43,7 @@ class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin):
|
||||
def __init__(self, num_train_timesteps=2000, beta_min=0.1, beta_max=20, sampling_eps=1e-3, **kwargs):
|
||||
deprecate(
|
||||
"tensor_format",
|
||||
"0.5.0",
|
||||
"0.6.0",
|
||||
"If you're running your code in PyTorch, you can safely remove this argument.",
|
||||
take_from=kwargs,
|
||||
)
|
||||
|
||||
@@ -45,7 +45,7 @@ class SchedulerMixin:
|
||||
def set_format(self, tensor_format="pt"):
|
||||
deprecate(
|
||||
"set_format",
|
||||
"0.5.0",
|
||||
"0.6.0",
|
||||
"If you're running your code in PyTorch, you can safely remove this function as the schedulers are always"
|
||||
" in Pytorch",
|
||||
)
|
||||
|
||||
@@ -11,7 +11,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
|
||||
import jax.numpy as jnp
|
||||
@@ -42,12 +41,3 @@ class FlaxSchedulerMixin:
|
||||
"""
|
||||
|
||||
config_name = SCHEDULER_CONFIG_NAME
|
||||
|
||||
def set_format(self, tensor_format="pt"):
|
||||
warnings.warn(
|
||||
"The method `set_format` is deprecated and will be removed in version `0.5.0`."
|
||||
"If you're running your code in PyTorch, you can safely remove this function as the schedulers"
|
||||
"are always in Pytorch",
|
||||
DeprecationWarning,
|
||||
)
|
||||
return self
|
||||
|
||||
@@ -9,3 +9,11 @@ class FlaxStableDiffusionPipeline(metaclass=DummyObject):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["flax", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["flax", "transformers"])
|
||||
|
||||
@@ -10,6 +10,14 @@ class FlaxModelMixin(metaclass=DummyObject):
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["flax"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["flax"])
|
||||
|
||||
|
||||
class FlaxUNet2DConditionModel(metaclass=DummyObject):
|
||||
_backends = ["flax"]
|
||||
@@ -17,6 +25,14 @@ class FlaxUNet2DConditionModel(metaclass=DummyObject):
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["flax"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["flax"])
|
||||
|
||||
|
||||
class FlaxAutoencoderKL(metaclass=DummyObject):
|
||||
_backends = ["flax"]
|
||||
@@ -24,6 +40,14 @@ class FlaxAutoencoderKL(metaclass=DummyObject):
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["flax"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["flax"])
|
||||
|
||||
|
||||
class FlaxDiffusionPipeline(metaclass=DummyObject):
|
||||
_backends = ["flax"]
|
||||
@@ -31,6 +55,14 @@ class FlaxDiffusionPipeline(metaclass=DummyObject):
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["flax"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["flax"])
|
||||
|
||||
|
||||
class FlaxDDIMScheduler(metaclass=DummyObject):
|
||||
_backends = ["flax"]
|
||||
@@ -38,6 +70,14 @@ class FlaxDDIMScheduler(metaclass=DummyObject):
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["flax"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["flax"])
|
||||
|
||||
|
||||
class FlaxDDPMScheduler(metaclass=DummyObject):
|
||||
_backends = ["flax"]
|
||||
@@ -45,6 +85,14 @@ class FlaxDDPMScheduler(metaclass=DummyObject):
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["flax"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["flax"])
|
||||
|
||||
|
||||
class FlaxKarrasVeScheduler(metaclass=DummyObject):
|
||||
_backends = ["flax"]
|
||||
@@ -52,6 +100,14 @@ class FlaxKarrasVeScheduler(metaclass=DummyObject):
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["flax"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["flax"])
|
||||
|
||||
|
||||
class FlaxLMSDiscreteScheduler(metaclass=DummyObject):
|
||||
_backends = ["flax"]
|
||||
@@ -59,6 +115,14 @@ class FlaxLMSDiscreteScheduler(metaclass=DummyObject):
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["flax"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["flax"])
|
||||
|
||||
|
||||
class FlaxPNDMScheduler(metaclass=DummyObject):
|
||||
_backends = ["flax"]
|
||||
@@ -66,6 +130,14 @@ class FlaxPNDMScheduler(metaclass=DummyObject):
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["flax"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["flax"])
|
||||
|
||||
|
||||
class FlaxSchedulerMixin(metaclass=DummyObject):
|
||||
_backends = ["flax"]
|
||||
@@ -73,9 +145,25 @@ class FlaxSchedulerMixin(metaclass=DummyObject):
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["flax"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["flax"])
|
||||
|
||||
|
||||
class FlaxScoreSdeVeScheduler(metaclass=DummyObject):
|
||||
_backends = ["flax"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["flax"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["flax"])
|
||||
|
||||
@@ -10,6 +10,14 @@ class ModelMixin(metaclass=DummyObject):
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class AutoencoderKL(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
@@ -17,6 +25,14 @@ class AutoencoderKL(metaclass=DummyObject):
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class UNet2DConditionModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
@@ -24,6 +40,14 @@ class UNet2DConditionModel(metaclass=DummyObject):
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class UNet2DModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
@@ -31,6 +55,14 @@ class UNet2DModel(metaclass=DummyObject):
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class VQModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
@@ -38,6 +70,14 @@ class VQModel(metaclass=DummyObject):
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
def get_constant_schedule(*args, **kwargs):
|
||||
requires_backends(get_constant_schedule, ["torch"])
|
||||
@@ -73,6 +113,14 @@ class DiffusionPipeline(metaclass=DummyObject):
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class DDIMPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
@@ -80,6 +128,14 @@ class DDIMPipeline(metaclass=DummyObject):
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class DDPMPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
@@ -87,6 +143,14 @@ class DDPMPipeline(metaclass=DummyObject):
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class KarrasVePipeline(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
@@ -94,6 +158,14 @@ class KarrasVePipeline(metaclass=DummyObject):
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class LDMPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
@@ -101,6 +173,14 @@ class LDMPipeline(metaclass=DummyObject):
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class PNDMPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
@@ -108,6 +188,14 @@ class PNDMPipeline(metaclass=DummyObject):
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class ScoreSdeVePipeline(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
@@ -115,6 +203,14 @@ class ScoreSdeVePipeline(metaclass=DummyObject):
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class DDIMScheduler(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
@@ -122,6 +218,14 @@ class DDIMScheduler(metaclass=DummyObject):
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class DDPMScheduler(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
@@ -129,6 +233,14 @@ class DDPMScheduler(metaclass=DummyObject):
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class KarrasVeScheduler(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
@@ -136,6 +248,14 @@ class KarrasVeScheduler(metaclass=DummyObject):
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class PNDMScheduler(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
@@ -143,6 +263,14 @@ class PNDMScheduler(metaclass=DummyObject):
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class SchedulerMixin(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
@@ -150,6 +278,14 @@ class SchedulerMixin(metaclass=DummyObject):
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class ScoreSdeVeScheduler(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
@@ -157,9 +293,25 @@ class ScoreSdeVeScheduler(metaclass=DummyObject):
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class EMAModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@@ -9,3 +9,11 @@ class LMSDiscreteScheduler(metaclass=DummyObject):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "scipy"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "scipy"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "scipy"])
|
||||
|
||||
@@ -9,3 +9,11 @@ class StableDiffusionOnnxPipeline(metaclass=DummyObject):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers", "onnx"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers", "onnx"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers", "onnx"])
|
||||
|
||||
@@ -10,6 +10,14 @@ class LDMTextToImagePipeline(metaclass=DummyObject):
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class StableDiffusionImg2ImgPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
@@ -17,6 +25,14 @@ class StableDiffusionImg2ImgPipeline(metaclass=DummyObject):
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class StableDiffusionInpaintPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
@@ -24,9 +40,25 @@ class StableDiffusionInpaintPipeline(metaclass=DummyObject):
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class StableDiffusionPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@@ -7,20 +7,27 @@ from distutils.util import strtobool
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
|
||||
import PIL.Image
|
||||
import PIL.ImageOps
|
||||
import requests
|
||||
from packaging import version
|
||||
|
||||
from .import_utils import is_flax_available, is_torch_available
|
||||
|
||||
|
||||
global_rng = random.Random()
|
||||
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
is_torch_higher_equal_than_1_12 = version.parse(version.parse(torch.__version__).base_version) >= version.parse("1.12")
|
||||
|
||||
if is_torch_higher_equal_than_1_12:
|
||||
torch_device = "mps" if torch.backends.mps.is_available() else torch_device
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
is_torch_higher_equal_than_1_12 = version.parse(version.parse(torch.__version__).base_version) >= version.parse(
|
||||
"1.12"
|
||||
)
|
||||
|
||||
if is_torch_higher_equal_than_1_12:
|
||||
torch_device = "mps" if torch.backends.mps.is_available() else torch_device
|
||||
|
||||
|
||||
def get_tests_dir(append_path=None):
|
||||
@@ -89,6 +96,13 @@ def slow(test_case):
|
||||
return unittest.skipUnless(_run_slow_tests, "test is slow")(test_case)
|
||||
|
||||
|
||||
def require_flax(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires JAX & Flax. These tests are skipped when one / both are not installed
|
||||
"""
|
||||
return unittest.skipUnless(is_flax_available(), "test requires JAX & Flax")(test_case)
|
||||
|
||||
|
||||
def load_image(image: Union[str, PIL.Image.Image]) -> PIL.Image.Image:
|
||||
"""
|
||||
Args:
|
||||
|
||||
44
tests/test_modeling_common_flax.py
Normal file
44
tests/test_modeling_common_flax.py
Normal file
@@ -0,0 +1,44 @@
|
||||
from diffusers.utils import is_flax_available
|
||||
from diffusers.utils.testing_utils import require_flax
|
||||
|
||||
|
||||
if is_flax_available():
|
||||
import jax
|
||||
|
||||
|
||||
@require_flax
|
||||
class FlaxModelTesterMixin:
|
||||
def test_output(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
variables = model.init(inputs_dict["prng_key"], inputs_dict["sample"])
|
||||
jax.lax.stop_gradient(variables)
|
||||
|
||||
output = model.apply(variables, inputs_dict["sample"])
|
||||
|
||||
if isinstance(output, dict):
|
||||
output = output.sample
|
||||
|
||||
self.assertIsNotNone(output)
|
||||
expected_shape = inputs_dict["sample"].shape
|
||||
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
|
||||
|
||||
def test_forward_with_norm_groups(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
init_dict["norm_num_groups"] = 16
|
||||
init_dict["block_out_channels"] = (16, 32)
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
variables = model.init(inputs_dict["prng_key"], inputs_dict["sample"])
|
||||
jax.lax.stop_gradient(variables)
|
||||
|
||||
output = model.apply(variables, inputs_dict["sample"])
|
||||
|
||||
if isinstance(output, dict):
|
||||
output = output.sample
|
||||
|
||||
self.assertIsNotNone(output)
|
||||
expected_shape = inputs_dict["sample"].shape
|
||||
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
|
||||
@@ -273,37 +273,39 @@ class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
|
||||
assert not model.is_gradient_checkpointing and model.training
|
||||
|
||||
out = model(**inputs_dict).sample
|
||||
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
|
||||
# we won't calculate the loss and rather backprop on out.sum()
|
||||
model.zero_grad()
|
||||
out.sum().backward()
|
||||
|
||||
# now we save the output and parameter gradients that we will use for comparison purposes with
|
||||
# the non-checkpointed run.
|
||||
output_not_checkpointed = out.data.clone()
|
||||
grad_not_checkpointed = {}
|
||||
for name, param in model.named_parameters():
|
||||
grad_not_checkpointed[name] = param.grad.data.clone()
|
||||
labels = torch.randn_like(out)
|
||||
loss = (out - labels).mean()
|
||||
loss.backward()
|
||||
|
||||
model.enable_gradient_checkpointing()
|
||||
out = model(**inputs_dict).sample
|
||||
# re-instantiate the model now enabling gradient checkpointing
|
||||
model_2 = self.model_class(**init_dict)
|
||||
# clone model
|
||||
model_2.load_state_dict(model.state_dict())
|
||||
model_2.to(torch_device)
|
||||
model_2.enable_gradient_checkpointing()
|
||||
|
||||
assert model_2.is_gradient_checkpointing and model_2.training
|
||||
|
||||
out_2 = model_2(**inputs_dict).sample
|
||||
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
|
||||
# we won't calculate the loss and rather backprop on out.sum()
|
||||
model.zero_grad()
|
||||
out.sum().backward()
|
||||
|
||||
# now we save the output and parameter gradients that we will use for comparison purposes with
|
||||
# the non-checkpointed run.
|
||||
output_checkpointed = out.data.clone()
|
||||
grad_checkpointed = {}
|
||||
for name, param in model.named_parameters():
|
||||
grad_checkpointed[name] = param.grad.data.clone()
|
||||
model_2.zero_grad()
|
||||
loss_2 = (out_2 - labels).mean()
|
||||
loss_2.backward()
|
||||
|
||||
# compare the output and parameters gradients
|
||||
self.assertTrue((output_checkpointed == output_not_checkpointed).all())
|
||||
for name in grad_checkpointed:
|
||||
self.assertTrue(torch.allclose(grad_checkpointed[name], grad_not_checkpointed[name], atol=5e-5))
|
||||
self.assertTrue((loss - loss_2).abs() < 1e-5)
|
||||
named_params = dict(model.named_parameters())
|
||||
named_params_2 = dict(model_2.named_parameters())
|
||||
for name, param in named_params.items():
|
||||
self.assertTrue(torch.allclose(param.grad.data, named_params_2[name].grad.data, atol=5e-5))
|
||||
|
||||
|
||||
# TODO(Patrick) - Re-add this test after having cleaned up LDM
|
||||
|
||||
39
tests/test_models_vae_flax.py
Normal file
39
tests/test_models_vae_flax.py
Normal file
@@ -0,0 +1,39 @@
|
||||
import unittest
|
||||
|
||||
from diffusers import FlaxAutoencoderKL
|
||||
from diffusers.utils import is_flax_available
|
||||
from diffusers.utils.testing_utils import require_flax
|
||||
|
||||
from .test_modeling_common_flax import FlaxModelTesterMixin
|
||||
|
||||
|
||||
if is_flax_available():
|
||||
import jax
|
||||
|
||||
|
||||
@require_flax
|
||||
class FlaxAutoencoderKLTests(FlaxModelTesterMixin, unittest.TestCase):
|
||||
model_class = FlaxAutoencoderKL
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
batch_size = 4
|
||||
num_channels = 3
|
||||
sizes = (32, 32)
|
||||
|
||||
prng_key = jax.random.PRNGKey(0)
|
||||
image = jax.random.uniform(prng_key, ((batch_size, num_channels) + sizes))
|
||||
|
||||
return {"sample": image, "prng_key": prng_key}
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"block_out_channels": [32, 64],
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
"down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"],
|
||||
"up_block_types": ["UpDecoderBlock2D", "UpDecoderBlock2D"],
|
||||
"latent_channels": 4,
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
@@ -17,12 +17,15 @@ import gc
|
||||
import os
|
||||
import random
|
||||
import tempfile
|
||||
import tracemalloc
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import accelerate
|
||||
import PIL
|
||||
import transformers
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
DDIMPipeline,
|
||||
@@ -50,6 +53,7 @@ from diffusers.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
|
||||
from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME, floats_tensor, load_image, slow, torch_device
|
||||
from diffusers.utils.testing_utils import get_tests_dir
|
||||
from packaging import version
|
||||
from PIL import Image
|
||||
from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextConfig, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
@@ -488,6 +492,23 @@ class PipelineFastTests(unittest.TestCase):
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
def test_from_pretrained_error_message_uninstalled_packages(self):
|
||||
# TODO(Patrick, Pedro) - need better test here for the future
|
||||
pipe = StableDiffusionPipeline.from_pretrained("hf-internal-testing/tiny-stable-diffusion-lms-pipe")
|
||||
assert isinstance(pipe, StableDiffusionPipeline)
|
||||
assert isinstance(pipe.scheduler, LMSDiscreteScheduler)
|
||||
|
||||
def test_stable_diffusion_no_safety_checker(self):
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
"hf-internal-testing/tiny-stable-diffusion-lms-pipe", safety_checker=None
|
||||
)
|
||||
assert isinstance(pipe, StableDiffusionPipeline)
|
||||
assert isinstance(pipe.scheduler, LMSDiscreteScheduler)
|
||||
assert pipe.safety_checker is None
|
||||
|
||||
image = pipe("example prompt", num_inference_steps=2).images[0]
|
||||
assert image is not None
|
||||
|
||||
def test_stable_diffusion_k_lms(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
unet = self.dummy_cond_unet
|
||||
@@ -565,6 +586,46 @@ class PipelineFastTests(unittest.TestCase):
|
||||
|
||||
assert np.abs(output_2.images.flatten() - output_1.images.flatten()).max() < 1e-4
|
||||
|
||||
def test_stable_diffusion_negative_prompt(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
unet = self.dummy_cond_unet
|
||||
scheduler = PNDMScheduler(skip_prk_steps=True)
|
||||
vae = self.dummy_vae
|
||||
bert = self.dummy_text_encoder
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
# make sure here that pndm scheduler skips prk
|
||||
sd_pipe = StableDiffusionPipeline(
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
vae=vae,
|
||||
text_encoder=bert,
|
||||
tokenizer=tokenizer,
|
||||
safety_checker=self.dummy_safety_checker,
|
||||
feature_extractor=self.dummy_extractor,
|
||||
)
|
||||
sd_pipe = sd_pipe.to(device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "A painting of a squirrel eating a burger"
|
||||
negative_prompt = "french fries"
|
||||
generator = torch.Generator(device=device).manual_seed(0)
|
||||
output = sd_pipe(
|
||||
prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
generator=generator,
|
||||
guidance_scale=6.0,
|
||||
num_inference_steps=2,
|
||||
output_type="np",
|
||||
)
|
||||
|
||||
image = output.images
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 128, 128, 3)
|
||||
expected_slice = np.array([0.4851, 0.4617, 0.4765, 0.5127, 0.4845, 0.5153, 0.5141, 0.4886, 0.4719])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
def test_score_sde_ve_pipeline(self):
|
||||
unet = self.dummy_uncond_unet
|
||||
scheduler = ScoreSdeVeScheduler()
|
||||
@@ -694,6 +755,90 @@ class PipelineFastTests(unittest.TestCase):
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
def test_stable_diffusion_img2img_negative_prompt(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
unet = self.dummy_cond_unet
|
||||
scheduler = PNDMScheduler(skip_prk_steps=True)
|
||||
vae = self.dummy_vae
|
||||
bert = self.dummy_text_encoder
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
init_image = self.dummy_image.to(device)
|
||||
|
||||
# make sure here that pndm scheduler skips prk
|
||||
sd_pipe = StableDiffusionImg2ImgPipeline(
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
vae=vae,
|
||||
text_encoder=bert,
|
||||
tokenizer=tokenizer,
|
||||
safety_checker=self.dummy_safety_checker,
|
||||
feature_extractor=self.dummy_extractor,
|
||||
)
|
||||
sd_pipe = sd_pipe.to(device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "A painting of a squirrel eating a burger"
|
||||
negative_prompt = "french fries"
|
||||
generator = torch.Generator(device=device).manual_seed(0)
|
||||
output = sd_pipe(
|
||||
prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
generator=generator,
|
||||
guidance_scale=6.0,
|
||||
num_inference_steps=2,
|
||||
output_type="np",
|
||||
init_image=init_image,
|
||||
)
|
||||
image = output.images
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 32, 32, 3)
|
||||
expected_slice = np.array([0.4065, 0.3783, 0.4050, 0.5266, 0.4781, 0.4252, 0.4203, 0.4692, 0.4365])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
def test_stable_diffusion_img2img_multiple_init_images(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
unet = self.dummy_cond_unet
|
||||
scheduler = PNDMScheduler(skip_prk_steps=True)
|
||||
vae = self.dummy_vae
|
||||
bert = self.dummy_text_encoder
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
init_image = self.dummy_image.to(device).repeat(2, 1, 1, 1)
|
||||
|
||||
# make sure here that pndm scheduler skips prk
|
||||
sd_pipe = StableDiffusionImg2ImgPipeline(
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
vae=vae,
|
||||
text_encoder=bert,
|
||||
tokenizer=tokenizer,
|
||||
safety_checker=self.dummy_safety_checker,
|
||||
feature_extractor=self.dummy_extractor,
|
||||
)
|
||||
sd_pipe = sd_pipe.to(device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = 2 * ["A painting of a squirrel eating a burger"]
|
||||
generator = torch.Generator(device=device).manual_seed(0)
|
||||
output = sd_pipe(
|
||||
prompt,
|
||||
generator=generator,
|
||||
guidance_scale=6.0,
|
||||
num_inference_steps=2,
|
||||
output_type="np",
|
||||
init_image=init_image,
|
||||
)
|
||||
|
||||
image = output.images
|
||||
|
||||
image_slice = image[-1, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (2, 32, 32, 3)
|
||||
expected_slice = np.array([0.5144, 0.4447, 0.4735, 0.6676, 0.5526, 0.5454, 0.645, 0.5149, 0.4689])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
def test_stable_diffusion_img2img_k_lms(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
unet = self.dummy_cond_unet
|
||||
@@ -809,6 +954,52 @@ class PipelineFastTests(unittest.TestCase):
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
def test_stable_diffusion_inpaint_negative_prompt(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
unet = self.dummy_cond_unet
|
||||
scheduler = PNDMScheduler(skip_prk_steps=True)
|
||||
vae = self.dummy_vae
|
||||
bert = self.dummy_text_encoder
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0]
|
||||
init_image = Image.fromarray(np.uint8(image)).convert("RGB")
|
||||
mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((128, 128))
|
||||
|
||||
# make sure here that pndm scheduler skips prk
|
||||
sd_pipe = StableDiffusionInpaintPipeline(
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
vae=vae,
|
||||
text_encoder=bert,
|
||||
tokenizer=tokenizer,
|
||||
safety_checker=self.dummy_safety_checker,
|
||||
feature_extractor=self.dummy_extractor,
|
||||
)
|
||||
sd_pipe = sd_pipe.to(device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "A painting of a squirrel eating a burger"
|
||||
negative_prompt = "french fries"
|
||||
generator = torch.Generator(device=device).manual_seed(0)
|
||||
output = sd_pipe(
|
||||
prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
generator=generator,
|
||||
guidance_scale=6.0,
|
||||
num_inference_steps=2,
|
||||
output_type="np",
|
||||
init_image=init_image,
|
||||
mask_image=mask_image,
|
||||
)
|
||||
|
||||
image = output.images
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 32, 32, 3)
|
||||
expected_slice = np.array([0.4765, 0.5339, 0.4541, 0.6240, 0.5439, 0.4055, 0.5503, 0.5891, 0.5150])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
def test_stable_diffusion_num_images_per_prompt(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
unet = self.dummy_cond_unet
|
||||
@@ -1851,13 +2042,21 @@ class PipelineTesterMixin(unittest.TestCase):
|
||||
[1.8285, 1.2857, -0.1024, 1.2406, -2.3068, 1.0747, -0.0818, -0.6520, -2.9506]
|
||||
)
|
||||
assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3
|
||||
elif step == 50:
|
||||
latents = latents.detach().cpu().numpy()
|
||||
assert latents.shape == (1, 4, 64, 64)
|
||||
latents_slice = latents[0, -3:, -3:, -1]
|
||||
expected_slice = np.array(
|
||||
[1.1078, 1.5803, 0.2773, -0.0589, -1.7928, -0.3665, -0.4695, -1.0727, -1.1601]
|
||||
)
|
||||
assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
test_callback_fn.has_been_called = False
|
||||
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16
|
||||
)
|
||||
pipe.to(torch_device)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.enable_attention_slicing()
|
||||
|
||||
@@ -1891,6 +2090,12 @@ class PipelineTesterMixin(unittest.TestCase):
|
||||
latents_slice = latents[0, -3:, -3:, -1]
|
||||
expected_slice = np.array([0.9052, -0.0184, 0.4810, 0.2898, 0.5851, 1.4920, 0.5362, 1.9838, 0.0530])
|
||||
assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3
|
||||
elif step == 37:
|
||||
latents = latents.detach().cpu().numpy()
|
||||
assert latents.shape == (1, 4, 64, 96)
|
||||
latents_slice = latents[0, -3:, -3:, -1]
|
||||
expected_slice = np.array([0.7071, 0.7831, 0.8300, 1.8140, 1.7840, 1.9402, 1.3651, 1.6590, 1.2828])
|
||||
assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
test_callback_fn.has_been_called = False
|
||||
|
||||
@@ -1941,6 +2146,12 @@ class PipelineTesterMixin(unittest.TestCase):
|
||||
[-0.5472, 1.1218, -0.5505, -0.9390, -1.0794, 0.4063, 0.5158, 0.6429, -1.5246]
|
||||
)
|
||||
assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3
|
||||
elif step == 37:
|
||||
latents = latents.detach().cpu().numpy()
|
||||
assert latents.shape == (1, 4, 64, 64)
|
||||
latents_slice = latents[0, -3:, -3:, -1]
|
||||
expected_slice = np.array([0.4781, 1.1572, 0.6258, 0.2291, 0.2554, -0.1443, 0.7085, -0.1598, -0.5659])
|
||||
assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3
|
||||
|
||||
test_callback_fn.has_been_called = False
|
||||
|
||||
@@ -1993,6 +2204,13 @@ class PipelineTesterMixin(unittest.TestCase):
|
||||
[-0.5950, -0.3039, -1.1672, 0.1594, -1.1572, 0.6719, -1.9712, -0.0403, 0.9592]
|
||||
)
|
||||
assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3
|
||||
elif step == 5:
|
||||
assert latents.shape == (1, 4, 64, 64)
|
||||
latents_slice = latents[0, -3:, -3:, -1]
|
||||
expected_slice = np.array(
|
||||
[-0.4776, -0.0119, -0.8519, -0.0275, -0.9764, 0.9820, -0.3843, 0.3788, 1.2264]
|
||||
)
|
||||
assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3
|
||||
|
||||
test_callback_fn.has_been_called = False
|
||||
|
||||
@@ -2007,3 +2225,53 @@ class PipelineTesterMixin(unittest.TestCase):
|
||||
pipe(prompt=prompt, num_inference_steps=5, guidance_scale=7.5, callback=test_callback_fn, callback_steps=1)
|
||||
assert test_callback_fn.has_been_called
|
||||
assert number_of_steps == 6
|
||||
|
||||
@slow
|
||||
@unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
|
||||
def test_stable_diffusion_accelerate_load_works(self):
|
||||
if version.parse(version.parse(transformers.__version__).base_version) < version.parse("4.23"):
|
||||
return
|
||||
|
||||
if version.parse(version.parse(accelerate.__version__).base_version) < version.parse("0.14"):
|
||||
return
|
||||
|
||||
model_id = "CompVis/stable-diffusion-v1-4"
|
||||
_ = StableDiffusionPipeline.from_pretrained(
|
||||
model_id, revision="fp16", torch_dtype=torch.float16, use_auth_token=True, device_map="auto"
|
||||
).to(torch_device)
|
||||
|
||||
@slow
|
||||
@unittest.skipIf(torch_device == "cpu", "This test is supposed to run on GPU")
|
||||
def test_stable_diffusion_accelerate_load_reduces_memory_footprint(self):
|
||||
if version.parse(version.parse(transformers.__version__).base_version) < version.parse("4.23"):
|
||||
return
|
||||
|
||||
if version.parse(version.parse(accelerate.__version__).base_version) < version.parse("0.14"):
|
||||
return
|
||||
|
||||
pipeline_id = "CompVis/stable-diffusion-v1-4"
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
tracemalloc.start()
|
||||
pipeline_normal_load = StableDiffusionPipeline.from_pretrained(
|
||||
pipeline_id, revision="fp16", torch_dtype=torch.float16, use_auth_token=True
|
||||
)
|
||||
pipeline_normal_load.to(torch_device)
|
||||
_, peak_normal = tracemalloc.get_traced_memory()
|
||||
tracemalloc.stop()
|
||||
|
||||
del pipeline_normal_load
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
tracemalloc.start()
|
||||
_ = StableDiffusionPipeline.from_pretrained(
|
||||
pipeline_id, revision="fp16", torch_dtype=torch.float16, use_auth_token=True, device_map="auto"
|
||||
)
|
||||
_, peak_accelerate = tracemalloc.get_traced_memory()
|
||||
|
||||
tracemalloc.stop()
|
||||
|
||||
assert peak_accelerate < peak_normal
|
||||
|
||||
201
tests/test_pipelines_flax.py
Normal file
201
tests/test_pipelines_flax.py
Normal file
@@ -0,0 +1,201 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from diffusers.utils import is_flax_available
|
||||
from diffusers.utils.testing_utils import require_flax, slow
|
||||
|
||||
|
||||
if is_flax_available():
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from diffusers import FlaxDDIMScheduler, FlaxStableDiffusionPipeline
|
||||
from flax.jax_utils import replicate
|
||||
from flax.training.common_utils import shard
|
||||
from jax import pmap
|
||||
|
||||
|
||||
@require_flax
|
||||
@slow
|
||||
class FlaxPipelineTests(unittest.TestCase):
|
||||
def test_dummy_all_tpus(self):
|
||||
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
|
||||
"hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None
|
||||
)
|
||||
|
||||
prompt = (
|
||||
"A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of"
|
||||
" field, close up, split lighting, cinematic"
|
||||
)
|
||||
|
||||
prng_seed = jax.random.PRNGKey(0)
|
||||
num_inference_steps = 4
|
||||
|
||||
num_samples = jax.device_count()
|
||||
prompt = num_samples * [prompt]
|
||||
prompt_ids = pipeline.prepare_inputs(prompt)
|
||||
|
||||
p_sample = pmap(pipeline.__call__, static_broadcasted_argnums=(3,))
|
||||
|
||||
# shard inputs and rng
|
||||
params = replicate(params)
|
||||
prng_seed = jax.random.split(prng_seed, 8)
|
||||
prompt_ids = shard(prompt_ids)
|
||||
|
||||
images = p_sample(prompt_ids, params, prng_seed, num_inference_steps).images
|
||||
|
||||
assert images.shape == (8, 1, 64, 64, 3)
|
||||
assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 4.151474)) < 1e-3
|
||||
assert np.abs((np.abs(images, dtype=np.float32).sum() - 49947.875)) < 5e-1
|
||||
|
||||
images_pil = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
|
||||
|
||||
assert len(images_pil) == 8
|
||||
|
||||
def test_stable_diffusion_v1_4(self):
|
||||
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4", revision="flax", safety_checker=None
|
||||
)
|
||||
|
||||
prompt = (
|
||||
"A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of"
|
||||
" field, close up, split lighting, cinematic"
|
||||
)
|
||||
|
||||
prng_seed = jax.random.PRNGKey(0)
|
||||
num_inference_steps = 50
|
||||
|
||||
num_samples = jax.device_count()
|
||||
prompt = num_samples * [prompt]
|
||||
prompt_ids = pipeline.prepare_inputs(prompt)
|
||||
|
||||
p_sample = pmap(pipeline.__call__, static_broadcasted_argnums=(3,))
|
||||
|
||||
# shard inputs and rng
|
||||
params = replicate(params)
|
||||
prng_seed = jax.random.split(prng_seed, 8)
|
||||
prompt_ids = shard(prompt_ids)
|
||||
|
||||
images = p_sample(prompt_ids, params, prng_seed, num_inference_steps).images
|
||||
|
||||
assert images.shape == (8, 1, 512, 512, 3)
|
||||
assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 0.05652401)) < 1e-3
|
||||
assert np.abs((np.abs(images, dtype=np.float32).sum() - 2383808.2)) < 5e-1
|
||||
|
||||
def test_stable_diffusion_v1_4_bfloat_16(self):
|
||||
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4", revision="bf16", dtype=jnp.bfloat16, safety_checker=None
|
||||
)
|
||||
|
||||
prompt = (
|
||||
"A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of"
|
||||
" field, close up, split lighting, cinematic"
|
||||
)
|
||||
|
||||
prng_seed = jax.random.PRNGKey(0)
|
||||
num_inference_steps = 50
|
||||
|
||||
num_samples = jax.device_count()
|
||||
prompt = num_samples * [prompt]
|
||||
prompt_ids = pipeline.prepare_inputs(prompt)
|
||||
|
||||
p_sample = pmap(pipeline.__call__, static_broadcasted_argnums=(3,))
|
||||
|
||||
# shard inputs and rng
|
||||
params = replicate(params)
|
||||
prng_seed = jax.random.split(prng_seed, 8)
|
||||
prompt_ids = shard(prompt_ids)
|
||||
|
||||
images = p_sample(prompt_ids, params, prng_seed, num_inference_steps).images
|
||||
|
||||
assert images.shape == (8, 1, 512, 512, 3)
|
||||
assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 0.06652832)) < 1e-3
|
||||
assert np.abs((np.abs(images, dtype=np.float32).sum() - 2384849.8)) < 5e-1
|
||||
|
||||
def test_stable_diffusion_v1_4_bfloat_16_with_safety(self):
|
||||
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4", revision="bf16", dtype=jnp.bfloat16
|
||||
)
|
||||
|
||||
prompt = (
|
||||
"A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of"
|
||||
" field, close up, split lighting, cinematic"
|
||||
)
|
||||
|
||||
prng_seed = jax.random.PRNGKey(0)
|
||||
num_inference_steps = 50
|
||||
|
||||
num_samples = jax.device_count()
|
||||
prompt = num_samples * [prompt]
|
||||
prompt_ids = pipeline.prepare_inputs(prompt)
|
||||
|
||||
# shard inputs and rng
|
||||
params = replicate(params)
|
||||
prng_seed = jax.random.split(prng_seed, 8)
|
||||
prompt_ids = shard(prompt_ids)
|
||||
|
||||
images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
|
||||
|
||||
assert images.shape == (8, 1, 512, 512, 3)
|
||||
assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 0.06652832)) < 1e-3
|
||||
assert np.abs((np.abs(images, dtype=np.float32).sum() - 2384849.8)) < 5e-1
|
||||
|
||||
def test_stable_diffusion_v1_4_bfloat_16_ddim(self):
|
||||
scheduler = FlaxDDIMScheduler(
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
beta_schedule="scaled_linear",
|
||||
set_alpha_to_one=False,
|
||||
steps_offset=1,
|
||||
)
|
||||
|
||||
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4",
|
||||
revision="bf16",
|
||||
dtype=jnp.bfloat16,
|
||||
scheduler=scheduler,
|
||||
safety_checker=None,
|
||||
)
|
||||
scheduler_state = scheduler.create_state()
|
||||
|
||||
params["scheduler"] = scheduler_state
|
||||
|
||||
prompt = (
|
||||
"A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of"
|
||||
" field, close up, split lighting, cinematic"
|
||||
)
|
||||
|
||||
prng_seed = jax.random.PRNGKey(0)
|
||||
num_inference_steps = 50
|
||||
|
||||
num_samples = jax.device_count()
|
||||
prompt = num_samples * [prompt]
|
||||
prompt_ids = pipeline.prepare_inputs(prompt)
|
||||
|
||||
p_sample = pmap(pipeline.__call__, static_broadcasted_argnums=(3,))
|
||||
|
||||
# shard inputs and rng
|
||||
params = replicate(params)
|
||||
prng_seed = jax.random.split(prng_seed, 8)
|
||||
prompt_ids = shard(prompt_ids)
|
||||
|
||||
images = p_sample(prompt_ids, params, prng_seed, num_inference_steps).images
|
||||
|
||||
assert images.shape == (8, 1, 512, 512, 3)
|
||||
assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 0.045043945)) < 1e-3
|
||||
assert np.abs((np.abs(images, dtype=np.float32).sum() - 2347693.5)) < 5e-1
|
||||
@@ -38,6 +38,14 @@ class {0}(metaclass=DummyObject):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, {1})
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, {1})
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, {1})
|
||||
"""
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user