mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-08 21:44:27 +08:00
Compare commits
2 Commits
pipeline-i
...
test-fixes
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2461933857 | ||
|
|
11190ed09a |
@@ -244,8 +244,6 @@
|
||||
- sections:
|
||||
- local: api/pipelines/overview
|
||||
title: Overview
|
||||
- local: api/pipelines/amused
|
||||
title: aMUSEd
|
||||
- local: api/pipelines/animatediff
|
||||
title: AnimateDiff
|
||||
- local: api/pipelines/attend_and_excite
|
||||
|
||||
@@ -1,30 +0,0 @@
|
||||
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# aMUSEd
|
||||
|
||||
Amused is a lightweight text to image model based off of the [muse](https://arxiv.org/pdf/2301.00704.pdf) architecture. Amused is particularly useful in applications that require a lightweight and fast model such as generating many images quickly at once.
|
||||
|
||||
Amused is a vqvae token based transformer that can generate an image in fewer forward passes than many diffusion models. In contrast with muse, it uses the smaller text encoder CLIP-L/14 instead of t5-xxl. Due to its small parameter count and few forward pass generation process, amused can generate many images quickly. This benefit is seen particularly at larger batch sizes.
|
||||
|
||||
| Model | Params |
|
||||
|-------|--------|
|
||||
| [amused-256](https://huggingface.co/huggingface/amused-256) | 603M |
|
||||
| [amused-512](https://huggingface.co/huggingface/amused-512) | 608M |
|
||||
|
||||
## AmusedPipeline
|
||||
|
||||
[[autodoc]] AmusedPipeline
|
||||
- __call__
|
||||
- all
|
||||
- enable_xformers_memory_efficient_attention
|
||||
- disable_xformers_memory_efficient_attention
|
||||
@@ -31,14 +31,14 @@ Make sure to check out the Stable Diffusion [Tips](overview#tips) section to lea
|
||||
|
||||
## StableDiffusionLDM3DPipeline
|
||||
|
||||
[[autodoc]] pipelines.stable_diffusion_ldm3d.pipeline_stable_diffusion_ldm3d.StableDiffusionLDM3DPipeline
|
||||
[[autodoc]] pipelines.stable_diffusion.pipeline_stable_diffusion_ldm3d.StableDiffusionLDM3DPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
|
||||
## LDM3DPipelineOutput
|
||||
|
||||
[[autodoc]] pipelines.stable_diffusion_ldm3d.pipeline_stable_diffusion_ldm3d.LDM3DPipelineOutput
|
||||
[[autodoc]] pipelines.stable_diffusion.pipeline_stable_diffusion_ldm3d.LDM3DPipelineOutput
|
||||
- all
|
||||
- __call__
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
# T2I-Adapter
|
||||
|
||||
[T2I-Adapter](https://hf.co/papers/2302.08453) is a lightweight adapter model that provides an additional conditioning input image (line art, canny, sketch, depth, pose) to better control image generation. It is similar to a ControlNet, but it is a lot smaller (~77M parameters and ~300MB file size) because its only inserts weights into the UNet instead of copying and training it.
|
||||
[T2I-Adapter]((https://hf.co/papers/2302.08453)) is a lightweight adapter model that provides an additional conditioning input image (line art, canny, sketch, depth, pose) to better control image generation. It is similar to a ControlNet, but it is a lot smaller (~77M parameters and ~300MB file size) because its only inserts weights into the UNet instead of copying and training it.
|
||||
|
||||
The T2I-Adapter is only available for training with the Stable Diffusion XL (SDXL) model.
|
||||
|
||||
|
||||
@@ -63,42 +63,3 @@ With callbacks, you can implement features such as dynamic CFG without having to
|
||||
🤗 Diffusers currently only supports `callback_on_step_end`, but feel free to open a [feature request](https://github.com/huggingface/diffusers/issues/new/choose) if you have a cool use-case and require a callback function with a different execution point!
|
||||
|
||||
</Tip>
|
||||
|
||||
|
||||
## Using Callbacks to interrupt the Diffusion Process
|
||||
|
||||
The following Pipelines support interrupting the diffusion process via callback
|
||||
|
||||
- [StableDiffusionPipeline](../api/pipelines/stable_diffusion/overview.md)
|
||||
- [StableDiffusionImg2ImgPipeline](..api/pipelines/stable_diffusion/img2img.md)
|
||||
- [StableDiffusionInpaintPipeline](..api/pipelines/stable_diffusion/inpaint.md)
|
||||
- [StableDiffusionXLPipeline](../api/pipelines/stable_diffusion/stable_diffusion_xl.md)
|
||||
- [StableDiffusionXLImg2ImgPipeline](../api/pipelines/stable_diffusion/stable_diffusion_xl.md)
|
||||
- [StableDiffusionXLInpaintPipeline](../api/pipelines/stable_diffusion/stable_diffusion_xl.md)
|
||||
|
||||
Interrupting the diffusion process is particularly useful when building UIs that work with Diffusers because it allows users to stop the generation process if they're unhappy with the intermediate results. You can incorporate this into your pipeline with a callback.
|
||||
|
||||
This callback function should take the following arguments: `pipe`, `i`, `t`, and `callback_kwargs` (this must be returned). Set the pipeline's `_interrupt` attribute to `True` to stop the diffusion process after a certain number of steps. You are also free to implement your own custom stopping logic inside the callback.
|
||||
|
||||
In this example, the diffusion process is stopped after 10 steps even though `num_inference_steps` is set to 50.
|
||||
|
||||
```python
|
||||
from diffusers import StableDiffusionPipeline
|
||||
|
||||
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
|
||||
pipe.enable_model_cpu_offload()
|
||||
num_inference_steps = 50
|
||||
|
||||
def interrupt_callback(pipe, i, t, callback_kwargs):
|
||||
stop_idx = 10
|
||||
if i == stop_idx:
|
||||
pipe._interrupt = True
|
||||
|
||||
return callback_kwargs
|
||||
|
||||
pipe(
|
||||
"A photo of a cat",
|
||||
num_inference_steps=num_inference_steps,
|
||||
callback_on_step_end=interrupt_callback,
|
||||
)
|
||||
```
|
||||
|
||||
@@ -1,326 +0,0 @@
|
||||
## Amused training
|
||||
|
||||
Amused can be finetuned on simple datasets relatively cheaply and quickly. Using 8bit optimizers, lora, and gradient accumulation, amused can be finetuned with as little as 5.5 GB. Here are a set of examples for finetuning amused on some relatively simple datasets. These training recipies are aggressively oriented towards minimal resources and fast verification -- i.e. the batch sizes are quite low and the learning rates are quite high. For optimal quality, you will probably want to increase the batch sizes and decrease learning rates.
|
||||
|
||||
All training examples use fp16 mixed precision and gradient checkpointing. We don't show 8 bit adam + lora as its about the same memory use as just using lora (bitsandbytes uses full precision optimizer states for weights below a minimum size).
|
||||
|
||||
### Finetuning the 256 checkpoint
|
||||
|
||||
These examples finetune on this [nouns](https://huggingface.co/datasets/m1guelpf/nouns) dataset.
|
||||
|
||||
Example results:
|
||||
|
||||
  
|
||||
|
||||
|
||||
#### Full finetuning
|
||||
|
||||
Batch size: 8, Learning rate: 1e-4, Gives decent results in 750-1000 steps
|
||||
|
||||
| Batch Size | Gradient Accumulation Steps | Effective Total Batch Size | Memory Used |
|
||||
|------------|-----------------------------|------------------|-------------|
|
||||
| 8 | 1 | 8 | 19.7 GB |
|
||||
| 4 | 2 | 8 | 18.3 GB |
|
||||
| 1 | 8 | 8 | 17.9 GB |
|
||||
|
||||
```sh
|
||||
accelerate launch train_amused.py \
|
||||
--output_dir <output path> \
|
||||
--train_batch_size <batch size> \
|
||||
--gradient_accumulation_steps <gradient accumulation steps> \
|
||||
--learning_rate 1e-4 \
|
||||
--pretrained_model_name_or_path huggingface/amused-256 \
|
||||
--instance_data_dataset 'm1guelpf/nouns' \
|
||||
--image_key image \
|
||||
--prompt_key text \
|
||||
--resolution 256 \
|
||||
--mixed_precision fp16 \
|
||||
--lr_scheduler constant \
|
||||
--validation_prompts \
|
||||
'a pixel art character with square red glasses, a baseball-shaped head and a orange-colored body on a dark background' \
|
||||
'a pixel art character with square orange glasses, a lips-shaped head and a red-colored body on a light background' \
|
||||
'a pixel art character with square blue glasses, a microwave-shaped head and a purple-colored body on a sunny background' \
|
||||
'a pixel art character with square red glasses, a baseball-shaped head and a blue-colored body on an orange background' \
|
||||
'a pixel art character with square red glasses' \
|
||||
'a pixel art character' \
|
||||
'square red glasses on a pixel art character' \
|
||||
'square red glasses on a pixel art character with a baseball-shaped head' \
|
||||
--max_train_steps 10000 \
|
||||
--checkpointing_steps 500 \
|
||||
--validation_steps 250 \
|
||||
--gradient_checkpointing
|
||||
```
|
||||
|
||||
#### Full finetuning + 8 bit adam
|
||||
|
||||
Note that this training config keeps the batch size low and the learning rate high to get results fast with low resources. However, due to 8 bit adam, it will diverge eventually. If you want to train for longer, you will have to up the batch size and lower the learning rate.
|
||||
|
||||
Batch size: 16, Learning rate: 2e-5, Gives decent results in ~750 steps
|
||||
|
||||
| Batch Size | Gradient Accumulation Steps | Effective Total Batch Size | Memory Used |
|
||||
|------------|-----------------------------|------------------|-------------|
|
||||
| 16 | 1 | 16 | 20.1 GB |
|
||||
| 8 | 2 | 16 | 15.6 GB |
|
||||
| 1 | 16 | 16 | 10.7 GB |
|
||||
|
||||
```sh
|
||||
accelerate launch train_amused.py \
|
||||
--output_dir <output path> \
|
||||
--train_batch_size <batch size> \
|
||||
--gradient_accumulation_steps <gradient accumulation steps> \
|
||||
--learning_rate 2e-5 \
|
||||
--use_8bit_adam \
|
||||
--pretrained_model_name_or_path huggingface/amused-256 \
|
||||
--instance_data_dataset 'm1guelpf/nouns' \
|
||||
--image_key image \
|
||||
--prompt_key text \
|
||||
--resolution 256 \
|
||||
--mixed_precision fp16 \
|
||||
--lr_scheduler constant \
|
||||
--validation_prompts \
|
||||
'a pixel art character with square red glasses, a baseball-shaped head and a orange-colored body on a dark background' \
|
||||
'a pixel art character with square orange glasses, a lips-shaped head and a red-colored body on a light background' \
|
||||
'a pixel art character with square blue glasses, a microwave-shaped head and a purple-colored body on a sunny background' \
|
||||
'a pixel art character with square red glasses, a baseball-shaped head and a blue-colored body on an orange background' \
|
||||
'a pixel art character with square red glasses' \
|
||||
'a pixel art character' \
|
||||
'square red glasses on a pixel art character' \
|
||||
'square red glasses on a pixel art character with a baseball-shaped head' \
|
||||
--max_train_steps 10000 \
|
||||
--checkpointing_steps 500 \
|
||||
--validation_steps 250 \
|
||||
--gradient_checkpointing
|
||||
```
|
||||
|
||||
#### Full finetuning + lora
|
||||
|
||||
Batch size: 16, Learning rate: 8e-4, Gives decent results in 1000-1250 steps
|
||||
|
||||
| Batch Size | Gradient Accumulation Steps | Effective Total Batch Size | Memory Used |
|
||||
|------------|-----------------------------|------------------|-------------|
|
||||
| 16 | 1 | 16 | 14.1 GB |
|
||||
| 8 | 2 | 16 | 10.1 GB |
|
||||
| 1 | 16 | 16 | 6.5 GB |
|
||||
|
||||
```sh
|
||||
accelerate launch train_amused.py \
|
||||
--output_dir <output path> \
|
||||
--train_batch_size <batch size> \
|
||||
--gradient_accumulation_steps <gradient accumulation steps> \
|
||||
--learning_rate 8e-4 \
|
||||
--use_lora \
|
||||
--pretrained_model_name_or_path huggingface/amused-256 \
|
||||
--instance_data_dataset 'm1guelpf/nouns' \
|
||||
--image_key image \
|
||||
--prompt_key text \
|
||||
--resolution 256 \
|
||||
--mixed_precision fp16 \
|
||||
--lr_scheduler constant \
|
||||
--validation_prompts \
|
||||
'a pixel art character with square red glasses, a baseball-shaped head and a orange-colored body on a dark background' \
|
||||
'a pixel art character with square orange glasses, a lips-shaped head and a red-colored body on a light background' \
|
||||
'a pixel art character with square blue glasses, a microwave-shaped head and a purple-colored body on a sunny background' \
|
||||
'a pixel art character with square red glasses, a baseball-shaped head and a blue-colored body on an orange background' \
|
||||
'a pixel art character with square red glasses' \
|
||||
'a pixel art character' \
|
||||
'square red glasses on a pixel art character' \
|
||||
'square red glasses on a pixel art character with a baseball-shaped head' \
|
||||
--max_train_steps 10000 \
|
||||
--checkpointing_steps 500 \
|
||||
--validation_steps 250 \
|
||||
--gradient_checkpointing
|
||||
```
|
||||
|
||||
### Finetuning the 512 checkpoint
|
||||
|
||||
These examples finetune on this [minecraft](https://huggingface.co/monadical-labs/minecraft-preview) dataset.
|
||||
|
||||
Example results:
|
||||
|
||||
  
|
||||
|
||||
#### Full finetuning
|
||||
|
||||
Batch size: 8, Learning rate: 8e-5, Gives decent results in 500-1000 steps
|
||||
|
||||
| Batch Size | Gradient Accumulation Steps | Effective Total Batch Size | Memory Used |
|
||||
|------------|-----------------------------|------------------|-------------|
|
||||
| 8 | 1 | 8 | 24.2 GB |
|
||||
| 4 | 2 | 8 | 19.7 GB |
|
||||
| 1 | 8 | 8 | 16.99 GB |
|
||||
|
||||
```sh
|
||||
accelerate launch train_amused.py \
|
||||
--output_dir <output path> \
|
||||
--train_batch_size <batch size> \
|
||||
--gradient_accumulation_steps <gradient accumulation steps> \
|
||||
--learning_rate 8e-5 \
|
||||
--pretrained_model_name_or_path huggingface/amused-512 \
|
||||
--instance_data_dataset 'monadical-labs/minecraft-preview' \
|
||||
--prompt_prefix 'minecraft ' \
|
||||
--image_key image \
|
||||
--prompt_key text \
|
||||
--resolution 512 \
|
||||
--mixed_precision fp16 \
|
||||
--lr_scheduler constant \
|
||||
--validation_prompts \
|
||||
'minecraft Avatar' \
|
||||
'minecraft character' \
|
||||
'minecraft' \
|
||||
'minecraft president' \
|
||||
'minecraft pig' \
|
||||
--max_train_steps 10000 \
|
||||
--checkpointing_steps 500 \
|
||||
--validation_steps 250 \
|
||||
--gradient_checkpointing
|
||||
```
|
||||
|
||||
#### Full finetuning + 8 bit adam
|
||||
|
||||
Batch size: 8, Learning rate: 5e-6, Gives decent results in 500-1000 steps
|
||||
|
||||
| Batch Size | Gradient Accumulation Steps | Effective Total Batch Size | Memory Used |
|
||||
|------------|-----------------------------|------------------|-------------|
|
||||
| 8 | 1 | 8 | 21.2 GB |
|
||||
| 4 | 2 | 8 | 13.3 GB |
|
||||
| 1 | 8 | 8 | 9.9 GB |
|
||||
|
||||
```sh
|
||||
accelerate launch train_amused.py \
|
||||
--output_dir <output path> \
|
||||
--train_batch_size <batch size> \
|
||||
--gradient_accumulation_steps <gradient accumulation steps> \
|
||||
--learning_rate 5e-6 \
|
||||
--pretrained_model_name_or_path huggingface/amused-512 \
|
||||
--instance_data_dataset 'monadical-labs/minecraft-preview' \
|
||||
--prompt_prefix 'minecraft ' \
|
||||
--image_key image \
|
||||
--prompt_key text \
|
||||
--resolution 512 \
|
||||
--mixed_precision fp16 \
|
||||
--lr_scheduler constant \
|
||||
--validation_prompts \
|
||||
'minecraft Avatar' \
|
||||
'minecraft character' \
|
||||
'minecraft' \
|
||||
'minecraft president' \
|
||||
'minecraft pig' \
|
||||
--max_train_steps 10000 \
|
||||
--checkpointing_steps 500 \
|
||||
--validation_steps 250 \
|
||||
--gradient_checkpointing
|
||||
```
|
||||
|
||||
#### Full finetuning + lora
|
||||
|
||||
Batch size: 8, Learning rate: 1e-4, Gives decent results in 500-1000 steps
|
||||
|
||||
| Batch Size | Gradient Accumulation Steps | Effective Total Batch Size | Memory Used |
|
||||
|------------|-----------------------------|------------------|-------------|
|
||||
| 8 | 1 | 8 | 12.7 GB |
|
||||
| 4 | 2 | 8 | 9.0 GB |
|
||||
| 1 | 8 | 8 | 5.6 GB |
|
||||
|
||||
```sh
|
||||
accelerate launch train_amused.py \
|
||||
--output_dir <output path> \
|
||||
--train_batch_size <batch size> \
|
||||
--gradient_accumulation_steps <gradient accumulation steps> \
|
||||
--learning_rate 1e-4 \
|
||||
--use_lora \
|
||||
--pretrained_model_name_or_path huggingface/amused-512 \
|
||||
--instance_data_dataset 'monadical-labs/minecraft-preview' \
|
||||
--prompt_prefix 'minecraft ' \
|
||||
--image_key image \
|
||||
--prompt_key text \
|
||||
--resolution 512 \
|
||||
--mixed_precision fp16 \
|
||||
--lr_scheduler constant \
|
||||
--validation_prompts \
|
||||
'minecraft Avatar' \
|
||||
'minecraft character' \
|
||||
'minecraft' \
|
||||
'minecraft president' \
|
||||
'minecraft pig' \
|
||||
--max_train_steps 10000 \
|
||||
--checkpointing_steps 500 \
|
||||
--validation_steps 250 \
|
||||
--gradient_checkpointing
|
||||
```
|
||||
|
||||
### Styledrop
|
||||
|
||||
[Styledrop](https://arxiv.org/abs/2306.00983) is an efficient finetuning method for learning a new style from just one or very few images. It has an optional first stage to generate human picked additional training samples. The additional training samples can be used to augment the initial images. Our examples exclude the optional additional image selection stage and instead we just finetune on a single image.
|
||||
|
||||
This is our example style image:
|
||||

|
||||
|
||||
Download it to your local directory with
|
||||
```sh
|
||||
wget https://huggingface.co/datasets/diffusers/docs-images/resolve/main/amused/A%20mushroom%20in%20%5BV%5D%20style.png
|
||||
```
|
||||
|
||||
#### 256
|
||||
|
||||
Example results:
|
||||
|
||||
  
|
||||
|
||||
Learning rate: 4e-4, Gives decent results in 1500-2000 steps
|
||||
|
||||
Memory used: 6.5 GB
|
||||
|
||||
```sh
|
||||
accelerate launch train_amused.py \
|
||||
--output_dir <output path> \
|
||||
--mixed_precision fp16 \
|
||||
--report_to wandb \
|
||||
--use_lora \
|
||||
--pretrained_model_name_or_path huggingface/amused-256 \
|
||||
--train_batch_size 1 \
|
||||
--lr_scheduler constant \
|
||||
--learning_rate 4e-4 \
|
||||
--validation_prompts \
|
||||
'A chihuahua walking on the street in [V] style' \
|
||||
'A banana on the table in [V] style' \
|
||||
'A church on the street in [V] style' \
|
||||
'A tabby cat walking in the forest in [V] style' \
|
||||
--instance_data_image 'A mushroom in [V] style.png' \
|
||||
--max_train_steps 10000 \
|
||||
--checkpointing_steps 500 \
|
||||
--validation_steps 100 \
|
||||
--resolution 256
|
||||
```
|
||||
|
||||
#### 512
|
||||
|
||||
Example results:
|
||||
|
||||
  
|
||||
|
||||
Learning rate: 1e-3, Lora alpha 1, Gives decent results in 1500-2000 steps
|
||||
|
||||
Memory used: 5.6 GB
|
||||
|
||||
```
|
||||
accelerate launch train_amused.py \
|
||||
--output_dir <output path> \
|
||||
--mixed_precision fp16 \
|
||||
--report_to wandb \
|
||||
--use_lora \
|
||||
--pretrained_model_name_or_path huggingface/amused-512 \
|
||||
--train_batch_size 1 \
|
||||
--lr_scheduler constant \
|
||||
--learning_rate 1e-3 \
|
||||
--validation_prompts \
|
||||
'A chihuahua walking on the street in [V] style' \
|
||||
'A banana on the table in [V] style' \
|
||||
'A church on the street in [V] style' \
|
||||
'A tabby cat walking in the forest in [V] style' \
|
||||
--instance_data_image 'A mushroom in [V] style.png' \
|
||||
--max_train_steps 100000 \
|
||||
--checkpointing_steps 500 \
|
||||
--validation_steps 100 \
|
||||
--resolution 512 \
|
||||
--lora_alpha 1
|
||||
```
|
||||
@@ -1,972 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import shutil
|
||||
from contextlib import nullcontext
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from accelerate import Accelerator
|
||||
from accelerate.logging import get_logger
|
||||
from accelerate.utils import ProjectConfiguration, set_seed
|
||||
from datasets import load_dataset
|
||||
from peft import LoraConfig
|
||||
from peft.utils import get_peft_model_state_dict
|
||||
from PIL import Image
|
||||
from PIL.ImageOps import exif_transpose
|
||||
from torch.utils.data import DataLoader, Dataset, default_collate
|
||||
from torchvision import transforms
|
||||
from transformers import (
|
||||
CLIPTextModelWithProjection,
|
||||
CLIPTokenizer,
|
||||
)
|
||||
|
||||
import diffusers.optimization
|
||||
from diffusers import AmusedPipeline, AmusedScheduler, EMAModel, UVit2DModel, VQModel
|
||||
from diffusers.loaders import LoraLoaderMixin
|
||||
from diffusers.utils import is_wandb_available
|
||||
|
||||
|
||||
if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--pretrained_model_name_or_path",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--revision",
|
||||
type=str,
|
||||
default=None,
|
||||
required=False,
|
||||
help="Revision of pretrained model identifier from huggingface.co/models.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--variant",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--instance_data_dataset",
|
||||
type=str,
|
||||
default=None,
|
||||
required=False,
|
||||
help="A Hugging Face dataset containing the training images",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--instance_data_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
required=False,
|
||||
help="A folder containing the training data of instance images.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--instance_data_image", type=str, default=None, required=False, help="A single training image"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataloader_num_workers",
|
||||
type=int,
|
||||
default=0,
|
||||
help=(
|
||||
"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--allow_tf32",
|
||||
action="store_true",
|
||||
help=(
|
||||
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
|
||||
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
|
||||
),
|
||||
)
|
||||
parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")
|
||||
parser.add_argument("--ema_decay", type=float, default=0.9999)
|
||||
parser.add_argument("--ema_update_after_step", type=int, default=0)
|
||||
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
|
||||
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
|
||||
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
|
||||
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
default="muse_training",
|
||||
help="The output directory where the model predictions and checkpoints will be written.",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
|
||||
parser.add_argument(
|
||||
"--logging_dir",
|
||||
type=str,
|
||||
default="logs",
|
||||
help=(
|
||||
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
|
||||
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_train_steps",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--checkpointing_steps",
|
||||
type=int,
|
||||
default=500,
|
||||
help=(
|
||||
"Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. "
|
||||
"In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference."
|
||||
"Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components."
|
||||
"See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step"
|
||||
"instructions."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--logging_steps",
|
||||
type=int,
|
||||
default=50,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--checkpoints_total_limit",
|
||||
type=int,
|
||||
default=None,
|
||||
help=(
|
||||
"Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
|
||||
" See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
|
||||
" for more details"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--resume_from_checkpoint",
|
||||
type=str,
|
||||
default=None,
|
||||
help=(
|
||||
"Whether training should be resumed from a previous checkpoint. Use a path saved by"
|
||||
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gradient_accumulation_steps",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--learning_rate",
|
||||
type=float,
|
||||
default=0.0003,
|
||||
help="Initial learning rate (after the potential warmup period) to use.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--scale_lr",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lr_scheduler",
|
||||
type=str,
|
||||
default="constant",
|
||||
help=(
|
||||
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
|
||||
' "constant", "constant_with_warmup"]'
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--validation_steps",
|
||||
type=int,
|
||||
default=100,
|
||||
help=(
|
||||
"Run validation every X steps. Validation consists of running the prompt"
|
||||
" `args.validation_prompt` multiple times: `args.num_validation_images`"
|
||||
" and logging the images."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mixed_precision",
|
||||
type=str,
|
||||
default=None,
|
||||
choices=["no", "fp16", "bf16"],
|
||||
help=(
|
||||
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
|
||||
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
|
||||
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--report_to",
|
||||
type=str,
|
||||
default="wandb",
|
||||
help=(
|
||||
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
|
||||
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
|
||||
),
|
||||
)
|
||||
parser.add_argument("--validation_prompts", type=str, nargs="*")
|
||||
parser.add_argument(
|
||||
"--resolution",
|
||||
type=int,
|
||||
default=512,
|
||||
help=(
|
||||
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
|
||||
" resolution"
|
||||
),
|
||||
)
|
||||
parser.add_argument("--split_vae_encode", type=int, required=False, default=None)
|
||||
parser.add_argument("--min_masking_rate", type=float, default=0.0)
|
||||
parser.add_argument("--cond_dropout_prob", type=float, default=0.0)
|
||||
parser.add_argument("--max_grad_norm", default=None, type=float, help="Max gradient norm.", required=False)
|
||||
parser.add_argument("--use_lora", action="store_true", help="Fine tune the model using LoRa")
|
||||
parser.add_argument("--text_encoder_use_lora", action="store_true", help="Fine tune the model using LoRa")
|
||||
parser.add_argument("--lora_r", default=16, type=int)
|
||||
parser.add_argument("--lora_alpha", default=32, type=int)
|
||||
parser.add_argument("--lora_target_modules", default=["to_q", "to_k", "to_v"], type=str, nargs="+")
|
||||
parser.add_argument("--text_encoder_lora_r", default=16, type=int)
|
||||
parser.add_argument("--text_encoder_lora_alpha", default=32, type=int)
|
||||
parser.add_argument("--text_encoder_lora_target_modules", default=["to_q", "to_k", "to_v"], type=str, nargs="+")
|
||||
parser.add_argument("--train_text_encoder", action="store_true")
|
||||
parser.add_argument("--image_key", type=str, required=False)
|
||||
parser.add_argument("--prompt_key", type=str, required=False)
|
||||
parser.add_argument(
|
||||
"--gradient_checkpointing",
|
||||
action="store_true",
|
||||
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
|
||||
)
|
||||
parser.add_argument("--prompt_prefix", type=str, required=False, default=None)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.report_to == "wandb":
|
||||
if not is_wandb_available():
|
||||
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
|
||||
|
||||
num_datasources = sum(
|
||||
[x is not None for x in [args.instance_data_dir, args.instance_data_image, args.instance_data_dataset]]
|
||||
)
|
||||
|
||||
if num_datasources != 1:
|
||||
raise ValueError(
|
||||
"provide one and only one of `--instance_data_dir`, `--instance_data_image`, or `--instance_data_dataset`"
|
||||
)
|
||||
|
||||
if args.instance_data_dir is not None:
|
||||
if not os.path.exists(args.instance_data_dir):
|
||||
raise ValueError(f"Does not exist: `--args.instance_data_dir` {args.instance_data_dir}")
|
||||
|
||||
if args.instance_data_image is not None:
|
||||
if not os.path.exists(args.instance_data_image):
|
||||
raise ValueError(f"Does not exist: `--args.instance_data_image` {args.instance_data_image}")
|
||||
|
||||
if args.instance_data_dataset is not None and (args.image_key is None or args.prompt_key is None):
|
||||
raise ValueError("`--instance_data_dataset` requires setting `--image_key` and `--prompt_key`")
|
||||
|
||||
return args
|
||||
|
||||
|
||||
class InstanceDataRootDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
instance_data_root,
|
||||
tokenizer,
|
||||
size=512,
|
||||
):
|
||||
self.size = size
|
||||
self.tokenizer = tokenizer
|
||||
self.instance_images_path = list(Path(instance_data_root).iterdir())
|
||||
|
||||
def __len__(self):
|
||||
return len(self.instance_images_path)
|
||||
|
||||
def __getitem__(self, index):
|
||||
image_path = self.instance_images_path[index % len(self.instance_images_path)]
|
||||
instance_image = Image.open(image_path)
|
||||
rv = process_image(instance_image, self.size)
|
||||
|
||||
prompt = os.path.splitext(os.path.basename(image_path))[0]
|
||||
rv["prompt_input_ids"] = tokenize_prompt(self.tokenizer, prompt)[0]
|
||||
return rv
|
||||
|
||||
|
||||
class InstanceDataImageDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
instance_data_image,
|
||||
train_batch_size,
|
||||
size=512,
|
||||
):
|
||||
self.value = process_image(Image.open(instance_data_image), size)
|
||||
self.train_batch_size = train_batch_size
|
||||
|
||||
def __len__(self):
|
||||
# Needed so a full batch of the data can be returned. Otherwise will return
|
||||
# batches of size 1
|
||||
return self.train_batch_size
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.value
|
||||
|
||||
|
||||
class HuggingFaceDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
hf_dataset,
|
||||
tokenizer,
|
||||
image_key,
|
||||
prompt_key,
|
||||
prompt_prefix=None,
|
||||
size=512,
|
||||
):
|
||||
self.size = size
|
||||
self.image_key = image_key
|
||||
self.prompt_key = prompt_key
|
||||
self.tokenizer = tokenizer
|
||||
self.hf_dataset = hf_dataset
|
||||
self.prompt_prefix = prompt_prefix
|
||||
|
||||
def __len__(self):
|
||||
return len(self.hf_dataset)
|
||||
|
||||
def __getitem__(self, index):
|
||||
item = self.hf_dataset[index]
|
||||
|
||||
rv = process_image(item[self.image_key], self.size)
|
||||
|
||||
prompt = item[self.prompt_key]
|
||||
|
||||
if self.prompt_prefix is not None:
|
||||
prompt = self.prompt_prefix + prompt
|
||||
|
||||
rv["prompt_input_ids"] = tokenize_prompt(self.tokenizer, prompt)[0]
|
||||
|
||||
return rv
|
||||
|
||||
|
||||
def process_image(image, size):
|
||||
image = exif_transpose(image)
|
||||
|
||||
if not image.mode == "RGB":
|
||||
image = image.convert("RGB")
|
||||
|
||||
orig_height = image.height
|
||||
orig_width = image.width
|
||||
|
||||
image = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR)(image)
|
||||
|
||||
c_top, c_left, _, _ = transforms.RandomCrop.get_params(image, output_size=(size, size))
|
||||
image = transforms.functional.crop(image, c_top, c_left, size, size)
|
||||
|
||||
image = transforms.ToTensor()(image)
|
||||
|
||||
micro_conds = torch.tensor(
|
||||
[orig_width, orig_height, c_top, c_left, 6.0],
|
||||
)
|
||||
|
||||
return {"image": image, "micro_conds": micro_conds}
|
||||
|
||||
|
||||
def tokenize_prompt(tokenizer, prompt):
|
||||
return tokenizer(
|
||||
prompt,
|
||||
truncation=True,
|
||||
padding="max_length",
|
||||
max_length=77,
|
||||
return_tensors="pt",
|
||||
).input_ids
|
||||
|
||||
|
||||
def encode_prompt(text_encoder, input_ids):
|
||||
outputs = text_encoder(input_ids, return_dict=True, output_hidden_states=True)
|
||||
encoder_hidden_states = outputs.hidden_states[-2]
|
||||
cond_embeds = outputs[0]
|
||||
return encoder_hidden_states, cond_embeds
|
||||
|
||||
|
||||
def main(args):
|
||||
if args.allow_tf32:
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
||||
logging_dir = Path(args.output_dir, args.logging_dir)
|
||||
|
||||
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
|
||||
|
||||
accelerator = Accelerator(
|
||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||
mixed_precision=args.mixed_precision,
|
||||
log_with=args.report_to,
|
||||
project_config=accelerator_project_config,
|
||||
)
|
||||
|
||||
if accelerator.is_main_process:
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
# Make one log on every process with the configuration for debugging.
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
level=logging.INFO,
|
||||
)
|
||||
logger.info(accelerator.state, main_process_only=False)
|
||||
|
||||
if accelerator.is_main_process:
|
||||
accelerator.init_trackers("amused", config=vars(copy.deepcopy(args)))
|
||||
|
||||
if args.seed is not None:
|
||||
set_seed(args.seed)
|
||||
|
||||
# TODO - will have to fix loading if training text encoder
|
||||
text_encoder = CLIPTextModelWithProjection.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
|
||||
)
|
||||
tokenizer = CLIPTokenizer.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision, variant=args.variant
|
||||
)
|
||||
vq_model = VQModel.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="vqvae", revision=args.revision, variant=args.variant
|
||||
)
|
||||
|
||||
if args.train_text_encoder:
|
||||
if args.text_encoder_use_lora:
|
||||
lora_config = LoraConfig(
|
||||
r=args.text_encoder_lora_r,
|
||||
lora_alpha=args.text_encoder_lora_alpha,
|
||||
target_modules=args.text_encoder_lora_target_modules,
|
||||
)
|
||||
text_encoder.add_adapter(lora_config)
|
||||
text_encoder.train()
|
||||
text_encoder.requires_grad_(True)
|
||||
else:
|
||||
text_encoder.eval()
|
||||
text_encoder.requires_grad_(False)
|
||||
|
||||
vq_model.requires_grad_(False)
|
||||
|
||||
model = UVit2DModel.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
subfolder="transformer",
|
||||
revision=args.revision,
|
||||
variant=args.variant,
|
||||
)
|
||||
|
||||
if args.use_lora:
|
||||
lora_config = LoraConfig(
|
||||
r=args.lora_r,
|
||||
lora_alpha=args.lora_alpha,
|
||||
target_modules=args.lora_target_modules,
|
||||
)
|
||||
model.add_adapter(lora_config)
|
||||
|
||||
model.train()
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
model.enable_gradient_checkpointing()
|
||||
if args.train_text_encoder:
|
||||
text_encoder.gradient_checkpointing_enable()
|
||||
|
||||
if args.use_ema:
|
||||
ema = EMAModel(
|
||||
model.parameters(),
|
||||
decay=args.ema_decay,
|
||||
update_after_step=args.ema_update_after_step,
|
||||
model_cls=UVit2DModel,
|
||||
model_config=model.config,
|
||||
)
|
||||
|
||||
def save_model_hook(models, weights, output_dir):
|
||||
if accelerator.is_main_process:
|
||||
transformer_lora_layers_to_save = None
|
||||
text_encoder_lora_layers_to_save = None
|
||||
|
||||
for model_ in models:
|
||||
if isinstance(model_, type(accelerator.unwrap_model(model))):
|
||||
if args.use_lora:
|
||||
transformer_lora_layers_to_save = get_peft_model_state_dict(model_)
|
||||
else:
|
||||
model_.save_pretrained(os.path.join(output_dir, "transformer"))
|
||||
elif isinstance(model_, type(accelerator.unwrap_model(text_encoder))):
|
||||
if args.text_encoder_use_lora:
|
||||
text_encoder_lora_layers_to_save = get_peft_model_state_dict(model_)
|
||||
else:
|
||||
model_.save_pretrained(os.path.join(output_dir, "text_encoder"))
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model_.__class__}")
|
||||
|
||||
# make sure to pop weight so that corresponding model is not saved again
|
||||
weights.pop()
|
||||
|
||||
if transformer_lora_layers_to_save is not None or text_encoder_lora_layers_to_save is not None:
|
||||
LoraLoaderMixin.save_lora_weights(
|
||||
output_dir,
|
||||
transformer_lora_layers=transformer_lora_layers_to_save,
|
||||
text_encoder_lora_layers=text_encoder_lora_layers_to_save,
|
||||
)
|
||||
|
||||
if args.use_ema:
|
||||
ema.save_pretrained(os.path.join(output_dir, "ema_model"))
|
||||
|
||||
def load_model_hook(models, input_dir):
|
||||
transformer = None
|
||||
text_encoder_ = None
|
||||
|
||||
while len(models) > 0:
|
||||
model_ = models.pop()
|
||||
|
||||
if isinstance(model_, type(accelerator.unwrap_model(model))):
|
||||
if args.use_lora:
|
||||
transformer = model_
|
||||
else:
|
||||
load_model = UVit2DModel.from_pretrained(os.path.join(input_dir, "transformer"))
|
||||
model_.load_state_dict(load_model.state_dict())
|
||||
del load_model
|
||||
elif isinstance(model, type(accelerator.unwrap_model(text_encoder))):
|
||||
if args.text_encoder_use_lora:
|
||||
text_encoder_ = model_
|
||||
else:
|
||||
load_model = CLIPTextModelWithProjection.from_pretrained(os.path.join(input_dir, "text_encoder"))
|
||||
model_.load_state_dict(load_model.state_dict())
|
||||
del load_model
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
|
||||
if transformer is not None or text_encoder_ is not None:
|
||||
lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir)
|
||||
LoraLoaderMixin.load_lora_into_text_encoder(
|
||||
lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_
|
||||
)
|
||||
LoraLoaderMixin.load_lora_into_transformer(
|
||||
lora_state_dict, network_alphas=network_alphas, transformer=transformer
|
||||
)
|
||||
|
||||
if args.use_ema:
|
||||
load_from = EMAModel.from_pretrained(os.path.join(input_dir, "ema_model"), model_cls=UVit2DModel)
|
||||
ema.load_state_dict(load_from.state_dict())
|
||||
del load_from
|
||||
|
||||
accelerator.register_load_state_pre_hook(load_model_hook)
|
||||
accelerator.register_save_state_pre_hook(save_model_hook)
|
||||
|
||||
if args.scale_lr:
|
||||
args.learning_rate = (
|
||||
args.learning_rate * args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
||||
)
|
||||
|
||||
if args.use_8bit_adam:
|
||||
try:
|
||||
import bitsandbytes as bnb
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
|
||||
)
|
||||
|
||||
optimizer_cls = bnb.optim.AdamW8bit
|
||||
else:
|
||||
optimizer_cls = torch.optim.AdamW
|
||||
|
||||
# no decay on bias and layernorm and embedding
|
||||
no_decay = ["bias", "layer_norm.weight", "mlm_ln.weight", "embeddings.weight"]
|
||||
optimizer_grouped_parameters = [
|
||||
{
|
||||
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
|
||||
"weight_decay": args.adam_weight_decay,
|
||||
},
|
||||
{
|
||||
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
|
||||
"weight_decay": 0.0,
|
||||
},
|
||||
]
|
||||
|
||||
if args.train_text_encoder:
|
||||
optimizer_grouped_parameters.append(
|
||||
{"params": text_encoder.parameters(), "weight_decay": args.adam_weight_decay}
|
||||
)
|
||||
|
||||
optimizer = optimizer_cls(
|
||||
optimizer_grouped_parameters,
|
||||
lr=args.learning_rate,
|
||||
betas=(args.adam_beta1, args.adam_beta2),
|
||||
weight_decay=args.adam_weight_decay,
|
||||
eps=args.adam_epsilon,
|
||||
)
|
||||
|
||||
logger.info("Creating dataloaders and lr_scheduler")
|
||||
|
||||
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
||||
|
||||
if args.instance_data_dir is not None:
|
||||
dataset = InstanceDataRootDataset(
|
||||
instance_data_root=args.instance_data_dir,
|
||||
tokenizer=tokenizer,
|
||||
size=args.resolution,
|
||||
)
|
||||
elif args.instance_data_image is not None:
|
||||
dataset = InstanceDataImageDataset(
|
||||
instance_data_image=args.instance_data_image,
|
||||
train_batch_size=args.train_batch_size,
|
||||
size=args.resolution,
|
||||
)
|
||||
elif args.instance_data_dataset is not None:
|
||||
dataset = HuggingFaceDataset(
|
||||
hf_dataset=load_dataset(args.instance_data_dataset, split="train"),
|
||||
tokenizer=tokenizer,
|
||||
image_key=args.image_key,
|
||||
prompt_key=args.prompt_key,
|
||||
prompt_prefix=args.prompt_prefix,
|
||||
size=args.resolution,
|
||||
)
|
||||
else:
|
||||
assert False
|
||||
|
||||
train_dataloader = DataLoader(
|
||||
dataset,
|
||||
batch_size=args.train_batch_size,
|
||||
shuffle=True,
|
||||
num_workers=args.dataloader_num_workers,
|
||||
collate_fn=default_collate,
|
||||
)
|
||||
train_dataloader.num_batches = len(train_dataloader)
|
||||
|
||||
lr_scheduler = diffusers.optimization.get_scheduler(
|
||||
args.lr_scheduler,
|
||||
optimizer=optimizer,
|
||||
num_training_steps=args.max_train_steps * accelerator.num_processes,
|
||||
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
|
||||
)
|
||||
|
||||
logger.info("Preparing model, optimizer and dataloaders")
|
||||
|
||||
if args.train_text_encoder:
|
||||
model, optimizer, lr_scheduler, train_dataloader, text_encoder = accelerator.prepare(
|
||||
model, optimizer, lr_scheduler, train_dataloader, text_encoder
|
||||
)
|
||||
else:
|
||||
model, optimizer, lr_scheduler, train_dataloader = accelerator.prepare(
|
||||
model, optimizer, lr_scheduler, train_dataloader
|
||||
)
|
||||
|
||||
train_dataloader.num_batches = len(train_dataloader)
|
||||
|
||||
weight_dtype = torch.float32
|
||||
if accelerator.mixed_precision == "fp16":
|
||||
weight_dtype = torch.float16
|
||||
elif accelerator.mixed_precision == "bf16":
|
||||
weight_dtype = torch.bfloat16
|
||||
|
||||
if not args.train_text_encoder:
|
||||
text_encoder.to(device=accelerator.device, dtype=weight_dtype)
|
||||
|
||||
vq_model.to(device=accelerator.device)
|
||||
|
||||
if args.use_ema:
|
||||
ema.to(accelerator.device)
|
||||
|
||||
with nullcontext() if args.train_text_encoder else torch.no_grad():
|
||||
empty_embeds, empty_clip_embeds = encode_prompt(
|
||||
text_encoder, tokenize_prompt(tokenizer, "").to(text_encoder.device, non_blocking=True)
|
||||
)
|
||||
|
||||
# There is a single image, we can just pre-encode the single prompt
|
||||
if args.instance_data_image is not None:
|
||||
prompt = os.path.splitext(os.path.basename(args.instance_data_image))[0]
|
||||
encoder_hidden_states, cond_embeds = encode_prompt(
|
||||
text_encoder, tokenize_prompt(tokenizer, prompt).to(text_encoder.device, non_blocking=True)
|
||||
)
|
||||
encoder_hidden_states = encoder_hidden_states.repeat(args.train_batch_size, 1, 1)
|
||||
cond_embeds = cond_embeds.repeat(args.train_batch_size, 1)
|
||||
|
||||
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
||||
num_update_steps_per_epoch = math.ceil(train_dataloader.num_batches / args.gradient_accumulation_steps)
|
||||
# Afterwards we recalculate our number of training epochs.
|
||||
# Note: We are not doing epoch based training here, but just using this for book keeping and being able to
|
||||
# reuse the same training loop with other datasets/loaders.
|
||||
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||
|
||||
# Train!
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(f" Num training steps = {args.max_train_steps}")
|
||||
logger.info(f" Instantaneous batch size per device = { args.train_batch_size}")
|
||||
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
||||
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
|
||||
|
||||
resume_from_checkpoint = args.resume_from_checkpoint
|
||||
if resume_from_checkpoint:
|
||||
if resume_from_checkpoint == "latest":
|
||||
# Get the most recent checkpoint
|
||||
dirs = os.listdir(args.output_dir)
|
||||
dirs = [d for d in dirs if d.startswith("checkpoint")]
|
||||
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
|
||||
if len(dirs) > 0:
|
||||
resume_from_checkpoint = os.path.join(args.output_dir, dirs[-1])
|
||||
else:
|
||||
resume_from_checkpoint = None
|
||||
|
||||
if resume_from_checkpoint is None:
|
||||
accelerator.print(
|
||||
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
|
||||
)
|
||||
else:
|
||||
accelerator.print(f"Resuming from checkpoint {resume_from_checkpoint}")
|
||||
|
||||
if resume_from_checkpoint is None:
|
||||
global_step = 0
|
||||
first_epoch = 0
|
||||
else:
|
||||
accelerator.load_state(resume_from_checkpoint)
|
||||
global_step = int(os.path.basename(resume_from_checkpoint).split("-")[1])
|
||||
first_epoch = global_step // num_update_steps_per_epoch
|
||||
|
||||
# As stated above, we are not doing epoch based training here, but just using this for book keeping and being able to
|
||||
# reuse the same training loop with other datasets/loaders.
|
||||
for epoch in range(first_epoch, num_train_epochs):
|
||||
for batch in train_dataloader:
|
||||
with torch.no_grad():
|
||||
micro_conds = batch["micro_conds"].to(accelerator.device, non_blocking=True)
|
||||
pixel_values = batch["image"].to(accelerator.device, non_blocking=True)
|
||||
|
||||
batch_size = pixel_values.shape[0]
|
||||
|
||||
split_batch_size = args.split_vae_encode if args.split_vae_encode is not None else batch_size
|
||||
num_splits = math.ceil(batch_size / split_batch_size)
|
||||
image_tokens = []
|
||||
for i in range(num_splits):
|
||||
start_idx = i * split_batch_size
|
||||
end_idx = min((i + 1) * split_batch_size, batch_size)
|
||||
bs = pixel_values.shape[0]
|
||||
image_tokens.append(
|
||||
vq_model.quantize(vq_model.encode(pixel_values[start_idx:end_idx]).latents)[2][2].reshape(
|
||||
bs, -1
|
||||
)
|
||||
)
|
||||
image_tokens = torch.cat(image_tokens, dim=0)
|
||||
|
||||
batch_size, seq_len = image_tokens.shape
|
||||
|
||||
timesteps = torch.rand(batch_size, device=image_tokens.device)
|
||||
mask_prob = torch.cos(timesteps * math.pi * 0.5)
|
||||
mask_prob = mask_prob.clip(args.min_masking_rate)
|
||||
|
||||
num_token_masked = (seq_len * mask_prob).round().clamp(min=1)
|
||||
batch_randperm = torch.rand(batch_size, seq_len, device=image_tokens.device).argsort(dim=-1)
|
||||
mask = batch_randperm < num_token_masked.unsqueeze(-1)
|
||||
|
||||
mask_id = accelerator.unwrap_model(model).config.vocab_size - 1
|
||||
input_ids = torch.where(mask, mask_id, image_tokens)
|
||||
labels = torch.where(mask, image_tokens, -100)
|
||||
|
||||
if args.cond_dropout_prob > 0.0:
|
||||
assert encoder_hidden_states is not None
|
||||
|
||||
batch_size = encoder_hidden_states.shape[0]
|
||||
|
||||
mask = (
|
||||
torch.zeros((batch_size, 1, 1), device=encoder_hidden_states.device).float().uniform_(0, 1)
|
||||
< args.cond_dropout_prob
|
||||
)
|
||||
|
||||
empty_embeds_ = empty_embeds.expand(batch_size, -1, -1)
|
||||
encoder_hidden_states = torch.where(
|
||||
(encoder_hidden_states * mask).bool(), encoder_hidden_states, empty_embeds_
|
||||
)
|
||||
|
||||
empty_clip_embeds_ = empty_clip_embeds.expand(batch_size, -1)
|
||||
cond_embeds = torch.where((cond_embeds * mask.squeeze(-1)).bool(), cond_embeds, empty_clip_embeds_)
|
||||
|
||||
bs = input_ids.shape[0]
|
||||
vae_scale_factor = 2 ** (len(vq_model.config.block_out_channels) - 1)
|
||||
resolution = args.resolution // vae_scale_factor
|
||||
input_ids = input_ids.reshape(bs, resolution, resolution)
|
||||
|
||||
if "prompt_input_ids" in batch:
|
||||
with nullcontext() if args.train_text_encoder else torch.no_grad():
|
||||
encoder_hidden_states, cond_embeds = encode_prompt(
|
||||
text_encoder, batch["prompt_input_ids"].to(accelerator.device, non_blocking=True)
|
||||
)
|
||||
|
||||
# Train Step
|
||||
with accelerator.accumulate(model):
|
||||
codebook_size = accelerator.unwrap_model(model).config.codebook_size
|
||||
|
||||
logits = (
|
||||
model(
|
||||
input_ids=input_ids,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
micro_conds=micro_conds,
|
||||
pooled_text_emb=cond_embeds,
|
||||
)
|
||||
.reshape(bs, codebook_size, -1)
|
||||
.permute(0, 2, 1)
|
||||
.reshape(-1, codebook_size)
|
||||
)
|
||||
|
||||
loss = F.cross_entropy(
|
||||
logits,
|
||||
labels.view(-1),
|
||||
ignore_index=-100,
|
||||
reduction="mean",
|
||||
)
|
||||
|
||||
# Gather the losses across all processes for logging (if we use distributed training).
|
||||
avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
|
||||
avg_masking_rate = accelerator.gather(mask_prob.repeat(args.train_batch_size)).mean()
|
||||
|
||||
accelerator.backward(loss)
|
||||
|
||||
if args.max_grad_norm is not None and accelerator.sync_gradients:
|
||||
accelerator.clip_grad_norm_(model.parameters(), args.max_grad_norm)
|
||||
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
if args.use_ema:
|
||||
ema.step(model.parameters())
|
||||
|
||||
if (global_step + 1) % args.logging_steps == 0:
|
||||
logs = {
|
||||
"step_loss": avg_loss.item(),
|
||||
"lr": lr_scheduler.get_last_lr()[0],
|
||||
"avg_masking_rate": avg_masking_rate.item(),
|
||||
}
|
||||
accelerator.log(logs, step=global_step + 1)
|
||||
|
||||
logger.info(
|
||||
f"Step: {global_step + 1} "
|
||||
f"Loss: {avg_loss.item():0.4f} "
|
||||
f"LR: {lr_scheduler.get_last_lr()[0]:0.6f}"
|
||||
)
|
||||
|
||||
if (global_step + 1) % args.checkpointing_steps == 0:
|
||||
save_checkpoint(args, accelerator, global_step + 1)
|
||||
|
||||
if (global_step + 1) % args.validation_steps == 0 and accelerator.is_main_process:
|
||||
if args.use_ema:
|
||||
ema.store(model.parameters())
|
||||
ema.copy_to(model.parameters())
|
||||
|
||||
with torch.no_grad():
|
||||
logger.info("Generating images...")
|
||||
|
||||
model.eval()
|
||||
|
||||
if args.train_text_encoder:
|
||||
text_encoder.eval()
|
||||
|
||||
scheduler = AmusedScheduler.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
subfolder="scheduler",
|
||||
revision=args.revision,
|
||||
variant=args.variant,
|
||||
)
|
||||
|
||||
pipe = AmusedPipeline(
|
||||
transformer=accelerator.unwrap_model(model),
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
vqvae=vq_model,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
|
||||
pil_images = pipe(prompt=args.validation_prompts).images
|
||||
wandb_images = [
|
||||
wandb.Image(image, caption=args.validation_prompts[i])
|
||||
for i, image in enumerate(pil_images)
|
||||
]
|
||||
|
||||
wandb.log({"generated_images": wandb_images}, step=global_step + 1)
|
||||
|
||||
model.train()
|
||||
|
||||
if args.train_text_encoder:
|
||||
text_encoder.train()
|
||||
|
||||
if args.use_ema:
|
||||
ema.restore(model.parameters())
|
||||
|
||||
global_step += 1
|
||||
|
||||
# Stop training if max steps is reached
|
||||
if global_step >= args.max_train_steps:
|
||||
break
|
||||
# End for
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# Evaluate and save checkpoint at the end of training
|
||||
save_checkpoint(args, accelerator, global_step)
|
||||
|
||||
# Save the final trained checkpoint
|
||||
if accelerator.is_main_process:
|
||||
model = accelerator.unwrap_model(model)
|
||||
if args.use_ema:
|
||||
ema.copy_to(model.parameters())
|
||||
model.save_pretrained(args.output_dir)
|
||||
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
def save_checkpoint(args, accelerator, global_step):
|
||||
output_dir = args.output_dir
|
||||
|
||||
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
|
||||
if accelerator.is_main_process and args.checkpoints_total_limit is not None:
|
||||
checkpoints = os.listdir(output_dir)
|
||||
checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
|
||||
checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
|
||||
|
||||
# before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
|
||||
if len(checkpoints) >= args.checkpoints_total_limit:
|
||||
num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
|
||||
removing_checkpoints = checkpoints[0:num_to_remove]
|
||||
|
||||
logger.info(
|
||||
f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
|
||||
)
|
||||
logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
|
||||
|
||||
for removing_checkpoint in removing_checkpoints:
|
||||
removing_checkpoint = os.path.join(output_dir, removing_checkpoint)
|
||||
shutil.rmtree(removing_checkpoint)
|
||||
|
||||
save_path = Path(output_dir) / f"checkpoint-{global_step}"
|
||||
accelerator.save_state(save_path)
|
||||
logger.info(f"Saved state to {save_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main(parse_args())
|
||||
@@ -8,7 +8,6 @@ If a community doesn't work as expected, please open an issue and ping the autho
|
||||
|
||||
| Example | Description | Code Example | Colab | Author |
|
||||
|:--------------------------------------------------------------------------------------------------------------------------------------|:---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:------------------------------------------------------------------------------------------|:-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------:|
|
||||
| Marigold Monocular Depth Estimation | A universal monocular depth estimator, utilizing Stable Diffusion, delivering sharp predictions in the wild. (See the [project page](https://marigoldmonodepth.github.io) and [full codebase](https://github.com/prs-eth/marigold) for more details.) | [Marigold Depth Estimation](#marigold-depth-estimation) | [](https://huggingface.co/spaces/toshas/marigold) [](https://colab.research.google.com/drive/12G8reD13DdpMie5ZQlaFNo2WCGeNUH-u?usp=sharing) | [Bingxin Ke](https://github.com/markkua) and [Anton Obukhov](https://github.com/toshas) |
|
||||
| LLM-grounded Diffusion (LMD+) | LMD greatly improves the prompt following ability of text-to-image generation models by introducing an LLM as a front-end prompt parser and layout planner. [Project page.](https://llm-grounded-diffusion.github.io/) [See our full codebase (also with diffusers).](https://github.com/TonyLianLong/LLM-groundedDiffusion) | [LLM-grounded Diffusion (LMD+)](#llm-grounded-diffusion) | [Huggingface Demo](https://huggingface.co/spaces/longlian/llm-grounded-diffusion) [](https://colab.research.google.com/drive/1SXzMSeAB-LJYISb2yrUOdypLz4OYWUKj) | [Long (Tony) Lian](https://tonylian.com/) |
|
||||
| CLIP Guided Stable Diffusion | Doing CLIP guidance for text to image generation with Stable Diffusion | [CLIP Guided Stable Diffusion](#clip-guided-stable-diffusion) | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/CLIP_Guided_Stable_diffusion_with_diffusers.ipynb) | [Suraj Patil](https://github.com/patil-suraj/) |
|
||||
| One Step U-Net (Dummy) | Example showcasing of how to use Community Pipelines (see https://github.com/huggingface/diffusers/issues/841) | [One Step U-Net](#one-step-unet) | - | [Patrick von Platen](https://github.com/patrickvonplaten/) |
|
||||
@@ -62,53 +61,6 @@ pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", custo
|
||||
|
||||
## Example usages
|
||||
|
||||
### Marigold Depth Estimation
|
||||
|
||||
Marigold is a universal monocular depth estimator that delivers accurate and sharp predictions in the wild. Based on Stable Diffusion, it is trained exclusively with synthetic depth data and excels in zero-shot adaptation to real-world imagery. This pipeline is an official implementation of the inference process. More details can be found on our [project page](https://marigoldmonodepth.github.io) and [full codebase](https://github.com/prs-eth/marigold) (also implemented with diffusers).
|
||||
|
||||

|
||||
|
||||
This depth estimation pipeline processes a single input image through multiple diffusion denoising stages to estimate depth maps. These maps are subsequently merged to produce the final output. Below is an example code snippet, including optional arguments:
|
||||
|
||||
```python
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from diffusers import DiffusionPipeline
|
||||
from diffusers.utils import load_image
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
"Bingxin/Marigold",
|
||||
custom_pipeline="marigold_depth_estimation"
|
||||
# torch_dtype=torch.float16, # (optional) Run with half-precision (16-bit float).
|
||||
)
|
||||
|
||||
pipe.to("cuda")
|
||||
|
||||
img_path_or_url = "https://share.phys.ethz.ch/~pf/bingkedata/marigold/pipeline_example.jpg"
|
||||
image: Image.Image = load_image(img_path_or_url)
|
||||
|
||||
pipeline_output = pipe(
|
||||
image, # Input image.
|
||||
# denoising_steps=10, # (optional) Number of denoising steps of each inference pass. Default: 10.
|
||||
# ensemble_size=10, # (optional) Number of inference passes in the ensemble. Default: 10.
|
||||
# processing_res=768, # (optional) Maximum resolution of processing. If set to 0: will not resize at all. Defaults to 768.
|
||||
# match_input_res=True, # (optional) Resize depth prediction to match input resolution.
|
||||
# batch_size=0, # (optional) Inference batch size, no bigger than `num_ensemble`. If set to 0, the script will automatically decide the proper batch size. Defaults to 0.
|
||||
# color_map="Spectral", # (optional) Colormap used to colorize the depth map. Defaults to "Spectral".
|
||||
# show_progress_bar=True, # (optional) If true, will show progress bars of the inference progress.
|
||||
)
|
||||
|
||||
depth: np.ndarray = pipeline_output.depth_np # Predicted depth map
|
||||
depth_colored: Image.Image = pipeline_output.depth_colored # Colorized prediction
|
||||
|
||||
# Save as uint16 PNG
|
||||
depth_uint16 = (depth * 65535.0).astype(np.uint16)
|
||||
Image.fromarray(depth_uint16).save("./depth_map.png", mode="I;16")
|
||||
|
||||
# Save colorized depth map
|
||||
depth_colored.save("./depth_colored.png")
|
||||
```
|
||||
|
||||
### LLM-grounded Diffusion
|
||||
|
||||
LMD and LMD+ greatly improves the prompt understanding ability of text-to-image generation models by introducing an LLM as a front-end prompt parser and layout planner. It improves spatial reasoning, the understanding of negation, attribute binding, generative numeracy, etc. in a unified manner without explicitly aiming for each. LMD is completely training-free (i.e., uses SD model off-the-shelf). LMD+ takes in additional adapters for better control. This is a reproduction of LMD+ model used in our work. [Project page.](https://llm-grounded-diffusion.github.io/) [See our full codebase (also with diffusers).](https://github.com/TonyLianLong/LLM-groundedDiffusion)
|
||||
|
||||
@@ -1,602 +0,0 @@
|
||||
# Copyright 2023 Bingxin Ke, ETH Zurich and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# --------------------------------------------------------------------------
|
||||
# If you find this code useful, we kindly ask you to cite our paper in your work.
|
||||
# Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
|
||||
# More information about the method can be found at https://marigoldmonodepth.github.io
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
|
||||
import math
|
||||
from typing import Dict, Union
|
||||
|
||||
import matplotlib
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from scipy.optimize import minimize
|
||||
from torch.utils.data import DataLoader, TensorDataset
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
DDIMScheduler,
|
||||
DiffusionPipeline,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.utils import BaseOutput, check_min_version
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.20.1.dev0")
|
||||
|
||||
|
||||
class MarigoldDepthOutput(BaseOutput):
|
||||
"""
|
||||
Output class for Marigold monocular depth prediction pipeline.
|
||||
|
||||
Args:
|
||||
depth_np (`np.ndarray`):
|
||||
Predicted depth map, with depth values in the range of [0, 1].
|
||||
depth_colored (`PIL.Image.Image`):
|
||||
Colorized depth map, with the shape of [3, H, W] and values in [0, 1].
|
||||
uncertainty (`None` or `np.ndarray`):
|
||||
Uncalibrated uncertainty(MAD, median absolute deviation) coming from ensembling.
|
||||
"""
|
||||
|
||||
depth_np: np.ndarray
|
||||
depth_colored: Image.Image
|
||||
uncertainty: Union[None, np.ndarray]
|
||||
|
||||
|
||||
class MarigoldPipeline(DiffusionPipeline):
|
||||
"""
|
||||
Pipeline for monocular depth estimation using Marigold: https://marigoldmonodepth.github.io.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
Args:
|
||||
unet (`UNet2DConditionModel`):
|
||||
Conditional U-Net to denoise the depth latent, conditioned on image latent.
|
||||
vae (`AutoencoderKL`):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images and depth maps
|
||||
to and from latent representations.
|
||||
scheduler (`DDIMScheduler`):
|
||||
A scheduler to be used in combination with `unet` to denoise the encoded image latents.
|
||||
text_encoder (`CLIPTextModel`):
|
||||
Text-encoder, for empty text embedding.
|
||||
tokenizer (`CLIPTokenizer`):
|
||||
CLIP tokenizer.
|
||||
"""
|
||||
|
||||
rgb_latent_scale_factor = 0.18215
|
||||
depth_latent_scale_factor = 0.18215
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
unet: UNet2DConditionModel,
|
||||
vae: AutoencoderKL,
|
||||
scheduler: DDIMScheduler,
|
||||
text_encoder: CLIPTextModel,
|
||||
tokenizer: CLIPTokenizer,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
unet=unet,
|
||||
vae=vae,
|
||||
scheduler=scheduler,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
self.empty_text_embed = None
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
input_image: Image,
|
||||
denoising_steps: int = 10,
|
||||
ensemble_size: int = 10,
|
||||
processing_res: int = 768,
|
||||
match_input_res: bool = True,
|
||||
batch_size: int = 0,
|
||||
color_map: str = "Spectral",
|
||||
show_progress_bar: bool = True,
|
||||
ensemble_kwargs: Dict = None,
|
||||
) -> MarigoldDepthOutput:
|
||||
"""
|
||||
Function invoked when calling the pipeline.
|
||||
|
||||
Args:
|
||||
input_image (`Image`):
|
||||
Input RGB (or gray-scale) image.
|
||||
processing_res (`int`, *optional*, defaults to `768`):
|
||||
Maximum resolution of processing.
|
||||
If set to 0: will not resize at all.
|
||||
match_input_res (`bool`, *optional*, defaults to `True`):
|
||||
Resize depth prediction to match input resolution.
|
||||
Only valid if `limit_input_res` is not None.
|
||||
denoising_steps (`int`, *optional*, defaults to `10`):
|
||||
Number of diffusion denoising steps (DDIM) during inference.
|
||||
ensemble_size (`int`, *optional*, defaults to `10`):
|
||||
Number of predictions to be ensembled.
|
||||
batch_size (`int`, *optional*, defaults to `0`):
|
||||
Inference batch size, no bigger than `num_ensemble`.
|
||||
If set to 0, the script will automatically decide the proper batch size.
|
||||
show_progress_bar (`bool`, *optional*, defaults to `True`):
|
||||
Display a progress bar of diffusion denoising.
|
||||
color_map (`str`, *optional*, defaults to `"Spectral"`):
|
||||
Colormap used to colorize the depth map.
|
||||
ensemble_kwargs (`dict`, *optional*, defaults to `None`):
|
||||
Arguments for detailed ensembling settings.
|
||||
Returns:
|
||||
`MarigoldDepthOutput`: Output class for Marigold monocular depth prediction pipeline, including:
|
||||
- **depth_np** (`np.ndarray`) Predicted depth map, with depth values in the range of [0, 1]
|
||||
- **depth_colored** (`PIL.Image.Image`) Colorized depth map, with the shape of [3, H, W] and values in [0, 1]
|
||||
- **uncertainty** (`None` or `np.ndarray`) Uncalibrated uncertainty(MAD, median absolute deviation)
|
||||
coming from ensembling. None if `ensemble_size = 1`
|
||||
"""
|
||||
|
||||
device = self.device
|
||||
input_size = input_image.size
|
||||
|
||||
if not match_input_res:
|
||||
assert processing_res is not None, "Value error: `resize_output_back` is only valid with "
|
||||
assert processing_res >= 0
|
||||
assert denoising_steps >= 1
|
||||
assert ensemble_size >= 1
|
||||
|
||||
# ----------------- Image Preprocess -----------------
|
||||
# Resize image
|
||||
if processing_res > 0:
|
||||
input_image = self.resize_max_res(input_image, max_edge_resolution=processing_res)
|
||||
# Convert the image to RGB, to 1.remove the alpha channel 2.convert B&W to 3-channel
|
||||
input_image = input_image.convert("RGB")
|
||||
image = np.asarray(input_image)
|
||||
|
||||
# Normalize rgb values
|
||||
rgb = np.transpose(image, (2, 0, 1)) # [H, W, rgb] -> [rgb, H, W]
|
||||
rgb_norm = rgb / 255.0
|
||||
rgb_norm = torch.from_numpy(rgb_norm).to(self.dtype)
|
||||
rgb_norm = rgb_norm.to(device)
|
||||
assert rgb_norm.min() >= 0.0 and rgb_norm.max() <= 1.0
|
||||
|
||||
# ----------------- Predicting depth -----------------
|
||||
# Batch repeated input image
|
||||
duplicated_rgb = torch.stack([rgb_norm] * ensemble_size)
|
||||
single_rgb_dataset = TensorDataset(duplicated_rgb)
|
||||
if batch_size > 0:
|
||||
_bs = batch_size
|
||||
else:
|
||||
_bs = self._find_batch_size(
|
||||
ensemble_size=ensemble_size,
|
||||
input_res=max(rgb_norm.shape[1:]),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
single_rgb_loader = DataLoader(single_rgb_dataset, batch_size=_bs, shuffle=False)
|
||||
|
||||
# Predict depth maps (batched)
|
||||
depth_pred_ls = []
|
||||
if show_progress_bar:
|
||||
iterable = tqdm(single_rgb_loader, desc=" " * 2 + "Inference batches", leave=False)
|
||||
else:
|
||||
iterable = single_rgb_loader
|
||||
for batch in iterable:
|
||||
(batched_img,) = batch
|
||||
depth_pred_raw = self.single_infer(
|
||||
rgb_in=batched_img,
|
||||
num_inference_steps=denoising_steps,
|
||||
show_pbar=show_progress_bar,
|
||||
)
|
||||
depth_pred_ls.append(depth_pred_raw.detach().clone())
|
||||
depth_preds = torch.concat(depth_pred_ls, axis=0).squeeze()
|
||||
torch.cuda.empty_cache() # clear vram cache for ensembling
|
||||
|
||||
# ----------------- Test-time ensembling -----------------
|
||||
if ensemble_size > 1:
|
||||
depth_pred, pred_uncert = self.ensemble_depths(depth_preds, **(ensemble_kwargs or {}))
|
||||
else:
|
||||
depth_pred = depth_preds
|
||||
pred_uncert = None
|
||||
|
||||
# ----------------- Post processing -----------------
|
||||
# Scale prediction to [0, 1]
|
||||
min_d = torch.min(depth_pred)
|
||||
max_d = torch.max(depth_pred)
|
||||
depth_pred = (depth_pred - min_d) / (max_d - min_d)
|
||||
|
||||
# Convert to numpy
|
||||
depth_pred = depth_pred.cpu().numpy().astype(np.float32)
|
||||
|
||||
# Resize back to original resolution
|
||||
if match_input_res:
|
||||
pred_img = Image.fromarray(depth_pred)
|
||||
pred_img = pred_img.resize(input_size)
|
||||
depth_pred = np.asarray(pred_img)
|
||||
|
||||
# Clip output range
|
||||
depth_pred = depth_pred.clip(0, 1)
|
||||
|
||||
# Colorize
|
||||
depth_colored = self.colorize_depth_maps(
|
||||
depth_pred, 0, 1, cmap=color_map
|
||||
).squeeze() # [3, H, W], value in (0, 1)
|
||||
depth_colored = (depth_colored * 255).astype(np.uint8)
|
||||
depth_colored_hwc = self.chw2hwc(depth_colored)
|
||||
depth_colored_img = Image.fromarray(depth_colored_hwc)
|
||||
return MarigoldDepthOutput(
|
||||
depth_np=depth_pred,
|
||||
depth_colored=depth_colored_img,
|
||||
uncertainty=pred_uncert,
|
||||
)
|
||||
|
||||
def _encode_empty_text(self):
|
||||
"""
|
||||
Encode text embedding for empty prompt.
|
||||
"""
|
||||
prompt = ""
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="do_not_pad",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids.to(self.text_encoder.device)
|
||||
self.empty_text_embed = self.text_encoder(text_input_ids)[0].to(self.dtype)
|
||||
|
||||
@torch.no_grad()
|
||||
def single_infer(self, rgb_in: torch.Tensor, num_inference_steps: int, show_pbar: bool) -> torch.Tensor:
|
||||
"""
|
||||
Perform an individual depth prediction without ensembling.
|
||||
|
||||
Args:
|
||||
rgb_in (`torch.Tensor`):
|
||||
Input RGB image.
|
||||
num_inference_steps (`int`):
|
||||
Number of diffusion denoisign steps (DDIM) during inference.
|
||||
show_pbar (`bool`):
|
||||
Display a progress bar of diffusion denoising.
|
||||
Returns:
|
||||
`torch.Tensor`: Predicted depth map.
|
||||
"""
|
||||
device = rgb_in.device
|
||||
|
||||
# Set timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps = self.scheduler.timesteps # [T]
|
||||
|
||||
# Encode image
|
||||
rgb_latent = self._encode_rgb(rgb_in)
|
||||
|
||||
# Initial depth map (noise)
|
||||
depth_latent = torch.randn(rgb_latent.shape, device=device, dtype=self.dtype) # [B, 4, h, w]
|
||||
|
||||
# Batched empty text embedding
|
||||
if self.empty_text_embed is None:
|
||||
self._encode_empty_text()
|
||||
batch_empty_text_embed = self.empty_text_embed.repeat((rgb_latent.shape[0], 1, 1)) # [B, 2, 1024]
|
||||
|
||||
# Denoising loop
|
||||
if show_pbar:
|
||||
iterable = tqdm(
|
||||
enumerate(timesteps),
|
||||
total=len(timesteps),
|
||||
leave=False,
|
||||
desc=" " * 4 + "Diffusion denoising",
|
||||
)
|
||||
else:
|
||||
iterable = enumerate(timesteps)
|
||||
|
||||
for i, t in iterable:
|
||||
unet_input = torch.cat([rgb_latent, depth_latent], dim=1) # this order is important
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(unet_input, t, encoder_hidden_states=batch_empty_text_embed).sample # [B, 4, h, w]
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
depth_latent = self.scheduler.step(noise_pred, t, depth_latent).prev_sample
|
||||
torch.cuda.empty_cache()
|
||||
depth = self._decode_depth(depth_latent)
|
||||
|
||||
# clip prediction
|
||||
depth = torch.clip(depth, -1.0, 1.0)
|
||||
# shift to [0, 1]
|
||||
depth = (depth + 1.0) / 2.0
|
||||
|
||||
return depth
|
||||
|
||||
def _encode_rgb(self, rgb_in: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Encode RGB image into latent.
|
||||
|
||||
Args:
|
||||
rgb_in (`torch.Tensor`):
|
||||
Input RGB image to be encoded.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`: Image latent.
|
||||
"""
|
||||
# encode
|
||||
h = self.vae.encoder(rgb_in)
|
||||
moments = self.vae.quant_conv(h)
|
||||
mean, logvar = torch.chunk(moments, 2, dim=1)
|
||||
# scale latent
|
||||
rgb_latent = mean * self.rgb_latent_scale_factor
|
||||
return rgb_latent
|
||||
|
||||
def _decode_depth(self, depth_latent: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Decode depth latent into depth map.
|
||||
|
||||
Args:
|
||||
depth_latent (`torch.Tensor`):
|
||||
Depth latent to be decoded.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`: Decoded depth map.
|
||||
"""
|
||||
# scale latent
|
||||
depth_latent = depth_latent / self.depth_latent_scale_factor
|
||||
# decode
|
||||
z = self.vae.post_quant_conv(depth_latent)
|
||||
stacked = self.vae.decoder(z)
|
||||
# mean of output channels
|
||||
depth_mean = stacked.mean(dim=1, keepdim=True)
|
||||
return depth_mean
|
||||
|
||||
@staticmethod
|
||||
def resize_max_res(img: Image.Image, max_edge_resolution: int) -> Image.Image:
|
||||
"""
|
||||
Resize image to limit maximum edge length while keeping aspect ratio.
|
||||
|
||||
Args:
|
||||
img (`Image.Image`):
|
||||
Image to be resized.
|
||||
max_edge_resolution (`int`):
|
||||
Maximum edge length (pixel).
|
||||
|
||||
Returns:
|
||||
`Image.Image`: Resized image.
|
||||
"""
|
||||
original_width, original_height = img.size
|
||||
downscale_factor = min(max_edge_resolution / original_width, max_edge_resolution / original_height)
|
||||
|
||||
new_width = int(original_width * downscale_factor)
|
||||
new_height = int(original_height * downscale_factor)
|
||||
|
||||
resized_img = img.resize((new_width, new_height))
|
||||
return resized_img
|
||||
|
||||
@staticmethod
|
||||
def colorize_depth_maps(depth_map, min_depth, max_depth, cmap="Spectral", valid_mask=None):
|
||||
"""
|
||||
Colorize depth maps.
|
||||
"""
|
||||
assert len(depth_map.shape) >= 2, "Invalid dimension"
|
||||
|
||||
if isinstance(depth_map, torch.Tensor):
|
||||
depth = depth_map.detach().clone().squeeze().numpy()
|
||||
elif isinstance(depth_map, np.ndarray):
|
||||
depth = depth_map.copy().squeeze()
|
||||
# reshape to [ (B,) H, W ]
|
||||
if depth.ndim < 3:
|
||||
depth = depth[np.newaxis, :, :]
|
||||
|
||||
# colorize
|
||||
cm = matplotlib.colormaps[cmap]
|
||||
depth = ((depth - min_depth) / (max_depth - min_depth)).clip(0, 1)
|
||||
img_colored_np = cm(depth, bytes=False)[:, :, :, 0:3] # value from 0 to 1
|
||||
img_colored_np = np.rollaxis(img_colored_np, 3, 1)
|
||||
|
||||
if valid_mask is not None:
|
||||
if isinstance(depth_map, torch.Tensor):
|
||||
valid_mask = valid_mask.detach().numpy()
|
||||
valid_mask = valid_mask.squeeze() # [H, W] or [B, H, W]
|
||||
if valid_mask.ndim < 3:
|
||||
valid_mask = valid_mask[np.newaxis, np.newaxis, :, :]
|
||||
else:
|
||||
valid_mask = valid_mask[:, np.newaxis, :, :]
|
||||
valid_mask = np.repeat(valid_mask, 3, axis=1)
|
||||
img_colored_np[~valid_mask] = 0
|
||||
|
||||
if isinstance(depth_map, torch.Tensor):
|
||||
img_colored = torch.from_numpy(img_colored_np).float()
|
||||
elif isinstance(depth_map, np.ndarray):
|
||||
img_colored = img_colored_np
|
||||
|
||||
return img_colored
|
||||
|
||||
@staticmethod
|
||||
def chw2hwc(chw):
|
||||
assert 3 == len(chw.shape)
|
||||
if isinstance(chw, torch.Tensor):
|
||||
hwc = torch.permute(chw, (1, 2, 0))
|
||||
elif isinstance(chw, np.ndarray):
|
||||
hwc = np.moveaxis(chw, 0, -1)
|
||||
return hwc
|
||||
|
||||
@staticmethod
|
||||
def _find_batch_size(ensemble_size: int, input_res: int, dtype: torch.dtype) -> int:
|
||||
"""
|
||||
Automatically search for suitable operating batch size.
|
||||
|
||||
Args:
|
||||
ensemble_size (`int`):
|
||||
Number of predictions to be ensembled.
|
||||
input_res (`int`):
|
||||
Operating resolution of the input image.
|
||||
|
||||
Returns:
|
||||
`int`: Operating batch size.
|
||||
"""
|
||||
# Search table for suggested max. inference batch size
|
||||
bs_search_table = [
|
||||
# tested on A100-PCIE-80GB
|
||||
{"res": 768, "total_vram": 79, "bs": 35, "dtype": torch.float32},
|
||||
{"res": 1024, "total_vram": 79, "bs": 20, "dtype": torch.float32},
|
||||
# tested on A100-PCIE-40GB
|
||||
{"res": 768, "total_vram": 39, "bs": 15, "dtype": torch.float32},
|
||||
{"res": 1024, "total_vram": 39, "bs": 8, "dtype": torch.float32},
|
||||
{"res": 768, "total_vram": 39, "bs": 30, "dtype": torch.float16},
|
||||
{"res": 1024, "total_vram": 39, "bs": 15, "dtype": torch.float16},
|
||||
# tested on RTX3090, RTX4090
|
||||
{"res": 512, "total_vram": 23, "bs": 20, "dtype": torch.float32},
|
||||
{"res": 768, "total_vram": 23, "bs": 7, "dtype": torch.float32},
|
||||
{"res": 1024, "total_vram": 23, "bs": 3, "dtype": torch.float32},
|
||||
{"res": 512, "total_vram": 23, "bs": 40, "dtype": torch.float16},
|
||||
{"res": 768, "total_vram": 23, "bs": 18, "dtype": torch.float16},
|
||||
{"res": 1024, "total_vram": 23, "bs": 10, "dtype": torch.float16},
|
||||
# tested on GTX1080Ti
|
||||
{"res": 512, "total_vram": 10, "bs": 5, "dtype": torch.float32},
|
||||
{"res": 768, "total_vram": 10, "bs": 2, "dtype": torch.float32},
|
||||
{"res": 512, "total_vram": 10, "bs": 10, "dtype": torch.float16},
|
||||
{"res": 768, "total_vram": 10, "bs": 5, "dtype": torch.float16},
|
||||
{"res": 1024, "total_vram": 10, "bs": 3, "dtype": torch.float16},
|
||||
]
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
return 1
|
||||
|
||||
total_vram = torch.cuda.mem_get_info()[1] / 1024.0**3
|
||||
filtered_bs_search_table = [s for s in bs_search_table if s["dtype"] == dtype]
|
||||
for settings in sorted(
|
||||
filtered_bs_search_table,
|
||||
key=lambda k: (k["res"], -k["total_vram"]),
|
||||
):
|
||||
if input_res <= settings["res"] and total_vram >= settings["total_vram"]:
|
||||
bs = settings["bs"]
|
||||
if bs > ensemble_size:
|
||||
bs = ensemble_size
|
||||
elif bs > math.ceil(ensemble_size / 2) and bs < ensemble_size:
|
||||
bs = math.ceil(ensemble_size / 2)
|
||||
return bs
|
||||
|
||||
return 1
|
||||
|
||||
@staticmethod
|
||||
def ensemble_depths(
|
||||
input_images: torch.Tensor,
|
||||
regularizer_strength: float = 0.02,
|
||||
max_iter: int = 2,
|
||||
tol: float = 1e-3,
|
||||
reduction: str = "median",
|
||||
max_res: int = None,
|
||||
):
|
||||
"""
|
||||
To ensemble multiple affine-invariant depth images (up to scale and shift),
|
||||
by aligning estimating the scale and shift
|
||||
"""
|
||||
|
||||
def inter_distances(tensors: torch.Tensor):
|
||||
"""
|
||||
To calculate the distance between each two depth maps.
|
||||
"""
|
||||
distances = []
|
||||
for i, j in torch.combinations(torch.arange(tensors.shape[0])):
|
||||
arr1 = tensors[i : i + 1]
|
||||
arr2 = tensors[j : j + 1]
|
||||
distances.append(arr1 - arr2)
|
||||
dist = torch.concatenate(distances, dim=0)
|
||||
return dist
|
||||
|
||||
device = input_images.device
|
||||
dtype = input_images.dtype
|
||||
np_dtype = np.float32
|
||||
|
||||
original_input = input_images.clone()
|
||||
n_img = input_images.shape[0]
|
||||
ori_shape = input_images.shape
|
||||
|
||||
if max_res is not None:
|
||||
scale_factor = torch.min(max_res / torch.tensor(ori_shape[-2:]))
|
||||
if scale_factor < 1:
|
||||
downscaler = torch.nn.Upsample(scale_factor=scale_factor, mode="nearest")
|
||||
input_images = downscaler(torch.from_numpy(input_images)).numpy()
|
||||
|
||||
# init guess
|
||||
_min = np.min(input_images.reshape((n_img, -1)).cpu().numpy(), axis=1)
|
||||
_max = np.max(input_images.reshape((n_img, -1)).cpu().numpy(), axis=1)
|
||||
s_init = 1.0 / (_max - _min).reshape((-1, 1, 1))
|
||||
t_init = (-1 * s_init.flatten() * _min.flatten()).reshape((-1, 1, 1))
|
||||
x = np.concatenate([s_init, t_init]).reshape(-1).astype(np_dtype)
|
||||
|
||||
input_images = input_images.to(device)
|
||||
|
||||
# objective function
|
||||
def closure(x):
|
||||
l = len(x)
|
||||
s = x[: int(l / 2)]
|
||||
t = x[int(l / 2) :]
|
||||
s = torch.from_numpy(s).to(dtype=dtype).to(device)
|
||||
t = torch.from_numpy(t).to(dtype=dtype).to(device)
|
||||
|
||||
transformed_arrays = input_images * s.view((-1, 1, 1)) + t.view((-1, 1, 1))
|
||||
dists = inter_distances(transformed_arrays)
|
||||
sqrt_dist = torch.sqrt(torch.mean(dists**2))
|
||||
|
||||
if "mean" == reduction:
|
||||
pred = torch.mean(transformed_arrays, dim=0)
|
||||
elif "median" == reduction:
|
||||
pred = torch.median(transformed_arrays, dim=0).values
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
near_err = torch.sqrt((0 - torch.min(pred)) ** 2)
|
||||
far_err = torch.sqrt((1 - torch.max(pred)) ** 2)
|
||||
|
||||
err = sqrt_dist + (near_err + far_err) * regularizer_strength
|
||||
err = err.detach().cpu().numpy().astype(np_dtype)
|
||||
return err
|
||||
|
||||
res = minimize(
|
||||
closure,
|
||||
x,
|
||||
method="BFGS",
|
||||
tol=tol,
|
||||
options={"maxiter": max_iter, "disp": False},
|
||||
)
|
||||
x = res.x
|
||||
l = len(x)
|
||||
s = x[: int(l / 2)]
|
||||
t = x[int(l / 2) :]
|
||||
|
||||
# Prediction
|
||||
s = torch.from_numpy(s).to(dtype=dtype).to(device)
|
||||
t = torch.from_numpy(t).to(dtype=dtype).to(device)
|
||||
transformed_arrays = original_input * s.view(-1, 1, 1) + t.view(-1, 1, 1)
|
||||
if "mean" == reduction:
|
||||
aligned_images = torch.mean(transformed_arrays, dim=0)
|
||||
std = torch.std(transformed_arrays, dim=0)
|
||||
uncertainty = std
|
||||
elif "median" == reduction:
|
||||
aligned_images = torch.median(transformed_arrays, dim=0).values
|
||||
# MAD (median absolute deviation) as uncertainty indicator
|
||||
abs_dev = torch.abs(transformed_arrays - aligned_images)
|
||||
mad = torch.median(abs_dev, dim=0).values
|
||||
uncertainty = mad
|
||||
else:
|
||||
raise ValueError(f"Unknown reduction method: {reduction}")
|
||||
|
||||
# Scale and shift to [0, 1]
|
||||
_min = torch.min(aligned_images)
|
||||
_max = torch.max(aligned_images)
|
||||
aligned_images = (aligned_images - _min) / (_max - _min)
|
||||
uncertainty /= _max - _min
|
||||
|
||||
return aligned_images, uncertainty
|
||||
@@ -1004,7 +1004,7 @@ class TensorRTStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline):
|
||||
"""
|
||||
self.generator = generator
|
||||
self.denoising_steps = num_inference_steps
|
||||
self._guidance_scale = guidance_scale
|
||||
self.guidance_scale = guidance_scale
|
||||
|
||||
# Pre-compute latent input scales and linear multistep coefficients
|
||||
self.scheduler.set_timesteps(self.denoising_steps, device=self.torch_device)
|
||||
|
||||
@@ -94,7 +94,7 @@ accelerate launch train_lcm_distill_lora_sd_wds.py \
|
||||
--mixed_precision=fp16 \
|
||||
--resolution=512 \
|
||||
--lora_rank=64 \
|
||||
--learning_rate=1e-4 --loss_type="huber" --adam_weight_decay=0.0 \
|
||||
--learning_rate=1e-6 --loss_type="huber" --adam_weight_decay=0.0 \
|
||||
--max_train_steps=1000 \
|
||||
--max_train_samples=4000000 \
|
||||
--dataloader_num_workers=8 \
|
||||
|
||||
@@ -96,7 +96,7 @@ accelerate launch train_lcm_distill_lora_sdxl_wds.py \
|
||||
--mixed_precision=fp16 \
|
||||
--resolution=1024 \
|
||||
--lora_rank=64 \
|
||||
--learning_rate=1e-4 --loss_type="huber" --use_fix_crop_and_size --adam_weight_decay=0.0 \
|
||||
--learning_rate=1e-6 --loss_type="huber" --use_fix_crop_and_size --adam_weight_decay=0.0 \
|
||||
--max_train_steps=1000 \
|
||||
--max_train_samples=4000000 \
|
||||
--dataloader_num_workers=8 \
|
||||
|
||||
@@ -65,7 +65,7 @@ class ControlNet(ExamplesTestsAccelerate):
|
||||
--train_batch_size=1
|
||||
--gradient_accumulation_steps=1
|
||||
--controlnet_model_name_or_path=hf-internal-testing/tiny-controlnet
|
||||
--max_train_steps=6
|
||||
--max_train_steps=9
|
||||
--checkpointing_steps=2
|
||||
""".split()
|
||||
|
||||
@@ -73,7 +73,7 @@ class ControlNet(ExamplesTestsAccelerate):
|
||||
|
||||
self.assertEqual(
|
||||
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
|
||||
{"checkpoint-2", "checkpoint-4", "checkpoint-6"},
|
||||
{"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"},
|
||||
)
|
||||
|
||||
resume_run_args = f"""
|
||||
@@ -85,15 +85,18 @@ class ControlNet(ExamplesTestsAccelerate):
|
||||
--train_batch_size=1
|
||||
--gradient_accumulation_steps=1
|
||||
--controlnet_model_name_or_path=hf-internal-testing/tiny-controlnet
|
||||
--max_train_steps=8
|
||||
--max_train_steps=11
|
||||
--checkpointing_steps=2
|
||||
--resume_from_checkpoint=checkpoint-6
|
||||
--checkpoints_total_limit=2
|
||||
--resume_from_checkpoint=checkpoint-8
|
||||
--checkpoints_total_limit=3
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + resume_run_args)
|
||||
|
||||
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})
|
||||
self.assertEqual(
|
||||
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
|
||||
{"checkpoint-8", "checkpoint-10", "checkpoint-12"},
|
||||
)
|
||||
|
||||
|
||||
class ControlNetSDXL(ExamplesTestsAccelerate):
|
||||
@@ -108,7 +111,7 @@ class ControlNetSDXL(ExamplesTestsAccelerate):
|
||||
--train_batch_size=1
|
||||
--gradient_accumulation_steps=1
|
||||
--controlnet_model_name_or_path=hf-internal-testing/tiny-controlnet-sdxl
|
||||
--max_train_steps=4
|
||||
--max_train_steps=9
|
||||
--checkpointing_steps=2
|
||||
""".split()
|
||||
|
||||
|
||||
@@ -76,7 +76,10 @@ class CustomDiffusion(ExamplesTestsAccelerate):
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
|
||||
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-4", "checkpoint-6"})
|
||||
self.assertEqual(
|
||||
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
|
||||
{"checkpoint-4", "checkpoint-6"},
|
||||
)
|
||||
|
||||
def test_custom_diffusion_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
@@ -90,7 +93,7 @@ class CustomDiffusion(ExamplesTestsAccelerate):
|
||||
--train_batch_size=1
|
||||
--modifier_token=<new1>
|
||||
--dataloader_num_workers=0
|
||||
--max_train_steps=4
|
||||
--max_train_steps=9
|
||||
--checkpointing_steps=2
|
||||
--no_safe_serialization
|
||||
""".split()
|
||||
@@ -99,7 +102,7 @@ class CustomDiffusion(ExamplesTestsAccelerate):
|
||||
|
||||
self.assertEqual(
|
||||
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
|
||||
{"checkpoint-2", "checkpoint-4"},
|
||||
{"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"},
|
||||
)
|
||||
|
||||
resume_run_args = f"""
|
||||
@@ -112,13 +115,16 @@ class CustomDiffusion(ExamplesTestsAccelerate):
|
||||
--train_batch_size=1
|
||||
--modifier_token=<new1>
|
||||
--dataloader_num_workers=0
|
||||
--max_train_steps=8
|
||||
--max_train_steps=11
|
||||
--checkpointing_steps=2
|
||||
--resume_from_checkpoint=checkpoint-4
|
||||
--checkpoints_total_limit=2
|
||||
--resume_from_checkpoint=checkpoint-8
|
||||
--checkpoints_total_limit=3
|
||||
--no_safe_serialization
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + resume_run_args)
|
||||
|
||||
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})
|
||||
self.assertEqual(
|
||||
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
|
||||
{"checkpoint-6", "checkpoint-8", "checkpoint-10"},
|
||||
)
|
||||
|
||||
@@ -89,7 +89,7 @@ class DreamBooth(ExamplesTestsAccelerate):
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Run training script with checkpointing
|
||||
# max_train_steps == 4, checkpointing_steps == 2
|
||||
# max_train_steps == 5, checkpointing_steps == 2
|
||||
# Should create checkpoints at steps 2, 4
|
||||
|
||||
initial_run_args = f"""
|
||||
@@ -100,7 +100,7 @@ class DreamBooth(ExamplesTestsAccelerate):
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 4
|
||||
--max_train_steps 5
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
@@ -114,7 +114,7 @@ class DreamBooth(ExamplesTestsAccelerate):
|
||||
|
||||
# check can run the original fully trained output pipeline
|
||||
pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)
|
||||
pipe(instance_prompt, num_inference_steps=1)
|
||||
pipe(instance_prompt, num_inference_steps=2)
|
||||
|
||||
# check checkpoint directories exist
|
||||
self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-2")))
|
||||
@@ -123,7 +123,7 @@ class DreamBooth(ExamplesTestsAccelerate):
|
||||
# check can run an intermediate checkpoint
|
||||
unet = UNet2DConditionModel.from_pretrained(tmpdir, subfolder="checkpoint-2/unet")
|
||||
pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, unet=unet, safety_checker=None)
|
||||
pipe(instance_prompt, num_inference_steps=1)
|
||||
pipe(instance_prompt, num_inference_steps=2)
|
||||
|
||||
# Remove checkpoint 2 so that we can check only later checkpoints exist after resuming
|
||||
shutil.rmtree(os.path.join(tmpdir, "checkpoint-2"))
|
||||
@@ -138,7 +138,7 @@ class DreamBooth(ExamplesTestsAccelerate):
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 6
|
||||
--max_train_steps 7
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
@@ -153,7 +153,7 @@ class DreamBooth(ExamplesTestsAccelerate):
|
||||
|
||||
# check can run new fully trained pipeline
|
||||
pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)
|
||||
pipe(instance_prompt, num_inference_steps=1)
|
||||
pipe(instance_prompt, num_inference_steps=2)
|
||||
|
||||
# check old checkpoints do not exist
|
||||
self.assertFalse(os.path.isdir(os.path.join(tmpdir, "checkpoint-2")))
|
||||
@@ -196,7 +196,7 @@ class DreamBooth(ExamplesTestsAccelerate):
|
||||
--resolution=64
|
||||
--train_batch_size=1
|
||||
--gradient_accumulation_steps=1
|
||||
--max_train_steps=4
|
||||
--max_train_steps=9
|
||||
--checkpointing_steps=2
|
||||
""".split()
|
||||
|
||||
@@ -204,7 +204,7 @@ class DreamBooth(ExamplesTestsAccelerate):
|
||||
|
||||
self.assertEqual(
|
||||
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
|
||||
{"checkpoint-2", "checkpoint-4"},
|
||||
{"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"},
|
||||
)
|
||||
|
||||
resume_run_args = f"""
|
||||
@@ -216,12 +216,15 @@ class DreamBooth(ExamplesTestsAccelerate):
|
||||
--resolution=64
|
||||
--train_batch_size=1
|
||||
--gradient_accumulation_steps=1
|
||||
--max_train_steps=8
|
||||
--max_train_steps=11
|
||||
--checkpointing_steps=2
|
||||
--resume_from_checkpoint=checkpoint-4
|
||||
--checkpoints_total_limit=2
|
||||
--resume_from_checkpoint=checkpoint-8
|
||||
--checkpoints_total_limit=3
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + resume_run_args)
|
||||
|
||||
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})
|
||||
self.assertEqual(
|
||||
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
|
||||
{"checkpoint-6", "checkpoint-8", "checkpoint-10"},
|
||||
)
|
||||
|
||||
@@ -135,13 +135,16 @@ class DreamBoothLoRA(ExamplesTestsAccelerate):
|
||||
--resolution=64
|
||||
--train_batch_size=1
|
||||
--gradient_accumulation_steps=1
|
||||
--max_train_steps=4
|
||||
--max_train_steps=9
|
||||
--checkpointing_steps=2
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
|
||||
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-2", "checkpoint-4"})
|
||||
self.assertEqual(
|
||||
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
|
||||
{"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"},
|
||||
)
|
||||
|
||||
resume_run_args = f"""
|
||||
examples/dreambooth/train_dreambooth_lora.py
|
||||
@@ -152,15 +155,18 @@ class DreamBoothLoRA(ExamplesTestsAccelerate):
|
||||
--resolution=64
|
||||
--train_batch_size=1
|
||||
--gradient_accumulation_steps=1
|
||||
--max_train_steps=8
|
||||
--max_train_steps=11
|
||||
--checkpointing_steps=2
|
||||
--resume_from_checkpoint=checkpoint-4
|
||||
--checkpoints_total_limit=2
|
||||
--resume_from_checkpoint=checkpoint-8
|
||||
--checkpoints_total_limit=3
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + resume_run_args)
|
||||
|
||||
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})
|
||||
self.assertEqual(
|
||||
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
|
||||
{"checkpoint-6", "checkpoint-8", "checkpoint-10"},
|
||||
)
|
||||
|
||||
def test_dreambooth_lora_if_model(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
@@ -322,7 +328,7 @@ class DreamBoothLoRASDXL(ExamplesTestsAccelerate):
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 6
|
||||
--max_train_steps 7
|
||||
--checkpointing_steps=2
|
||||
--checkpoints_total_limit=2
|
||||
--learning_rate 5.0e-04
|
||||
@@ -336,11 +342,14 @@ class DreamBoothLoRASDXL(ExamplesTestsAccelerate):
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained(pipeline_path)
|
||||
pipe.load_lora_weights(tmpdir)
|
||||
pipe("a prompt", num_inference_steps=1)
|
||||
pipe("a prompt", num_inference_steps=2)
|
||||
|
||||
# check checkpoint directories exist
|
||||
# checkpoint-2 should have been deleted
|
||||
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-4", "checkpoint-6"})
|
||||
self.assertEqual(
|
||||
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
|
||||
# checkpoint-2 should have been deleted
|
||||
{"checkpoint-4", "checkpoint-6"},
|
||||
)
|
||||
|
||||
def test_dreambooth_lora_sdxl_text_encoder_checkpointing_checkpoints_total_limit(self):
|
||||
pipeline_path = "hf-internal-testing/tiny-stable-diffusion-xl-pipe"
|
||||
|
||||
@@ -827,7 +827,6 @@ def main(args):
|
||||
# now we will add new LoRA weights to the attention layers
|
||||
unet_lora_config = LoraConfig(
|
||||
r=args.rank,
|
||||
lora_alpha=args.rank,
|
||||
init_lora_weights="gaussian",
|
||||
target_modules=["to_k", "to_q", "to_v", "to_out.0", "add_k_proj", "add_v_proj"],
|
||||
)
|
||||
@@ -836,10 +835,7 @@ def main(args):
|
||||
# The text encoder comes from 🤗 transformers, we will also attach adapters to it.
|
||||
if args.train_text_encoder:
|
||||
text_lora_config = LoraConfig(
|
||||
r=args.rank,
|
||||
lora_alpha=args.rank,
|
||||
init_lora_weights="gaussian",
|
||||
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
|
||||
r=args.rank, init_lora_weights="gaussian", target_modules=["q_proj", "k_proj", "v_proj", "out_proj"]
|
||||
)
|
||||
text_encoder.add_adapter(text_lora_config)
|
||||
|
||||
|
||||
@@ -978,10 +978,7 @@ def main(args):
|
||||
|
||||
# now we will add new LoRA weights to the attention layers
|
||||
unet_lora_config = LoraConfig(
|
||||
r=args.rank,
|
||||
lora_alpha=args.rank,
|
||||
init_lora_weights="gaussian",
|
||||
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
|
||||
r=args.rank, init_lora_weights="gaussian", target_modules=["to_k", "to_q", "to_v", "to_out.0"]
|
||||
)
|
||||
unet.add_adapter(unet_lora_config)
|
||||
|
||||
@@ -989,10 +986,7 @@ def main(args):
|
||||
# So, instead, we monkey-patch the forward calls of its attention-blocks.
|
||||
if args.train_text_encoder:
|
||||
text_lora_config = LoraConfig(
|
||||
r=args.rank,
|
||||
lora_alpha=args.rank,
|
||||
init_lora_weights="gaussian",
|
||||
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
|
||||
r=args.rank, init_lora_weights="gaussian", target_modules=["q_proj", "k_proj", "v_proj", "out_proj"]
|
||||
)
|
||||
text_encoder_one.add_adapter(text_lora_config)
|
||||
text_encoder_two.add_adapter(text_lora_config)
|
||||
@@ -1150,26 +1144,10 @@ def main(args):
|
||||
|
||||
optimizer_class = prodigyopt.Prodigy
|
||||
|
||||
if args.learning_rate <= 0.1:
|
||||
logger.warn(
|
||||
"Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0"
|
||||
)
|
||||
if args.train_text_encoder and args.text_encoder_lr:
|
||||
logger.warn(
|
||||
f"Learning rates were provided both for the unet and the text encoder- e.g. text_encoder_lr:"
|
||||
f" {args.text_encoder_lr} and learning_rate: {args.learning_rate}. "
|
||||
f"When using prodigy only learning_rate is used as the initial learning rate."
|
||||
)
|
||||
# changes the learning rate of text_encoder_parameters_one and text_encoder_parameters_two to be
|
||||
# --learning_rate
|
||||
params_to_optimize[1]["lr"] = args.learning_rate
|
||||
params_to_optimize[2]["lr"] = args.learning_rate
|
||||
|
||||
optimizer = optimizer_class(
|
||||
params_to_optimize,
|
||||
lr=args.learning_rate,
|
||||
betas=(args.adam_beta1, args.adam_beta2),
|
||||
beta3=args.prodigy_beta3,
|
||||
weight_decay=args.adam_weight_decay,
|
||||
eps=args.adam_epsilon,
|
||||
decouple=args.prodigy_decouple,
|
||||
|
||||
@@ -40,7 +40,7 @@ class InstructPix2Pix(ExamplesTestsAccelerate):
|
||||
--resolution=64
|
||||
--random_flip
|
||||
--train_batch_size=1
|
||||
--max_train_steps=6
|
||||
--max_train_steps=7
|
||||
--checkpointing_steps=2
|
||||
--checkpoints_total_limit=2
|
||||
--output_dir {tmpdir}
|
||||
@@ -63,7 +63,7 @@ class InstructPix2Pix(ExamplesTestsAccelerate):
|
||||
--resolution=64
|
||||
--random_flip
|
||||
--train_batch_size=1
|
||||
--max_train_steps=4
|
||||
--max_train_steps=9
|
||||
--checkpointing_steps=2
|
||||
--output_dir {tmpdir}
|
||||
--seed=0
|
||||
@@ -74,7 +74,7 @@ class InstructPix2Pix(ExamplesTestsAccelerate):
|
||||
# check checkpoint directories exist
|
||||
self.assertEqual(
|
||||
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
|
||||
{"checkpoint-2", "checkpoint-4"},
|
||||
{"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"},
|
||||
)
|
||||
|
||||
resume_run_args = f"""
|
||||
@@ -84,12 +84,12 @@ class InstructPix2Pix(ExamplesTestsAccelerate):
|
||||
--resolution=64
|
||||
--random_flip
|
||||
--train_batch_size=1
|
||||
--max_train_steps=8
|
||||
--max_train_steps=11
|
||||
--checkpointing_steps=2
|
||||
--output_dir {tmpdir}
|
||||
--seed=0
|
||||
--resume_from_checkpoint=checkpoint-4
|
||||
--checkpoints_total_limit=2
|
||||
--resume_from_checkpoint=checkpoint-8
|
||||
--checkpoints_total_limit=3
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + resume_run_args)
|
||||
@@ -97,5 +97,5 @@ class InstructPix2Pix(ExamplesTestsAccelerate):
|
||||
# check checkpoint directories exist
|
||||
self.assertEqual(
|
||||
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
|
||||
{"checkpoint-6", "checkpoint-8"},
|
||||
{"checkpoint-6", "checkpoint-8", "checkpoint-10"},
|
||||
)
|
||||
|
||||
@@ -64,7 +64,7 @@ class TextToImage(ExamplesTestsAccelerate):
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Run training script with checkpointing
|
||||
# max_train_steps == 4, checkpointing_steps == 2
|
||||
# max_train_steps == 5, checkpointing_steps == 2
|
||||
# Should create checkpoints at steps 2, 4
|
||||
|
||||
initial_run_args = f"""
|
||||
@@ -76,7 +76,7 @@ class TextToImage(ExamplesTestsAccelerate):
|
||||
--random_flip
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 4
|
||||
--max_train_steps 5
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
@@ -89,7 +89,7 @@ class TextToImage(ExamplesTestsAccelerate):
|
||||
run_command(self._launch_args + initial_run_args)
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)
|
||||
pipe(prompt, num_inference_steps=1)
|
||||
pipe(prompt, num_inference_steps=2)
|
||||
|
||||
# check checkpoint directories exist
|
||||
self.assertEqual(
|
||||
@@ -100,12 +100,12 @@ class TextToImage(ExamplesTestsAccelerate):
|
||||
# check can run an intermediate checkpoint
|
||||
unet = UNet2DConditionModel.from_pretrained(tmpdir, subfolder="checkpoint-2/unet")
|
||||
pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, unet=unet, safety_checker=None)
|
||||
pipe(prompt, num_inference_steps=1)
|
||||
pipe(prompt, num_inference_steps=2)
|
||||
|
||||
# Remove checkpoint 2 so that we can check only later checkpoints exist after resuming
|
||||
shutil.rmtree(os.path.join(tmpdir, "checkpoint-2"))
|
||||
|
||||
# Run training script for 2 total steps resuming from checkpoint 4
|
||||
# Run training script for 7 total steps resuming from checkpoint 4
|
||||
|
||||
resume_run_args = f"""
|
||||
examples/text_to_image/train_text_to_image.py
|
||||
@@ -116,13 +116,13 @@ class TextToImage(ExamplesTestsAccelerate):
|
||||
--random_flip
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 2
|
||||
--max_train_steps 7
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
--checkpointing_steps=1
|
||||
--checkpointing_steps=2
|
||||
--resume_from_checkpoint=checkpoint-4
|
||||
--seed=0
|
||||
""".split()
|
||||
@@ -131,13 +131,16 @@ class TextToImage(ExamplesTestsAccelerate):
|
||||
|
||||
# check can run new fully trained pipeline
|
||||
pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)
|
||||
pipe(prompt, num_inference_steps=1)
|
||||
pipe(prompt, num_inference_steps=2)
|
||||
|
||||
# no checkpoint-2 -> check old checkpoints do not exist
|
||||
# check new checkpoints exist
|
||||
self.assertEqual(
|
||||
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
|
||||
{"checkpoint-4", "checkpoint-5"},
|
||||
{
|
||||
# no checkpoint-2 -> check old checkpoints do not exist
|
||||
# check new checkpoints exist
|
||||
"checkpoint-4",
|
||||
"checkpoint-6",
|
||||
},
|
||||
)
|
||||
|
||||
def test_text_to_image_checkpointing_use_ema(self):
|
||||
@@ -146,7 +149,7 @@ class TextToImage(ExamplesTestsAccelerate):
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Run training script with checkpointing
|
||||
# max_train_steps == 4, checkpointing_steps == 2
|
||||
# max_train_steps == 5, checkpointing_steps == 2
|
||||
# Should create checkpoints at steps 2, 4
|
||||
|
||||
initial_run_args = f"""
|
||||
@@ -158,7 +161,7 @@ class TextToImage(ExamplesTestsAccelerate):
|
||||
--random_flip
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 4
|
||||
--max_train_steps 5
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
@@ -183,12 +186,12 @@ class TextToImage(ExamplesTestsAccelerate):
|
||||
# check can run an intermediate checkpoint
|
||||
unet = UNet2DConditionModel.from_pretrained(tmpdir, subfolder="checkpoint-2/unet")
|
||||
pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, unet=unet, safety_checker=None)
|
||||
pipe(prompt, num_inference_steps=1)
|
||||
pipe(prompt, num_inference_steps=2)
|
||||
|
||||
# Remove checkpoint 2 so that we can check only later checkpoints exist after resuming
|
||||
shutil.rmtree(os.path.join(tmpdir, "checkpoint-2"))
|
||||
|
||||
# Run training script for 2 total steps resuming from checkpoint 4
|
||||
# Run training script for 7 total steps resuming from checkpoint 4
|
||||
|
||||
resume_run_args = f"""
|
||||
examples/text_to_image/train_text_to_image.py
|
||||
@@ -199,13 +202,13 @@ class TextToImage(ExamplesTestsAccelerate):
|
||||
--random_flip
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 2
|
||||
--max_train_steps 7
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
--checkpointing_steps=1
|
||||
--checkpointing_steps=2
|
||||
--resume_from_checkpoint=checkpoint-4
|
||||
--use_ema
|
||||
--seed=0
|
||||
@@ -215,13 +218,16 @@ class TextToImage(ExamplesTestsAccelerate):
|
||||
|
||||
# check can run new fully trained pipeline
|
||||
pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)
|
||||
pipe(prompt, num_inference_steps=1)
|
||||
pipe(prompt, num_inference_steps=2)
|
||||
|
||||
# no checkpoint-2 -> check old checkpoints do not exist
|
||||
# check new checkpoints exist
|
||||
self.assertEqual(
|
||||
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
|
||||
{"checkpoint-4", "checkpoint-5"},
|
||||
{
|
||||
# no checkpoint-2 -> check old checkpoints do not exist
|
||||
# check new checkpoints exist
|
||||
"checkpoint-4",
|
||||
"checkpoint-6",
|
||||
},
|
||||
)
|
||||
|
||||
def test_text_to_image_checkpointing_checkpoints_total_limit(self):
|
||||
@@ -230,7 +236,7 @@ class TextToImage(ExamplesTestsAccelerate):
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Run training script with checkpointing
|
||||
# max_train_steps == 6, checkpointing_steps == 2, checkpoints_total_limit == 2
|
||||
# max_train_steps == 7, checkpointing_steps == 2, checkpoints_total_limit == 2
|
||||
# Should create checkpoints at steps 2, 4, 6
|
||||
# with checkpoint at step 2 deleted
|
||||
|
||||
@@ -243,7 +249,7 @@ class TextToImage(ExamplesTestsAccelerate):
|
||||
--random_flip
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 6
|
||||
--max_train_steps 7
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
@@ -257,11 +263,14 @@ class TextToImage(ExamplesTestsAccelerate):
|
||||
run_command(self._launch_args + initial_run_args)
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)
|
||||
pipe(prompt, num_inference_steps=1)
|
||||
pipe(prompt, num_inference_steps=2)
|
||||
|
||||
# check checkpoint directories exist
|
||||
# checkpoint-2 should have been deleted
|
||||
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-4", "checkpoint-6"})
|
||||
self.assertEqual(
|
||||
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
|
||||
# checkpoint-2 should have been deleted
|
||||
{"checkpoint-4", "checkpoint-6"},
|
||||
)
|
||||
|
||||
def test_text_to_image_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe"
|
||||
@@ -269,8 +278,8 @@ class TextToImage(ExamplesTestsAccelerate):
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Run training script with checkpointing
|
||||
# max_train_steps == 4, checkpointing_steps == 2
|
||||
# Should create checkpoints at steps 2, 4
|
||||
# max_train_steps == 9, checkpointing_steps == 2
|
||||
# Should create checkpoints at steps 2, 4, 6, 8
|
||||
|
||||
initial_run_args = f"""
|
||||
examples/text_to_image/train_text_to_image.py
|
||||
@@ -281,7 +290,7 @@ class TextToImage(ExamplesTestsAccelerate):
|
||||
--random_flip
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 4
|
||||
--max_train_steps 9
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
@@ -294,15 +303,15 @@ class TextToImage(ExamplesTestsAccelerate):
|
||||
run_command(self._launch_args + initial_run_args)
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)
|
||||
pipe(prompt, num_inference_steps=1)
|
||||
pipe(prompt, num_inference_steps=2)
|
||||
|
||||
# check checkpoint directories exist
|
||||
self.assertEqual(
|
||||
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
|
||||
{"checkpoint-2", "checkpoint-4"},
|
||||
{"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"},
|
||||
)
|
||||
|
||||
# resume and we should try to checkpoint at 6, where we'll have to remove
|
||||
# resume and we should try to checkpoint at 10, where we'll have to remove
|
||||
# checkpoint-2 and checkpoint-4 instead of just a single previous checkpoint
|
||||
|
||||
resume_run_args = f"""
|
||||
@@ -314,27 +323,27 @@ class TextToImage(ExamplesTestsAccelerate):
|
||||
--random_flip
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 8
|
||||
--max_train_steps 11
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
--checkpointing_steps=2
|
||||
--resume_from_checkpoint=checkpoint-4
|
||||
--checkpoints_total_limit=2
|
||||
--resume_from_checkpoint=checkpoint-8
|
||||
--checkpoints_total_limit=3
|
||||
--seed=0
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + resume_run_args)
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)
|
||||
pipe(prompt, num_inference_steps=1)
|
||||
pipe(prompt, num_inference_steps=2)
|
||||
|
||||
# check checkpoint directories exist
|
||||
self.assertEqual(
|
||||
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
|
||||
{"checkpoint-6", "checkpoint-8"},
|
||||
{"checkpoint-6", "checkpoint-8", "checkpoint-10"},
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -41,7 +41,7 @@ class TextToImageLoRA(ExamplesTestsAccelerate):
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Run training script with checkpointing
|
||||
# max_train_steps == 6, checkpointing_steps == 2, checkpoints_total_limit == 2
|
||||
# max_train_steps == 7, checkpointing_steps == 2, checkpoints_total_limit == 2
|
||||
# Should create checkpoints at steps 2, 4, 6
|
||||
# with checkpoint at step 2 deleted
|
||||
|
||||
@@ -52,7 +52,7 @@ class TextToImageLoRA(ExamplesTestsAccelerate):
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 6
|
||||
--max_train_steps 7
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
@@ -66,11 +66,14 @@ class TextToImageLoRA(ExamplesTestsAccelerate):
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained(pipeline_path)
|
||||
pipe.load_lora_weights(tmpdir)
|
||||
pipe(prompt, num_inference_steps=1)
|
||||
pipe(prompt, num_inference_steps=2)
|
||||
|
||||
# check checkpoint directories exist
|
||||
# checkpoint-2 should have been deleted
|
||||
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-4", "checkpoint-6"})
|
||||
self.assertEqual(
|
||||
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
|
||||
# checkpoint-2 should have been deleted
|
||||
{"checkpoint-4", "checkpoint-6"},
|
||||
)
|
||||
|
||||
def test_text_to_image_lora_checkpointing_checkpoints_total_limit(self):
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe"
|
||||
@@ -78,7 +81,7 @@ class TextToImageLoRA(ExamplesTestsAccelerate):
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Run training script with checkpointing
|
||||
# max_train_steps == 6, checkpointing_steps == 2, checkpoints_total_limit == 2
|
||||
# max_train_steps == 7, checkpointing_steps == 2, checkpoints_total_limit == 2
|
||||
# Should create checkpoints at steps 2, 4, 6
|
||||
# with checkpoint at step 2 deleted
|
||||
|
||||
@@ -91,7 +94,7 @@ class TextToImageLoRA(ExamplesTestsAccelerate):
|
||||
--random_flip
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 6
|
||||
--max_train_steps 7
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
@@ -109,11 +112,14 @@ class TextToImageLoRA(ExamplesTestsAccelerate):
|
||||
"hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None
|
||||
)
|
||||
pipe.load_lora_weights(tmpdir)
|
||||
pipe(prompt, num_inference_steps=1)
|
||||
pipe(prompt, num_inference_steps=2)
|
||||
|
||||
# check checkpoint directories exist
|
||||
# checkpoint-2 should have been deleted
|
||||
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-4", "checkpoint-6"})
|
||||
self.assertEqual(
|
||||
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
|
||||
# checkpoint-2 should have been deleted
|
||||
{"checkpoint-4", "checkpoint-6"},
|
||||
)
|
||||
|
||||
def test_text_to_image_lora_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe"
|
||||
@@ -121,8 +127,8 @@ class TextToImageLoRA(ExamplesTestsAccelerate):
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Run training script with checkpointing
|
||||
# max_train_steps == 4, checkpointing_steps == 2
|
||||
# Should create checkpoints at steps 2, 4
|
||||
# max_train_steps == 9, checkpointing_steps == 2
|
||||
# Should create checkpoints at steps 2, 4, 6, 8
|
||||
|
||||
initial_run_args = f"""
|
||||
examples/text_to_image/train_text_to_image_lora.py
|
||||
@@ -133,7 +139,7 @@ class TextToImageLoRA(ExamplesTestsAccelerate):
|
||||
--random_flip
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 4
|
||||
--max_train_steps 9
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
@@ -150,15 +156,15 @@ class TextToImageLoRA(ExamplesTestsAccelerate):
|
||||
"hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None
|
||||
)
|
||||
pipe.load_lora_weights(tmpdir)
|
||||
pipe(prompt, num_inference_steps=1)
|
||||
pipe(prompt, num_inference_steps=2)
|
||||
|
||||
# check checkpoint directories exist
|
||||
self.assertEqual(
|
||||
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
|
||||
{"checkpoint-2", "checkpoint-4"},
|
||||
{"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"},
|
||||
)
|
||||
|
||||
# resume and we should try to checkpoint at 6, where we'll have to remove
|
||||
# resume and we should try to checkpoint at 10, where we'll have to remove
|
||||
# checkpoint-2 and checkpoint-4 instead of just a single previous checkpoint
|
||||
|
||||
resume_run_args = f"""
|
||||
@@ -170,15 +176,15 @@ class TextToImageLoRA(ExamplesTestsAccelerate):
|
||||
--random_flip
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 8
|
||||
--max_train_steps 11
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
--checkpointing_steps=2
|
||||
--resume_from_checkpoint=checkpoint-4
|
||||
--checkpoints_total_limit=2
|
||||
--resume_from_checkpoint=checkpoint-8
|
||||
--checkpoints_total_limit=3
|
||||
--seed=0
|
||||
--num_validation_images=0
|
||||
""".split()
|
||||
@@ -189,12 +195,12 @@ class TextToImageLoRA(ExamplesTestsAccelerate):
|
||||
"hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None
|
||||
)
|
||||
pipe.load_lora_weights(tmpdir)
|
||||
pipe(prompt, num_inference_steps=1)
|
||||
pipe(prompt, num_inference_steps=2)
|
||||
|
||||
# check checkpoint directories exist
|
||||
self.assertEqual(
|
||||
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
|
||||
{"checkpoint-6", "checkpoint-8"},
|
||||
{"checkpoint-6", "checkpoint-8", "checkpoint-10"},
|
||||
)
|
||||
|
||||
|
||||
@@ -266,7 +272,7 @@ class TextToImageLoRASDXL(ExamplesTestsAccelerate):
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Run training script with checkpointing
|
||||
# max_train_steps == 6, checkpointing_steps == 2, checkpoints_total_limit == 2
|
||||
# max_train_steps == 7, checkpointing_steps == 2, checkpoints_total_limit == 2
|
||||
# Should create checkpoints at steps 2, 4, 6
|
||||
# with checkpoint at step 2 deleted
|
||||
|
||||
@@ -277,7 +283,7 @@ class TextToImageLoRASDXL(ExamplesTestsAccelerate):
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 6
|
||||
--max_train_steps 7
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
@@ -292,8 +298,11 @@ class TextToImageLoRASDXL(ExamplesTestsAccelerate):
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained(pipeline_path)
|
||||
pipe.load_lora_weights(tmpdir)
|
||||
pipe(prompt, num_inference_steps=1)
|
||||
pipe(prompt, num_inference_steps=2)
|
||||
|
||||
# check checkpoint directories exist
|
||||
# checkpoint-2 should have been deleted
|
||||
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-4", "checkpoint-6"})
|
||||
self.assertEqual(
|
||||
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
|
||||
# checkpoint-2 should have been deleted
|
||||
{"checkpoint-4", "checkpoint-6"},
|
||||
)
|
||||
|
||||
@@ -452,10 +452,7 @@ def main():
|
||||
param.requires_grad_(False)
|
||||
|
||||
unet_lora_config = LoraConfig(
|
||||
r=args.rank,
|
||||
lora_alpha=args.rank,
|
||||
init_lora_weights="gaussian",
|
||||
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
|
||||
r=args.rank, init_lora_weights="gaussian", target_modules=["to_k", "to_q", "to_v", "to_out.0"]
|
||||
)
|
||||
|
||||
# Move unet, vae and text_encoder to device and cast to weight_dtype
|
||||
@@ -847,11 +844,10 @@ def main():
|
||||
if args.seed is not None:
|
||||
generator = generator.manual_seed(args.seed)
|
||||
images = []
|
||||
with torch.cuda.amp.autocast():
|
||||
for _ in range(args.num_validation_images):
|
||||
images.append(
|
||||
pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]
|
||||
)
|
||||
for _ in range(args.num_validation_images):
|
||||
images.append(
|
||||
pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]
|
||||
)
|
||||
|
||||
for tracker in accelerator.trackers:
|
||||
if tracker.name == "tensorboard":
|
||||
@@ -917,11 +913,8 @@ def main():
|
||||
if args.seed is not None:
|
||||
generator = generator.manual_seed(args.seed)
|
||||
images = []
|
||||
with torch.cuda.amp.autocast():
|
||||
for _ in range(args.num_validation_images):
|
||||
images.append(
|
||||
pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]
|
||||
)
|
||||
for _ in range(args.num_validation_images):
|
||||
images.append(pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0])
|
||||
|
||||
for tracker in accelerator.trackers:
|
||||
if len(images) != 0:
|
||||
|
||||
@@ -609,10 +609,7 @@ def main(args):
|
||||
# now we will add new LoRA weights to the attention layers
|
||||
# Set correct lora layers
|
||||
unet_lora_config = LoraConfig(
|
||||
r=args.rank,
|
||||
lora_alpha=args.rank,
|
||||
init_lora_weights="gaussian",
|
||||
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
|
||||
r=args.rank, init_lora_weights="gaussian", target_modules=["to_k", "to_q", "to_v", "to_out.0"]
|
||||
)
|
||||
|
||||
unet.add_adapter(unet_lora_config)
|
||||
@@ -621,10 +618,7 @@ def main(args):
|
||||
if args.train_text_encoder:
|
||||
# ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
|
||||
text_lora_config = LoraConfig(
|
||||
r=args.rank,
|
||||
lora_alpha=args.rank,
|
||||
init_lora_weights="gaussian",
|
||||
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
|
||||
r=args.rank, init_lora_weights="gaussian", target_modules=["q_proj", "k_proj", "v_proj", "out_proj"]
|
||||
)
|
||||
text_encoder_one.add_adapter(text_lora_config)
|
||||
text_encoder_two.add_adapter(text_lora_config)
|
||||
|
||||
@@ -40,6 +40,8 @@ class TextualInversion(ExamplesTestsAccelerate):
|
||||
--learnable_property object
|
||||
--placeholder_token <cat-toy>
|
||||
--initializer_token a
|
||||
--validation_prompt <cat-toy>
|
||||
--validation_steps 1
|
||||
--save_steps 1
|
||||
--num_vectors 2
|
||||
--resolution 64
|
||||
@@ -66,6 +68,8 @@ class TextualInversion(ExamplesTestsAccelerate):
|
||||
--learnable_property object
|
||||
--placeholder_token <cat-toy>
|
||||
--initializer_token a
|
||||
--validation_prompt <cat-toy>
|
||||
--validation_steps 1
|
||||
--save_steps 1
|
||||
--num_vectors 2
|
||||
--resolution 64
|
||||
@@ -98,12 +102,14 @@ class TextualInversion(ExamplesTestsAccelerate):
|
||||
--learnable_property object
|
||||
--placeholder_token <cat-toy>
|
||||
--initializer_token a
|
||||
--validation_prompt <cat-toy>
|
||||
--validation_steps 1
|
||||
--save_steps 1
|
||||
--num_vectors 2
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 2
|
||||
--max_train_steps 3
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
@@ -117,7 +123,7 @@ class TextualInversion(ExamplesTestsAccelerate):
|
||||
# check checkpoint directories exist
|
||||
self.assertEqual(
|
||||
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
|
||||
{"checkpoint-1", "checkpoint-2"},
|
||||
{"checkpoint-1", "checkpoint-2", "checkpoint-3"},
|
||||
)
|
||||
|
||||
resume_run_args = f"""
|
||||
@@ -127,19 +133,21 @@ class TextualInversion(ExamplesTestsAccelerate):
|
||||
--learnable_property object
|
||||
--placeholder_token <cat-toy>
|
||||
--initializer_token a
|
||||
--validation_prompt <cat-toy>
|
||||
--validation_steps 1
|
||||
--save_steps 1
|
||||
--num_vectors 2
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 2
|
||||
--max_train_steps 4
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
--checkpointing_steps=1
|
||||
--resume_from_checkpoint=checkpoint-2
|
||||
--resume_from_checkpoint=checkpoint-3
|
||||
--checkpoints_total_limit=2
|
||||
""".split()
|
||||
|
||||
@@ -148,5 +156,5 @@ class TextualInversion(ExamplesTestsAccelerate):
|
||||
# check checkpoint directories exist
|
||||
self.assertEqual(
|
||||
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
|
||||
{"checkpoint-2", "checkpoint-3"},
|
||||
{"checkpoint-3", "checkpoint-4"},
|
||||
)
|
||||
|
||||
@@ -90,10 +90,10 @@ class Unconditional(ExamplesTestsAccelerate):
|
||||
--train_batch_size 1
|
||||
--num_epochs 1
|
||||
--gradient_accumulation_steps 1
|
||||
--ddpm_num_inference_steps 1
|
||||
--ddpm_num_inference_steps 2
|
||||
--learning_rate 1e-3
|
||||
--lr_warmup_steps 5
|
||||
--checkpointing_steps=2
|
||||
--checkpointing_steps=1
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + initial_run_args)
|
||||
@@ -101,7 +101,7 @@ class Unconditional(ExamplesTestsAccelerate):
|
||||
# check checkpoint directories exist
|
||||
self.assertEqual(
|
||||
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
|
||||
{"checkpoint-2", "checkpoint-4", "checkpoint-6"},
|
||||
{"checkpoint-1", "checkpoint-2", "checkpoint-3", "checkpoint-4", "checkpoint-5", "checkpoint-6"},
|
||||
)
|
||||
|
||||
resume_run_args = f"""
|
||||
@@ -113,12 +113,12 @@ class Unconditional(ExamplesTestsAccelerate):
|
||||
--train_batch_size 1
|
||||
--num_epochs 2
|
||||
--gradient_accumulation_steps 1
|
||||
--ddpm_num_inference_steps 1
|
||||
--ddpm_num_inference_steps 2
|
||||
--learning_rate 1e-3
|
||||
--lr_warmup_steps 5
|
||||
--resume_from_checkpoint=checkpoint-6
|
||||
--checkpointing_steps=2
|
||||
--checkpoints_total_limit=2
|
||||
--checkpoints_total_limit=3
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + resume_run_args)
|
||||
@@ -126,5 +126,5 @@ class Unconditional(ExamplesTestsAccelerate):
|
||||
# check checkpoint directories exist
|
||||
self.assertEqual(
|
||||
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
|
||||
{"checkpoint-10", "checkpoint-12"},
|
||||
{"checkpoint-8", "checkpoint-10", "checkpoint-12"},
|
||||
)
|
||||
|
||||
@@ -1,523 +0,0 @@
|
||||
import inspect
|
||||
import os
|
||||
from argparse import ArgumentParser
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from muse import MaskGiTUViT, VQGANModel
|
||||
from muse import PipelineMuse as OldPipelineMuse
|
||||
from transformers import CLIPTextModelWithProjection, CLIPTokenizer
|
||||
|
||||
from diffusers import VQModel
|
||||
from diffusers.models.attention_processor import AttnProcessor
|
||||
from diffusers.models.uvit_2d import UVit2DModel
|
||||
from diffusers.pipelines.amused.pipeline_amused import AmusedPipeline
|
||||
from diffusers.schedulers import AmusedScheduler
|
||||
|
||||
|
||||
torch.backends.cuda.enable_flash_sdp(False)
|
||||
torch.backends.cuda.enable_mem_efficient_sdp(False)
|
||||
torch.backends.cuda.enable_math_sdp(True)
|
||||
|
||||
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
|
||||
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
|
||||
torch.use_deterministic_algorithms(True)
|
||||
|
||||
# Enable CUDNN deterministic mode
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
|
||||
device = "cuda"
|
||||
|
||||
|
||||
def main():
|
||||
args = ArgumentParser()
|
||||
args.add_argument("--model_256", action="store_true")
|
||||
args.add_argument("--write_to", type=str, required=False, default=None)
|
||||
args.add_argument("--transformer_path", type=str, required=False, default=None)
|
||||
args = args.parse_args()
|
||||
|
||||
transformer_path = args.transformer_path
|
||||
subfolder = "transformer"
|
||||
|
||||
if transformer_path is None:
|
||||
if args.model_256:
|
||||
transformer_path = "openMUSE/muse-256"
|
||||
else:
|
||||
transformer_path = (
|
||||
"../research-run-512-checkpoints/research-run-512-with-downsample-checkpoint-554000/unwrapped_model/"
|
||||
)
|
||||
subfolder = None
|
||||
|
||||
old_transformer = MaskGiTUViT.from_pretrained(transformer_path, subfolder=subfolder)
|
||||
|
||||
old_transformer.to(device)
|
||||
|
||||
old_vae = VQGANModel.from_pretrained("openMUSE/muse-512", subfolder="vae")
|
||||
old_vae.to(device)
|
||||
|
||||
vqvae = make_vqvae(old_vae)
|
||||
|
||||
tokenizer = CLIPTokenizer.from_pretrained("openMUSE/muse-512", subfolder="text_encoder")
|
||||
|
||||
text_encoder = CLIPTextModelWithProjection.from_pretrained("openMUSE/muse-512", subfolder="text_encoder")
|
||||
text_encoder.to(device)
|
||||
|
||||
transformer = make_transformer(old_transformer, args.model_256)
|
||||
|
||||
scheduler = AmusedScheduler(mask_token_id=old_transformer.config.mask_token_id)
|
||||
|
||||
new_pipe = AmusedPipeline(
|
||||
vqvae=vqvae, tokenizer=tokenizer, text_encoder=text_encoder, transformer=transformer, scheduler=scheduler
|
||||
)
|
||||
|
||||
old_pipe = OldPipelineMuse(
|
||||
vae=old_vae, transformer=old_transformer, text_encoder=text_encoder, tokenizer=tokenizer
|
||||
)
|
||||
old_pipe.to(device)
|
||||
|
||||
if args.model_256:
|
||||
transformer_seq_len = 256
|
||||
orig_size = (256, 256)
|
||||
else:
|
||||
transformer_seq_len = 1024
|
||||
orig_size = (512, 512)
|
||||
|
||||
old_out = old_pipe(
|
||||
"dog",
|
||||
generator=torch.Generator(device).manual_seed(0),
|
||||
transformer_seq_len=transformer_seq_len,
|
||||
orig_size=orig_size,
|
||||
timesteps=12,
|
||||
)[0]
|
||||
|
||||
new_out = new_pipe("dog", generator=torch.Generator(device).manual_seed(0)).images[0]
|
||||
|
||||
old_out = np.array(old_out)
|
||||
new_out = np.array(new_out)
|
||||
|
||||
diff = np.abs(old_out.astype(np.float64) - new_out.astype(np.float64))
|
||||
|
||||
# assert diff diff.sum() == 0
|
||||
print("skipping pipeline full equivalence check")
|
||||
|
||||
print(f"max diff: {diff.max()}, diff.sum() / diff.size {diff.sum() / diff.size}")
|
||||
|
||||
if args.model_256:
|
||||
assert diff.max() <= 3
|
||||
assert diff.sum() / diff.size < 0.7
|
||||
else:
|
||||
assert diff.max() <= 1
|
||||
assert diff.sum() / diff.size < 0.4
|
||||
|
||||
if args.write_to is not None:
|
||||
new_pipe.save_pretrained(args.write_to)
|
||||
|
||||
|
||||
def make_transformer(old_transformer, model_256):
|
||||
args = dict(old_transformer.config)
|
||||
force_down_up_sample = args["force_down_up_sample"]
|
||||
|
||||
signature = inspect.signature(UVit2DModel.__init__)
|
||||
|
||||
args_ = {
|
||||
"downsample": force_down_up_sample,
|
||||
"upsample": force_down_up_sample,
|
||||
"block_out_channels": args["block_out_channels"][0],
|
||||
"sample_size": 16 if model_256 else 32,
|
||||
}
|
||||
|
||||
for s in list(signature.parameters.keys()):
|
||||
if s in ["self", "downsample", "upsample", "sample_size", "block_out_channels"]:
|
||||
continue
|
||||
|
||||
args_[s] = args[s]
|
||||
|
||||
new_transformer = UVit2DModel(**args_)
|
||||
new_transformer.to(device)
|
||||
|
||||
new_transformer.set_attn_processor(AttnProcessor())
|
||||
|
||||
state_dict = old_transformer.state_dict()
|
||||
|
||||
state_dict["cond_embed.linear_1.weight"] = state_dict.pop("cond_embed.0.weight")
|
||||
state_dict["cond_embed.linear_2.weight"] = state_dict.pop("cond_embed.2.weight")
|
||||
|
||||
for i in range(22):
|
||||
state_dict[f"transformer_layers.{i}.norm1.norm.weight"] = state_dict.pop(
|
||||
f"transformer_layers.{i}.attn_layer_norm.weight"
|
||||
)
|
||||
state_dict[f"transformer_layers.{i}.norm1.linear.weight"] = state_dict.pop(
|
||||
f"transformer_layers.{i}.self_attn_adaLN_modulation.mapper.weight"
|
||||
)
|
||||
|
||||
state_dict[f"transformer_layers.{i}.attn1.to_q.weight"] = state_dict.pop(
|
||||
f"transformer_layers.{i}.attention.query.weight"
|
||||
)
|
||||
state_dict[f"transformer_layers.{i}.attn1.to_k.weight"] = state_dict.pop(
|
||||
f"transformer_layers.{i}.attention.key.weight"
|
||||
)
|
||||
state_dict[f"transformer_layers.{i}.attn1.to_v.weight"] = state_dict.pop(
|
||||
f"transformer_layers.{i}.attention.value.weight"
|
||||
)
|
||||
state_dict[f"transformer_layers.{i}.attn1.to_out.0.weight"] = state_dict.pop(
|
||||
f"transformer_layers.{i}.attention.out.weight"
|
||||
)
|
||||
|
||||
state_dict[f"transformer_layers.{i}.norm2.norm.weight"] = state_dict.pop(
|
||||
f"transformer_layers.{i}.crossattn_layer_norm.weight"
|
||||
)
|
||||
state_dict[f"transformer_layers.{i}.norm2.linear.weight"] = state_dict.pop(
|
||||
f"transformer_layers.{i}.cross_attn_adaLN_modulation.mapper.weight"
|
||||
)
|
||||
|
||||
state_dict[f"transformer_layers.{i}.attn2.to_q.weight"] = state_dict.pop(
|
||||
f"transformer_layers.{i}.crossattention.query.weight"
|
||||
)
|
||||
state_dict[f"transformer_layers.{i}.attn2.to_k.weight"] = state_dict.pop(
|
||||
f"transformer_layers.{i}.crossattention.key.weight"
|
||||
)
|
||||
state_dict[f"transformer_layers.{i}.attn2.to_v.weight"] = state_dict.pop(
|
||||
f"transformer_layers.{i}.crossattention.value.weight"
|
||||
)
|
||||
state_dict[f"transformer_layers.{i}.attn2.to_out.0.weight"] = state_dict.pop(
|
||||
f"transformer_layers.{i}.crossattention.out.weight"
|
||||
)
|
||||
|
||||
state_dict[f"transformer_layers.{i}.norm3.norm.weight"] = state_dict.pop(
|
||||
f"transformer_layers.{i}.ffn.pre_mlp_layer_norm.weight"
|
||||
)
|
||||
state_dict[f"transformer_layers.{i}.norm3.linear.weight"] = state_dict.pop(
|
||||
f"transformer_layers.{i}.ffn.adaLN_modulation.mapper.weight"
|
||||
)
|
||||
|
||||
wi_0_weight = state_dict.pop(f"transformer_layers.{i}.ffn.wi_0.weight")
|
||||
wi_1_weight = state_dict.pop(f"transformer_layers.{i}.ffn.wi_1.weight")
|
||||
proj_weight = torch.concat([wi_1_weight, wi_0_weight], dim=0)
|
||||
state_dict[f"transformer_layers.{i}.ff.net.0.proj.weight"] = proj_weight
|
||||
|
||||
state_dict[f"transformer_layers.{i}.ff.net.2.weight"] = state_dict.pop(f"transformer_layers.{i}.ffn.wo.weight")
|
||||
|
||||
if force_down_up_sample:
|
||||
state_dict["down_block.downsample.norm.weight"] = state_dict.pop("down_blocks.0.downsample.0.norm.weight")
|
||||
state_dict["down_block.downsample.conv.weight"] = state_dict.pop("down_blocks.0.downsample.1.weight")
|
||||
|
||||
state_dict["up_block.upsample.norm.weight"] = state_dict.pop("up_blocks.0.upsample.0.norm.weight")
|
||||
state_dict["up_block.upsample.conv.weight"] = state_dict.pop("up_blocks.0.upsample.1.weight")
|
||||
|
||||
state_dict["mlm_layer.layer_norm.weight"] = state_dict.pop("mlm_layer.layer_norm.norm.weight")
|
||||
|
||||
for i in range(3):
|
||||
state_dict[f"down_block.res_blocks.{i}.norm.weight"] = state_dict.pop(
|
||||
f"down_blocks.0.res_blocks.{i}.norm.norm.weight"
|
||||
)
|
||||
state_dict[f"down_block.res_blocks.{i}.channelwise_linear_1.weight"] = state_dict.pop(
|
||||
f"down_blocks.0.res_blocks.{i}.channelwise.0.weight"
|
||||
)
|
||||
state_dict[f"down_block.res_blocks.{i}.channelwise_norm.gamma"] = state_dict.pop(
|
||||
f"down_blocks.0.res_blocks.{i}.channelwise.2.gamma"
|
||||
)
|
||||
state_dict[f"down_block.res_blocks.{i}.channelwise_norm.beta"] = state_dict.pop(
|
||||
f"down_blocks.0.res_blocks.{i}.channelwise.2.beta"
|
||||
)
|
||||
state_dict[f"down_block.res_blocks.{i}.channelwise_linear_2.weight"] = state_dict.pop(
|
||||
f"down_blocks.0.res_blocks.{i}.channelwise.4.weight"
|
||||
)
|
||||
state_dict[f"down_block.res_blocks.{i}.cond_embeds_mapper.weight"] = state_dict.pop(
|
||||
f"down_blocks.0.res_blocks.{i}.adaLN_modulation.mapper.weight"
|
||||
)
|
||||
|
||||
state_dict[f"down_block.attention_blocks.{i}.norm1.weight"] = state_dict.pop(
|
||||
f"down_blocks.0.attention_blocks.{i}.attn_layer_norm.weight"
|
||||
)
|
||||
state_dict[f"down_block.attention_blocks.{i}.attn1.to_q.weight"] = state_dict.pop(
|
||||
f"down_blocks.0.attention_blocks.{i}.attention.query.weight"
|
||||
)
|
||||
state_dict[f"down_block.attention_blocks.{i}.attn1.to_k.weight"] = state_dict.pop(
|
||||
f"down_blocks.0.attention_blocks.{i}.attention.key.weight"
|
||||
)
|
||||
state_dict[f"down_block.attention_blocks.{i}.attn1.to_v.weight"] = state_dict.pop(
|
||||
f"down_blocks.0.attention_blocks.{i}.attention.value.weight"
|
||||
)
|
||||
state_dict[f"down_block.attention_blocks.{i}.attn1.to_out.0.weight"] = state_dict.pop(
|
||||
f"down_blocks.0.attention_blocks.{i}.attention.out.weight"
|
||||
)
|
||||
|
||||
state_dict[f"down_block.attention_blocks.{i}.norm2.weight"] = state_dict.pop(
|
||||
f"down_blocks.0.attention_blocks.{i}.crossattn_layer_norm.weight"
|
||||
)
|
||||
state_dict[f"down_block.attention_blocks.{i}.attn2.to_q.weight"] = state_dict.pop(
|
||||
f"down_blocks.0.attention_blocks.{i}.crossattention.query.weight"
|
||||
)
|
||||
state_dict[f"down_block.attention_blocks.{i}.attn2.to_k.weight"] = state_dict.pop(
|
||||
f"down_blocks.0.attention_blocks.{i}.crossattention.key.weight"
|
||||
)
|
||||
state_dict[f"down_block.attention_blocks.{i}.attn2.to_v.weight"] = state_dict.pop(
|
||||
f"down_blocks.0.attention_blocks.{i}.crossattention.value.weight"
|
||||
)
|
||||
state_dict[f"down_block.attention_blocks.{i}.attn2.to_out.0.weight"] = state_dict.pop(
|
||||
f"down_blocks.0.attention_blocks.{i}.crossattention.out.weight"
|
||||
)
|
||||
|
||||
state_dict[f"up_block.res_blocks.{i}.norm.weight"] = state_dict.pop(
|
||||
f"up_blocks.0.res_blocks.{i}.norm.norm.weight"
|
||||
)
|
||||
state_dict[f"up_block.res_blocks.{i}.channelwise_linear_1.weight"] = state_dict.pop(
|
||||
f"up_blocks.0.res_blocks.{i}.channelwise.0.weight"
|
||||
)
|
||||
state_dict[f"up_block.res_blocks.{i}.channelwise_norm.gamma"] = state_dict.pop(
|
||||
f"up_blocks.0.res_blocks.{i}.channelwise.2.gamma"
|
||||
)
|
||||
state_dict[f"up_block.res_blocks.{i}.channelwise_norm.beta"] = state_dict.pop(
|
||||
f"up_blocks.0.res_blocks.{i}.channelwise.2.beta"
|
||||
)
|
||||
state_dict[f"up_block.res_blocks.{i}.channelwise_linear_2.weight"] = state_dict.pop(
|
||||
f"up_blocks.0.res_blocks.{i}.channelwise.4.weight"
|
||||
)
|
||||
state_dict[f"up_block.res_blocks.{i}.cond_embeds_mapper.weight"] = state_dict.pop(
|
||||
f"up_blocks.0.res_blocks.{i}.adaLN_modulation.mapper.weight"
|
||||
)
|
||||
|
||||
state_dict[f"up_block.attention_blocks.{i}.norm1.weight"] = state_dict.pop(
|
||||
f"up_blocks.0.attention_blocks.{i}.attn_layer_norm.weight"
|
||||
)
|
||||
state_dict[f"up_block.attention_blocks.{i}.attn1.to_q.weight"] = state_dict.pop(
|
||||
f"up_blocks.0.attention_blocks.{i}.attention.query.weight"
|
||||
)
|
||||
state_dict[f"up_block.attention_blocks.{i}.attn1.to_k.weight"] = state_dict.pop(
|
||||
f"up_blocks.0.attention_blocks.{i}.attention.key.weight"
|
||||
)
|
||||
state_dict[f"up_block.attention_blocks.{i}.attn1.to_v.weight"] = state_dict.pop(
|
||||
f"up_blocks.0.attention_blocks.{i}.attention.value.weight"
|
||||
)
|
||||
state_dict[f"up_block.attention_blocks.{i}.attn1.to_out.0.weight"] = state_dict.pop(
|
||||
f"up_blocks.0.attention_blocks.{i}.attention.out.weight"
|
||||
)
|
||||
|
||||
state_dict[f"up_block.attention_blocks.{i}.norm2.weight"] = state_dict.pop(
|
||||
f"up_blocks.0.attention_blocks.{i}.crossattn_layer_norm.weight"
|
||||
)
|
||||
state_dict[f"up_block.attention_blocks.{i}.attn2.to_q.weight"] = state_dict.pop(
|
||||
f"up_blocks.0.attention_blocks.{i}.crossattention.query.weight"
|
||||
)
|
||||
state_dict[f"up_block.attention_blocks.{i}.attn2.to_k.weight"] = state_dict.pop(
|
||||
f"up_blocks.0.attention_blocks.{i}.crossattention.key.weight"
|
||||
)
|
||||
state_dict[f"up_block.attention_blocks.{i}.attn2.to_v.weight"] = state_dict.pop(
|
||||
f"up_blocks.0.attention_blocks.{i}.crossattention.value.weight"
|
||||
)
|
||||
state_dict[f"up_block.attention_blocks.{i}.attn2.to_out.0.weight"] = state_dict.pop(
|
||||
f"up_blocks.0.attention_blocks.{i}.crossattention.out.weight"
|
||||
)
|
||||
|
||||
for key in list(state_dict.keys()):
|
||||
if key.startswith("up_blocks.0"):
|
||||
key_ = "up_block." + ".".join(key.split(".")[2:])
|
||||
state_dict[key_] = state_dict.pop(key)
|
||||
|
||||
if key.startswith("down_blocks.0"):
|
||||
key_ = "down_block." + ".".join(key.split(".")[2:])
|
||||
state_dict[key_] = state_dict.pop(key)
|
||||
|
||||
new_transformer.load_state_dict(state_dict)
|
||||
|
||||
input_ids = torch.randint(0, 10, (1, 32, 32), device=old_transformer.device)
|
||||
encoder_hidden_states = torch.randn((1, 77, 768), device=old_transformer.device)
|
||||
cond_embeds = torch.randn((1, 768), device=old_transformer.device)
|
||||
micro_conds = torch.tensor([[512, 512, 0, 0, 6]], dtype=torch.float32, device=old_transformer.device)
|
||||
|
||||
old_out = old_transformer(input_ids.reshape(1, -1), encoder_hidden_states, cond_embeds, micro_conds)
|
||||
old_out = old_out.reshape(1, 32, 32, 8192).permute(0, 3, 1, 2)
|
||||
|
||||
new_out = new_transformer(input_ids, encoder_hidden_states, cond_embeds, micro_conds)
|
||||
|
||||
# NOTE: these differences are solely due to using the geglu block that has a single linear layer of
|
||||
# double output dimension instead of two different linear layers
|
||||
max_diff = (old_out - new_out).abs().max()
|
||||
total_diff = (old_out - new_out).abs().sum()
|
||||
print(f"Transformer max_diff: {max_diff} total_diff: {total_diff}")
|
||||
assert max_diff < 0.01
|
||||
assert total_diff < 1500
|
||||
|
||||
return new_transformer
|
||||
|
||||
|
||||
def make_vqvae(old_vae):
|
||||
new_vae = VQModel(
|
||||
act_fn="silu",
|
||||
block_out_channels=[128, 256, 256, 512, 768],
|
||||
down_block_types=[
|
||||
"DownEncoderBlock2D",
|
||||
"DownEncoderBlock2D",
|
||||
"DownEncoderBlock2D",
|
||||
"DownEncoderBlock2D",
|
||||
"DownEncoderBlock2D",
|
||||
],
|
||||
in_channels=3,
|
||||
latent_channels=64,
|
||||
layers_per_block=2,
|
||||
norm_num_groups=32,
|
||||
num_vq_embeddings=8192,
|
||||
out_channels=3,
|
||||
sample_size=32,
|
||||
up_block_types=[
|
||||
"UpDecoderBlock2D",
|
||||
"UpDecoderBlock2D",
|
||||
"UpDecoderBlock2D",
|
||||
"UpDecoderBlock2D",
|
||||
"UpDecoderBlock2D",
|
||||
],
|
||||
mid_block_add_attention=False,
|
||||
lookup_from_codebook=True,
|
||||
)
|
||||
new_vae.to(device)
|
||||
|
||||
# fmt: off
|
||||
|
||||
new_state_dict = {}
|
||||
|
||||
old_state_dict = old_vae.state_dict()
|
||||
|
||||
new_state_dict["encoder.conv_in.weight"] = old_state_dict.pop("encoder.conv_in.weight")
|
||||
new_state_dict["encoder.conv_in.bias"] = old_state_dict.pop("encoder.conv_in.bias")
|
||||
|
||||
convert_vae_block_state_dict(old_state_dict, "encoder.down.0", new_state_dict, "encoder.down_blocks.0")
|
||||
convert_vae_block_state_dict(old_state_dict, "encoder.down.1", new_state_dict, "encoder.down_blocks.1")
|
||||
convert_vae_block_state_dict(old_state_dict, "encoder.down.2", new_state_dict, "encoder.down_blocks.2")
|
||||
convert_vae_block_state_dict(old_state_dict, "encoder.down.3", new_state_dict, "encoder.down_blocks.3")
|
||||
convert_vae_block_state_dict(old_state_dict, "encoder.down.4", new_state_dict, "encoder.down_blocks.4")
|
||||
|
||||
new_state_dict["encoder.mid_block.resnets.0.norm1.weight"] = old_state_dict.pop("encoder.mid.block_1.norm1.weight")
|
||||
new_state_dict["encoder.mid_block.resnets.0.norm1.bias"] = old_state_dict.pop("encoder.mid.block_1.norm1.bias")
|
||||
new_state_dict["encoder.mid_block.resnets.0.conv1.weight"] = old_state_dict.pop("encoder.mid.block_1.conv1.weight")
|
||||
new_state_dict["encoder.mid_block.resnets.0.conv1.bias"] = old_state_dict.pop("encoder.mid.block_1.conv1.bias")
|
||||
new_state_dict["encoder.mid_block.resnets.0.norm2.weight"] = old_state_dict.pop("encoder.mid.block_1.norm2.weight")
|
||||
new_state_dict["encoder.mid_block.resnets.0.norm2.bias"] = old_state_dict.pop("encoder.mid.block_1.norm2.bias")
|
||||
new_state_dict["encoder.mid_block.resnets.0.conv2.weight"] = old_state_dict.pop("encoder.mid.block_1.conv2.weight")
|
||||
new_state_dict["encoder.mid_block.resnets.0.conv2.bias"] = old_state_dict.pop("encoder.mid.block_1.conv2.bias")
|
||||
new_state_dict["encoder.mid_block.resnets.1.norm1.weight"] = old_state_dict.pop("encoder.mid.block_2.norm1.weight")
|
||||
new_state_dict["encoder.mid_block.resnets.1.norm1.bias"] = old_state_dict.pop("encoder.mid.block_2.norm1.bias")
|
||||
new_state_dict["encoder.mid_block.resnets.1.conv1.weight"] = old_state_dict.pop("encoder.mid.block_2.conv1.weight")
|
||||
new_state_dict["encoder.mid_block.resnets.1.conv1.bias"] = old_state_dict.pop("encoder.mid.block_2.conv1.bias")
|
||||
new_state_dict["encoder.mid_block.resnets.1.norm2.weight"] = old_state_dict.pop("encoder.mid.block_2.norm2.weight")
|
||||
new_state_dict["encoder.mid_block.resnets.1.norm2.bias"] = old_state_dict.pop("encoder.mid.block_2.norm2.bias")
|
||||
new_state_dict["encoder.mid_block.resnets.1.conv2.weight"] = old_state_dict.pop("encoder.mid.block_2.conv2.weight")
|
||||
new_state_dict["encoder.mid_block.resnets.1.conv2.bias"] = old_state_dict.pop("encoder.mid.block_2.conv2.bias")
|
||||
new_state_dict["encoder.conv_norm_out.weight"] = old_state_dict.pop("encoder.norm_out.weight")
|
||||
new_state_dict["encoder.conv_norm_out.bias"] = old_state_dict.pop("encoder.norm_out.bias")
|
||||
new_state_dict["encoder.conv_out.weight"] = old_state_dict.pop("encoder.conv_out.weight")
|
||||
new_state_dict["encoder.conv_out.bias"] = old_state_dict.pop("encoder.conv_out.bias")
|
||||
new_state_dict["quant_conv.weight"] = old_state_dict.pop("quant_conv.weight")
|
||||
new_state_dict["quant_conv.bias"] = old_state_dict.pop("quant_conv.bias")
|
||||
new_state_dict["quantize.embedding.weight"] = old_state_dict.pop("quantize.embedding.weight")
|
||||
new_state_dict["post_quant_conv.weight"] = old_state_dict.pop("post_quant_conv.weight")
|
||||
new_state_dict["post_quant_conv.bias"] = old_state_dict.pop("post_quant_conv.bias")
|
||||
new_state_dict["decoder.conv_in.weight"] = old_state_dict.pop("decoder.conv_in.weight")
|
||||
new_state_dict["decoder.conv_in.bias"] = old_state_dict.pop("decoder.conv_in.bias")
|
||||
new_state_dict["decoder.mid_block.resnets.0.norm1.weight"] = old_state_dict.pop("decoder.mid.block_1.norm1.weight")
|
||||
new_state_dict["decoder.mid_block.resnets.0.norm1.bias"] = old_state_dict.pop("decoder.mid.block_1.norm1.bias")
|
||||
new_state_dict["decoder.mid_block.resnets.0.conv1.weight"] = old_state_dict.pop("decoder.mid.block_1.conv1.weight")
|
||||
new_state_dict["decoder.mid_block.resnets.0.conv1.bias"] = old_state_dict.pop("decoder.mid.block_1.conv1.bias")
|
||||
new_state_dict["decoder.mid_block.resnets.0.norm2.weight"] = old_state_dict.pop("decoder.mid.block_1.norm2.weight")
|
||||
new_state_dict["decoder.mid_block.resnets.0.norm2.bias"] = old_state_dict.pop("decoder.mid.block_1.norm2.bias")
|
||||
new_state_dict["decoder.mid_block.resnets.0.conv2.weight"] = old_state_dict.pop("decoder.mid.block_1.conv2.weight")
|
||||
new_state_dict["decoder.mid_block.resnets.0.conv2.bias"] = old_state_dict.pop("decoder.mid.block_1.conv2.bias")
|
||||
new_state_dict["decoder.mid_block.resnets.1.norm1.weight"] = old_state_dict.pop("decoder.mid.block_2.norm1.weight")
|
||||
new_state_dict["decoder.mid_block.resnets.1.norm1.bias"] = old_state_dict.pop("decoder.mid.block_2.norm1.bias")
|
||||
new_state_dict["decoder.mid_block.resnets.1.conv1.weight"] = old_state_dict.pop("decoder.mid.block_2.conv1.weight")
|
||||
new_state_dict["decoder.mid_block.resnets.1.conv1.bias"] = old_state_dict.pop("decoder.mid.block_2.conv1.bias")
|
||||
new_state_dict["decoder.mid_block.resnets.1.norm2.weight"] = old_state_dict.pop("decoder.mid.block_2.norm2.weight")
|
||||
new_state_dict["decoder.mid_block.resnets.1.norm2.bias"] = old_state_dict.pop("decoder.mid.block_2.norm2.bias")
|
||||
new_state_dict["decoder.mid_block.resnets.1.conv2.weight"] = old_state_dict.pop("decoder.mid.block_2.conv2.weight")
|
||||
new_state_dict["decoder.mid_block.resnets.1.conv2.bias"] = old_state_dict.pop("decoder.mid.block_2.conv2.bias")
|
||||
|
||||
convert_vae_block_state_dict(old_state_dict, "decoder.up.0", new_state_dict, "decoder.up_blocks.4")
|
||||
convert_vae_block_state_dict(old_state_dict, "decoder.up.1", new_state_dict, "decoder.up_blocks.3")
|
||||
convert_vae_block_state_dict(old_state_dict, "decoder.up.2", new_state_dict, "decoder.up_blocks.2")
|
||||
convert_vae_block_state_dict(old_state_dict, "decoder.up.3", new_state_dict, "decoder.up_blocks.1")
|
||||
convert_vae_block_state_dict(old_state_dict, "decoder.up.4", new_state_dict, "decoder.up_blocks.0")
|
||||
|
||||
new_state_dict["decoder.conv_norm_out.weight"] = old_state_dict.pop("decoder.norm_out.weight")
|
||||
new_state_dict["decoder.conv_norm_out.bias"] = old_state_dict.pop("decoder.norm_out.bias")
|
||||
new_state_dict["decoder.conv_out.weight"] = old_state_dict.pop("decoder.conv_out.weight")
|
||||
new_state_dict["decoder.conv_out.bias"] = old_state_dict.pop("decoder.conv_out.bias")
|
||||
|
||||
# fmt: on
|
||||
|
||||
assert len(old_state_dict.keys()) == 0
|
||||
|
||||
new_vae.load_state_dict(new_state_dict)
|
||||
|
||||
input = torch.randn((1, 3, 512, 512), device=device)
|
||||
input = input.clamp(-1, 1)
|
||||
|
||||
old_encoder_output = old_vae.quant_conv(old_vae.encoder(input))
|
||||
new_encoder_output = new_vae.quant_conv(new_vae.encoder(input))
|
||||
assert (old_encoder_output == new_encoder_output).all()
|
||||
|
||||
old_decoder_output = old_vae.decoder(old_vae.post_quant_conv(old_encoder_output))
|
||||
new_decoder_output = new_vae.decoder(new_vae.post_quant_conv(new_encoder_output))
|
||||
|
||||
# assert (old_decoder_output == new_decoder_output).all()
|
||||
print("kipping vae decoder equivalence check")
|
||||
print(f"vae decoder diff {(old_decoder_output - new_decoder_output).float().abs().sum()}")
|
||||
|
||||
old_output = old_vae(input)[0]
|
||||
new_output = new_vae(input)[0]
|
||||
|
||||
# assert (old_output == new_output).all()
|
||||
print("skipping full vae equivalence check")
|
||||
print(f"vae full diff { (old_output - new_output).float().abs().sum()}")
|
||||
|
||||
return new_vae
|
||||
|
||||
|
||||
def convert_vae_block_state_dict(old_state_dict, prefix_from, new_state_dict, prefix_to):
|
||||
# fmt: off
|
||||
|
||||
new_state_dict[f"{prefix_to}.resnets.0.norm1.weight"] = old_state_dict.pop(f"{prefix_from}.block.0.norm1.weight")
|
||||
new_state_dict[f"{prefix_to}.resnets.0.norm1.bias"] = old_state_dict.pop(f"{prefix_from}.block.0.norm1.bias")
|
||||
new_state_dict[f"{prefix_to}.resnets.0.conv1.weight"] = old_state_dict.pop(f"{prefix_from}.block.0.conv1.weight")
|
||||
new_state_dict[f"{prefix_to}.resnets.0.conv1.bias"] = old_state_dict.pop(f"{prefix_from}.block.0.conv1.bias")
|
||||
new_state_dict[f"{prefix_to}.resnets.0.norm2.weight"] = old_state_dict.pop(f"{prefix_from}.block.0.norm2.weight")
|
||||
new_state_dict[f"{prefix_to}.resnets.0.norm2.bias"] = old_state_dict.pop(f"{prefix_from}.block.0.norm2.bias")
|
||||
new_state_dict[f"{prefix_to}.resnets.0.conv2.weight"] = old_state_dict.pop(f"{prefix_from}.block.0.conv2.weight")
|
||||
new_state_dict[f"{prefix_to}.resnets.0.conv2.bias"] = old_state_dict.pop(f"{prefix_from}.block.0.conv2.bias")
|
||||
|
||||
if f"{prefix_from}.block.0.nin_shortcut.weight" in old_state_dict:
|
||||
new_state_dict[f"{prefix_to}.resnets.0.conv_shortcut.weight"] = old_state_dict.pop(f"{prefix_from}.block.0.nin_shortcut.weight")
|
||||
new_state_dict[f"{prefix_to}.resnets.0.conv_shortcut.bias"] = old_state_dict.pop(f"{prefix_from}.block.0.nin_shortcut.bias")
|
||||
|
||||
new_state_dict[f"{prefix_to}.resnets.1.norm1.weight"] = old_state_dict.pop(f"{prefix_from}.block.1.norm1.weight")
|
||||
new_state_dict[f"{prefix_to}.resnets.1.norm1.bias"] = old_state_dict.pop(f"{prefix_from}.block.1.norm1.bias")
|
||||
new_state_dict[f"{prefix_to}.resnets.1.conv1.weight"] = old_state_dict.pop(f"{prefix_from}.block.1.conv1.weight")
|
||||
new_state_dict[f"{prefix_to}.resnets.1.conv1.bias"] = old_state_dict.pop(f"{prefix_from}.block.1.conv1.bias")
|
||||
new_state_dict[f"{prefix_to}.resnets.1.norm2.weight"] = old_state_dict.pop(f"{prefix_from}.block.1.norm2.weight")
|
||||
new_state_dict[f"{prefix_to}.resnets.1.norm2.bias"] = old_state_dict.pop(f"{prefix_from}.block.1.norm2.bias")
|
||||
new_state_dict[f"{prefix_to}.resnets.1.conv2.weight"] = old_state_dict.pop(f"{prefix_from}.block.1.conv2.weight")
|
||||
new_state_dict[f"{prefix_to}.resnets.1.conv2.bias"] = old_state_dict.pop(f"{prefix_from}.block.1.conv2.bias")
|
||||
|
||||
if f"{prefix_from}.downsample.conv.weight" in old_state_dict:
|
||||
new_state_dict[f"{prefix_to}.downsamplers.0.conv.weight"] = old_state_dict.pop(f"{prefix_from}.downsample.conv.weight")
|
||||
new_state_dict[f"{prefix_to}.downsamplers.0.conv.bias"] = old_state_dict.pop(f"{prefix_from}.downsample.conv.bias")
|
||||
|
||||
if f"{prefix_from}.upsample.conv.weight" in old_state_dict:
|
||||
new_state_dict[f"{prefix_to}.upsamplers.0.conv.weight"] = old_state_dict.pop(f"{prefix_from}.upsample.conv.weight")
|
||||
new_state_dict[f"{prefix_to}.upsamplers.0.conv.bias"] = old_state_dict.pop(f"{prefix_from}.upsample.conv.bias")
|
||||
|
||||
if f"{prefix_from}.block.2.norm1.weight" in old_state_dict:
|
||||
new_state_dict[f"{prefix_to}.resnets.2.norm1.weight"] = old_state_dict.pop(f"{prefix_from}.block.2.norm1.weight")
|
||||
new_state_dict[f"{prefix_to}.resnets.2.norm1.bias"] = old_state_dict.pop(f"{prefix_from}.block.2.norm1.bias")
|
||||
new_state_dict[f"{prefix_to}.resnets.2.conv1.weight"] = old_state_dict.pop(f"{prefix_from}.block.2.conv1.weight")
|
||||
new_state_dict[f"{prefix_to}.resnets.2.conv1.bias"] = old_state_dict.pop(f"{prefix_from}.block.2.conv1.bias")
|
||||
new_state_dict[f"{prefix_to}.resnets.2.norm2.weight"] = old_state_dict.pop(f"{prefix_from}.block.2.norm2.weight")
|
||||
new_state_dict[f"{prefix_to}.resnets.2.norm2.bias"] = old_state_dict.pop(f"{prefix_from}.block.2.norm2.bias")
|
||||
new_state_dict[f"{prefix_to}.resnets.2.conv2.weight"] = old_state_dict.pop(f"{prefix_from}.block.2.conv2.weight")
|
||||
new_state_dict[f"{prefix_to}.resnets.2.conv2.bias"] = old_state_dict.pop(f"{prefix_from}.block.2.conv2.bias")
|
||||
|
||||
# fmt: on
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -95,7 +95,6 @@ else:
|
||||
"UNet3DConditionModel",
|
||||
"UNetMotionModel",
|
||||
"UNetSpatioTemporalConditionModel",
|
||||
"UVit2DModel",
|
||||
"VQModel",
|
||||
]
|
||||
)
|
||||
@@ -132,7 +131,6 @@ else:
|
||||
)
|
||||
_import_structure["schedulers"].extend(
|
||||
[
|
||||
"AmusedScheduler",
|
||||
"CMStochasticIterativeScheduler",
|
||||
"DDIMInverseScheduler",
|
||||
"DDIMParallelScheduler",
|
||||
@@ -204,9 +202,6 @@ else:
|
||||
[
|
||||
"AltDiffusionImg2ImgPipeline",
|
||||
"AltDiffusionPipeline",
|
||||
"AmusedImg2ImgPipeline",
|
||||
"AmusedInpaintPipeline",
|
||||
"AmusedPipeline",
|
||||
"AnimateDiffPipeline",
|
||||
"AudioLDM2Pipeline",
|
||||
"AudioLDM2ProjectionModel",
|
||||
@@ -477,7 +472,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
UNet3DConditionModel,
|
||||
UNetMotionModel,
|
||||
UNetSpatioTemporalConditionModel,
|
||||
UVit2DModel,
|
||||
VQModel,
|
||||
)
|
||||
from .optimization import (
|
||||
@@ -512,7 +506,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
ScoreSdeVePipeline,
|
||||
)
|
||||
from .schedulers import (
|
||||
AmusedScheduler,
|
||||
CMStochasticIterativeScheduler,
|
||||
DDIMInverseScheduler,
|
||||
DDIMParallelScheduler,
|
||||
@@ -567,9 +560,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .pipelines import (
|
||||
AltDiffusionImg2ImgPipeline,
|
||||
AltDiffusionPipeline,
|
||||
AmusedImg2ImgPipeline,
|
||||
AmusedInpaintPipeline,
|
||||
AmusedPipeline,
|
||||
AnimateDiffPipeline,
|
||||
AudioLDM2Pipeline,
|
||||
AudioLDM2ProjectionModel,
|
||||
|
||||
@@ -59,7 +59,6 @@ logger = logging.get_logger(__name__)
|
||||
|
||||
TEXT_ENCODER_NAME = "text_encoder"
|
||||
UNET_NAME = "unet"
|
||||
TRANSFORMER_NAME = "transformer"
|
||||
|
||||
LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
|
||||
LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
|
||||
@@ -75,7 +74,6 @@ class LoraLoaderMixin:
|
||||
|
||||
text_encoder_name = TEXT_ENCODER_NAME
|
||||
unet_name = UNET_NAME
|
||||
transformer_name = TRANSFORMER_NAME
|
||||
num_fused_loras = 0
|
||||
|
||||
def load_lora_weights(
|
||||
@@ -663,89 +661,6 @@ class LoraLoaderMixin:
|
||||
_pipeline.enable_sequential_cpu_offload()
|
||||
# Unsafe code />
|
||||
|
||||
@classmethod
|
||||
def load_lora_into_transformer(
|
||||
cls, state_dict, network_alphas, transformer, low_cpu_mem_usage=None, adapter_name=None, _pipeline=None
|
||||
):
|
||||
"""
|
||||
This will load the LoRA layers specified in `state_dict` into `transformer`.
|
||||
|
||||
Parameters:
|
||||
state_dict (`dict`):
|
||||
A standard state dict containing the lora layer parameters. The keys can either be indexed directly
|
||||
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
|
||||
encoder lora layers.
|
||||
network_alphas (`Dict[str, float]`):
|
||||
See `LoRALinearLayer` for more details.
|
||||
unet (`UNet2DConditionModel`):
|
||||
The UNet model to load the LoRA layers into.
|
||||
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
|
||||
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
|
||||
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
|
||||
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
|
||||
argument to `True` will raise an error.
|
||||
adapter_name (`str`, *optional*):
|
||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||
`default_{i}` where i is the total number of adapters being loaded.
|
||||
"""
|
||||
low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT
|
||||
|
||||
keys = list(state_dict.keys())
|
||||
|
||||
transformer_keys = [k for k in keys if k.startswith(cls.transformer_name)]
|
||||
state_dict = {
|
||||
k.replace(f"{cls.transformer_name}.", ""): v for k, v in state_dict.items() if k in transformer_keys
|
||||
}
|
||||
|
||||
if network_alphas is not None:
|
||||
alpha_keys = [k for k in network_alphas.keys() if k.startswith(cls.transformer_name)]
|
||||
network_alphas = {
|
||||
k.replace(f"{cls.transformer_name}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
|
||||
}
|
||||
|
||||
if len(state_dict.keys()) > 0:
|
||||
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
|
||||
|
||||
if adapter_name in getattr(transformer, "peft_config", {}):
|
||||
raise ValueError(
|
||||
f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name."
|
||||
)
|
||||
|
||||
rank = {}
|
||||
for key, val in state_dict.items():
|
||||
if "lora_B" in key:
|
||||
rank[key] = val.shape[1]
|
||||
|
||||
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict)
|
||||
lora_config = LoraConfig(**lora_config_kwargs)
|
||||
|
||||
# adapter_name
|
||||
if adapter_name is None:
|
||||
adapter_name = get_adapter_name(transformer)
|
||||
|
||||
# In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
|
||||
# otherwise loading LoRA weights will lead to an error
|
||||
is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
|
||||
|
||||
inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name)
|
||||
incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name)
|
||||
|
||||
if incompatible_keys is not None:
|
||||
# check only for unexpected keys
|
||||
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
|
||||
if unexpected_keys:
|
||||
logger.warning(
|
||||
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
|
||||
f" {unexpected_keys}. "
|
||||
)
|
||||
|
||||
# Offload back.
|
||||
if is_model_cpu_offload:
|
||||
_pipeline.enable_model_cpu_offload()
|
||||
elif is_sequential_cpu_offload:
|
||||
_pipeline.enable_sequential_cpu_offload()
|
||||
# Unsafe code />
|
||||
|
||||
@property
|
||||
def lora_scale(self) -> float:
|
||||
# property function that returns the lora scale which can be set at run time by the pipeline.
|
||||
@@ -871,7 +786,6 @@ class LoraLoaderMixin:
|
||||
save_directory: Union[str, os.PathLike],
|
||||
unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
||||
text_encoder_lora_layers: Dict[str, torch.nn.Module] = None,
|
||||
transformer_lora_layers: Dict[str, torch.nn.Module] = None,
|
||||
is_main_process: bool = True,
|
||||
weight_name: str = None,
|
||||
save_function: Callable = None,
|
||||
@@ -906,10 +820,8 @@ class LoraLoaderMixin:
|
||||
layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
|
||||
return layers_state_dict
|
||||
|
||||
if not (unet_lora_layers or text_encoder_lora_layers or transformer_lora_layers):
|
||||
raise ValueError(
|
||||
"You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers`, or `transformer_lora_layers`."
|
||||
)
|
||||
if not (unet_lora_layers or text_encoder_lora_layers):
|
||||
raise ValueError("You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers`.")
|
||||
|
||||
if unet_lora_layers:
|
||||
state_dict.update(pack_weights(unet_lora_layers, "unet"))
|
||||
@@ -917,9 +829,6 @@ class LoraLoaderMixin:
|
||||
if text_encoder_lora_layers:
|
||||
state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder"))
|
||||
|
||||
if transformer_lora_layers:
|
||||
state_dict.update(pack_weights(transformer_lora_layers, "transformer"))
|
||||
|
||||
# Save the model
|
||||
cls.write_lora_layers(
|
||||
state_dict=state_dict,
|
||||
|
||||
@@ -47,7 +47,6 @@ if is_torch_available():
|
||||
_import_structure["unet_kandinsky3"] = ["Kandinsky3UNet"]
|
||||
_import_structure["unet_motion_model"] = ["MotionAdapter", "UNetMotionModel"]
|
||||
_import_structure["unet_spatio_temporal_condition"] = ["UNetSpatioTemporalConditionModel"]
|
||||
_import_structure["uvit_2d"] = ["UVit2DModel"]
|
||||
_import_structure["vq_model"] = ["VQModel"]
|
||||
|
||||
if is_flax_available():
|
||||
@@ -82,7 +81,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .unet_kandinsky3 import Kandinsky3UNet
|
||||
from .unet_motion_model import MotionAdapter, UNetMotionModel
|
||||
from .unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel
|
||||
from .uvit_2d import UVit2DModel
|
||||
from .vq_model import VQModel
|
||||
|
||||
if is_flax_available():
|
||||
|
||||
@@ -14,7 +14,6 @@
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from ..utils import USE_PEFT_BACKEND
|
||||
@@ -23,7 +22,7 @@ from .activations import GEGLU, GELU, ApproximateGELU
|
||||
from .attention_processor import Attention
|
||||
from .embeddings import SinusoidalPositionalEmbedding
|
||||
from .lora import LoRACompatibleLinear
|
||||
from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm
|
||||
from .normalization import AdaLayerNorm, AdaLayerNormZero
|
||||
|
||||
|
||||
def _chunked_feed_forward(
|
||||
@@ -149,11 +148,6 @@ class BasicTransformerBlock(nn.Module):
|
||||
attention_type: str = "default",
|
||||
positional_embeddings: Optional[str] = None,
|
||||
num_positional_embeddings: Optional[int] = None,
|
||||
ada_norm_continous_conditioning_embedding_dim: Optional[int] = None,
|
||||
ada_norm_bias: Optional[int] = None,
|
||||
ff_inner_dim: Optional[int] = None,
|
||||
ff_bias: bool = True,
|
||||
attention_out_bias: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.only_cross_attention = only_cross_attention
|
||||
@@ -162,7 +156,6 @@ class BasicTransformerBlock(nn.Module):
|
||||
self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
|
||||
self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
|
||||
self.use_layer_norm = norm_type == "layer_norm"
|
||||
self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous"
|
||||
|
||||
if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
|
||||
raise ValueError(
|
||||
@@ -186,15 +179,6 @@ class BasicTransformerBlock(nn.Module):
|
||||
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
|
||||
elif self.use_ada_layer_norm_zero:
|
||||
self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
|
||||
elif self.use_ada_layer_norm_continuous:
|
||||
self.norm1 = AdaLayerNormContinuous(
|
||||
dim,
|
||||
ada_norm_continous_conditioning_embedding_dim,
|
||||
norm_elementwise_affine,
|
||||
norm_eps,
|
||||
ada_norm_bias,
|
||||
"rms_norm",
|
||||
)
|
||||
else:
|
||||
self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
|
||||
|
||||
@@ -206,7 +190,6 @@ class BasicTransformerBlock(nn.Module):
|
||||
bias=attention_bias,
|
||||
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
|
||||
upcast_attention=upcast_attention,
|
||||
out_bias=attention_out_bias,
|
||||
)
|
||||
|
||||
# 2. Cross-Attn
|
||||
@@ -214,20 +197,11 @@ class BasicTransformerBlock(nn.Module):
|
||||
# We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
|
||||
# I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
|
||||
# the second cross attention block.
|
||||
if self.use_ada_layer_norm:
|
||||
self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm)
|
||||
elif self.use_ada_layer_norm_continuous:
|
||||
self.norm2 = AdaLayerNormContinuous(
|
||||
dim,
|
||||
ada_norm_continous_conditioning_embedding_dim,
|
||||
norm_elementwise_affine,
|
||||
norm_eps,
|
||||
ada_norm_bias,
|
||||
"rms_norm",
|
||||
)
|
||||
else:
|
||||
self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
|
||||
|
||||
self.norm2 = (
|
||||
AdaLayerNorm(dim, num_embeds_ada_norm)
|
||||
if self.use_ada_layer_norm
|
||||
else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
|
||||
)
|
||||
self.attn2 = Attention(
|
||||
query_dim=dim,
|
||||
cross_attention_dim=cross_attention_dim if not double_self_attention else None,
|
||||
@@ -236,32 +210,20 @@ class BasicTransformerBlock(nn.Module):
|
||||
dropout=dropout,
|
||||
bias=attention_bias,
|
||||
upcast_attention=upcast_attention,
|
||||
out_bias=attention_out_bias,
|
||||
) # is self-attn if encoder_hidden_states is none
|
||||
else:
|
||||
self.norm2 = None
|
||||
self.attn2 = None
|
||||
|
||||
# 3. Feed-forward
|
||||
if self.use_ada_layer_norm_continuous:
|
||||
self.norm3 = AdaLayerNormContinuous(
|
||||
dim,
|
||||
ada_norm_continous_conditioning_embedding_dim,
|
||||
norm_elementwise_affine,
|
||||
norm_eps,
|
||||
ada_norm_bias,
|
||||
"layer_norm",
|
||||
)
|
||||
elif not self.use_ada_layer_norm_single:
|
||||
self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
|
||||
if not self.use_ada_layer_norm_single:
|
||||
self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
|
||||
|
||||
self.ff = FeedForward(
|
||||
dim,
|
||||
dropout=dropout,
|
||||
activation_fn=activation_fn,
|
||||
final_dropout=final_dropout,
|
||||
inner_dim=ff_inner_dim,
|
||||
bias=ff_bias,
|
||||
)
|
||||
|
||||
# 4. Fuser
|
||||
@@ -290,7 +252,6 @@ class BasicTransformerBlock(nn.Module):
|
||||
timestep: Optional[torch.LongTensor] = None,
|
||||
cross_attention_kwargs: Dict[str, Any] = None,
|
||||
class_labels: Optional[torch.LongTensor] = None,
|
||||
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
||||
) -> torch.FloatTensor:
|
||||
# Notice that normalization is always applied before the real computation in the following blocks.
|
||||
# 0. Self-Attention
|
||||
@@ -304,8 +265,6 @@ class BasicTransformerBlock(nn.Module):
|
||||
)
|
||||
elif self.use_layer_norm:
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
elif self.use_ada_layer_norm_continuous:
|
||||
norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"])
|
||||
elif self.use_ada_layer_norm_single:
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
|
||||
self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
|
||||
@@ -355,8 +314,6 @@ class BasicTransformerBlock(nn.Module):
|
||||
# For PixArt norm2 isn't applied here:
|
||||
# https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
|
||||
norm_hidden_states = hidden_states
|
||||
elif self.use_ada_layer_norm_continuous:
|
||||
norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"])
|
||||
else:
|
||||
raise ValueError("Incorrect norm")
|
||||
|
||||
@@ -372,9 +329,7 @@ class BasicTransformerBlock(nn.Module):
|
||||
hidden_states = attn_output + hidden_states
|
||||
|
||||
# 4. Feed-forward
|
||||
if self.use_ada_layer_norm_continuous:
|
||||
norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs["pooled_text_emb"])
|
||||
elif not self.use_ada_layer_norm_single:
|
||||
if not self.use_ada_layer_norm_single:
|
||||
norm_hidden_states = self.norm3(hidden_states)
|
||||
|
||||
if self.use_ada_layer_norm_zero:
|
||||
@@ -535,78 +490,6 @@ class TemporalBasicTransformerBlock(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class SkipFFTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
kv_input_dim: int,
|
||||
kv_input_dim_proj_use_bias: bool,
|
||||
dropout=0.0,
|
||||
cross_attention_dim: Optional[int] = None,
|
||||
attention_bias: bool = False,
|
||||
attention_out_bias: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
if kv_input_dim != dim:
|
||||
self.kv_mapper = nn.Linear(kv_input_dim, dim, kv_input_dim_proj_use_bias)
|
||||
else:
|
||||
self.kv_mapper = None
|
||||
|
||||
self.norm1 = RMSNorm(dim, 1e-06)
|
||||
|
||||
self.attn1 = Attention(
|
||||
query_dim=dim,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
dropout=dropout,
|
||||
bias=attention_bias,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
out_bias=attention_out_bias,
|
||||
)
|
||||
|
||||
self.norm2 = RMSNorm(dim, 1e-06)
|
||||
|
||||
self.attn2 = Attention(
|
||||
query_dim=dim,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
dropout=dropout,
|
||||
bias=attention_bias,
|
||||
out_bias=attention_out_bias,
|
||||
)
|
||||
|
||||
def forward(self, hidden_states, encoder_hidden_states, cross_attention_kwargs):
|
||||
cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
|
||||
|
||||
if self.kv_mapper is not None:
|
||||
encoder_hidden_states = self.kv_mapper(F.silu(encoder_hidden_states))
|
||||
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
|
||||
attn_output = self.attn1(
|
||||
norm_hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
**cross_attention_kwargs,
|
||||
)
|
||||
|
||||
hidden_states = attn_output + hidden_states
|
||||
|
||||
norm_hidden_states = self.norm2(hidden_states)
|
||||
|
||||
attn_output = self.attn2(
|
||||
norm_hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
**cross_attention_kwargs,
|
||||
)
|
||||
|
||||
hidden_states = attn_output + hidden_states
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
r"""
|
||||
A feed-forward layer.
|
||||
@@ -629,12 +512,10 @@ class FeedForward(nn.Module):
|
||||
dropout: float = 0.0,
|
||||
activation_fn: str = "geglu",
|
||||
final_dropout: bool = False,
|
||||
inner_dim=None,
|
||||
bias: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
if inner_dim is None:
|
||||
inner_dim = int(dim * mult)
|
||||
inner_dim = int(dim * mult)
|
||||
dim_out = dim_out if dim_out is not None else dim
|
||||
linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear
|
||||
|
||||
|
||||
@@ -77,7 +77,6 @@ class Encoder(nn.Module):
|
||||
norm_num_groups: int = 32,
|
||||
act_fn: str = "silu",
|
||||
double_z: bool = True,
|
||||
mid_block_add_attention=True,
|
||||
):
|
||||
super().__init__()
|
||||
self.layers_per_block = layers_per_block
|
||||
@@ -125,7 +124,6 @@ class Encoder(nn.Module):
|
||||
attention_head_dim=block_out_channels[-1],
|
||||
resnet_groups=norm_num_groups,
|
||||
temb_channels=None,
|
||||
add_attention=mid_block_add_attention,
|
||||
)
|
||||
|
||||
# out
|
||||
@@ -215,7 +213,6 @@ class Decoder(nn.Module):
|
||||
norm_num_groups: int = 32,
|
||||
act_fn: str = "silu",
|
||||
norm_type: str = "group", # group, spatial
|
||||
mid_block_add_attention=True,
|
||||
):
|
||||
super().__init__()
|
||||
self.layers_per_block = layers_per_block
|
||||
@@ -243,7 +240,6 @@ class Decoder(nn.Module):
|
||||
attention_head_dim=block_out_channels[-1],
|
||||
resnet_groups=norm_num_groups,
|
||||
temb_channels=temb_channels,
|
||||
add_attention=mid_block_add_attention,
|
||||
)
|
||||
|
||||
# up
|
||||
|
||||
@@ -20,7 +20,6 @@ import torch.nn.functional as F
|
||||
|
||||
from ..utils import USE_PEFT_BACKEND
|
||||
from .lora import LoRACompatibleConv
|
||||
from .normalization import RMSNorm
|
||||
from .upsampling import upfirdn2d_native
|
||||
|
||||
|
||||
@@ -90,11 +89,6 @@ class Downsample2D(nn.Module):
|
||||
out_channels: Optional[int] = None,
|
||||
padding: int = 1,
|
||||
name: str = "conv",
|
||||
kernel_size=3,
|
||||
norm_type=None,
|
||||
eps=None,
|
||||
elementwise_affine=None,
|
||||
bias=True,
|
||||
):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
@@ -105,19 +99,8 @@ class Downsample2D(nn.Module):
|
||||
self.name = name
|
||||
conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
|
||||
|
||||
if norm_type == "ln_norm":
|
||||
self.norm = nn.LayerNorm(channels, eps, elementwise_affine)
|
||||
elif norm_type == "rms_norm":
|
||||
self.norm = RMSNorm(channels, eps, elementwise_affine)
|
||||
elif norm_type is None:
|
||||
self.norm = None
|
||||
else:
|
||||
raise ValueError(f"unknown norm_type: {norm_type}")
|
||||
|
||||
if use_conv:
|
||||
conv = conv_cls(
|
||||
self.channels, self.out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias
|
||||
)
|
||||
conv = conv_cls(self.channels, self.out_channels, 3, stride=stride, padding=padding)
|
||||
else:
|
||||
assert self.channels == self.out_channels
|
||||
conv = nn.AvgPool2d(kernel_size=stride, stride=stride)
|
||||
@@ -134,9 +117,6 @@ class Downsample2D(nn.Module):
|
||||
def forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0) -> torch.FloatTensor:
|
||||
assert hidden_states.shape[1] == self.channels
|
||||
|
||||
if self.norm is not None:
|
||||
hidden_states = self.norm(hidden_states.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
||||
|
||||
if self.use_conv and self.padding == 0:
|
||||
pad = (0, 1, 0, 1)
|
||||
hidden_states = F.pad(hidden_states, pad, mode="constant", value=0)
|
||||
|
||||
@@ -197,12 +197,11 @@ class TimestepEmbedding(nn.Module):
|
||||
out_dim: int = None,
|
||||
post_act_fn: Optional[str] = None,
|
||||
cond_proj_dim=None,
|
||||
sample_proj_bias=True,
|
||||
):
|
||||
super().__init__()
|
||||
linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
|
||||
|
||||
self.linear_1 = linear_cls(in_channels, time_embed_dim, sample_proj_bias)
|
||||
self.linear_1 = linear_cls(in_channels, time_embed_dim)
|
||||
|
||||
if cond_proj_dim is not None:
|
||||
self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
|
||||
@@ -215,7 +214,7 @@ class TimestepEmbedding(nn.Module):
|
||||
time_embed_dim_out = out_dim
|
||||
else:
|
||||
time_embed_dim_out = time_embed_dim
|
||||
self.linear_2 = linear_cls(time_embed_dim, time_embed_dim_out, sample_proj_bias)
|
||||
self.linear_2 = linear_cls(time_embed_dim, time_embed_dim_out)
|
||||
|
||||
if post_act_fn is None:
|
||||
self.post_act = None
|
||||
|
||||
@@ -13,14 +13,12 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numbers
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ..utils import is_torch_version
|
||||
from .activations import get_activation
|
||||
from .embeddings import CombinedTimestepLabelEmbeddings, PixArtAlphaCombinedTimestepSizeEmbeddings
|
||||
|
||||
@@ -148,107 +146,3 @@ class AdaGroupNorm(nn.Module):
|
||||
x = F.group_norm(x, self.num_groups, eps=self.eps)
|
||||
x = x * (1 + scale) + shift
|
||||
return x
|
||||
|
||||
|
||||
class AdaLayerNormContinuous(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
conditioning_embedding_dim: int,
|
||||
# NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters
|
||||
# because the output is immediately scaled and shifted by the projected conditioning embeddings.
|
||||
# Note that AdaLayerNorm does not let the norm layer have scale and shift parameters.
|
||||
# However, this is how it was implemented in the original code, and it's rather likely you should
|
||||
# set `elementwise_affine` to False.
|
||||
elementwise_affine=True,
|
||||
eps=1e-5,
|
||||
bias=True,
|
||||
norm_type="layer_norm",
|
||||
):
|
||||
super().__init__()
|
||||
self.silu = nn.SiLU()
|
||||
self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias)
|
||||
if norm_type == "layer_norm":
|
||||
self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias)
|
||||
elif norm_type == "rms_norm":
|
||||
self.norm = RMSNorm(embedding_dim, eps, elementwise_affine)
|
||||
else:
|
||||
raise ValueError(f"unknown norm_type {norm_type}")
|
||||
|
||||
def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor:
|
||||
emb = self.linear(self.silu(conditioning_embedding))
|
||||
scale, shift = torch.chunk(emb, 2, dim=1)
|
||||
x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
|
||||
return x
|
||||
|
||||
|
||||
if is_torch_version(">=", "2.1.0"):
|
||||
LayerNorm = nn.LayerNorm
|
||||
else:
|
||||
# Has optional bias parameter compared to torch layer norm
|
||||
# TODO: replace with torch layernorm once min required torch version >= 2.1
|
||||
class LayerNorm(nn.Module):
|
||||
def __init__(self, dim, eps: float = 1e-5, elementwise_affine: bool = True, bias: bool = True):
|
||||
super().__init__()
|
||||
|
||||
self.eps = eps
|
||||
|
||||
if isinstance(dim, numbers.Integral):
|
||||
dim = (dim,)
|
||||
|
||||
self.dim = torch.Size(dim)
|
||||
|
||||
if elementwise_affine:
|
||||
self.weight = nn.Parameter(torch.ones(dim))
|
||||
self.bias = nn.Parameter(torch.zeros(dim)) if bias else None
|
||||
else:
|
||||
self.weight = None
|
||||
self.bias = None
|
||||
|
||||
def forward(self, input):
|
||||
return F.layer_norm(input, self.dim, self.weight, self.bias, self.eps)
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dim, eps: float, elementwise_affine: bool = True):
|
||||
super().__init__()
|
||||
|
||||
self.eps = eps
|
||||
|
||||
if isinstance(dim, numbers.Integral):
|
||||
dim = (dim,)
|
||||
|
||||
self.dim = torch.Size(dim)
|
||||
|
||||
if elementwise_affine:
|
||||
self.weight = nn.Parameter(torch.ones(dim))
|
||||
else:
|
||||
self.weight = None
|
||||
|
||||
def forward(self, hidden_states):
|
||||
input_dtype = hidden_states.dtype
|
||||
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
|
||||
|
||||
if self.weight is not None:
|
||||
# convert into half-precision if necessary
|
||||
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
||||
hidden_states = hidden_states.to(self.weight.dtype)
|
||||
hidden_states = hidden_states * self.weight
|
||||
else:
|
||||
hidden_states = hidden_states.to(input_dtype)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class GlobalResponseNorm(nn.Module):
|
||||
# Taken from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim))
|
||||
self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim))
|
||||
|
||||
def forward(self, x):
|
||||
gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
|
||||
nx = gx / (gx.mean(dim=-1, keepdim=True) + 1e-6)
|
||||
return self.gamma * (x * nx) + self.beta + x
|
||||
|
||||
@@ -20,7 +20,6 @@ import torch.nn.functional as F
|
||||
|
||||
from ..utils import USE_PEFT_BACKEND
|
||||
from .lora import LoRACompatibleConv
|
||||
from .normalization import RMSNorm
|
||||
|
||||
|
||||
class Upsample1D(nn.Module):
|
||||
@@ -96,13 +95,6 @@ class Upsample2D(nn.Module):
|
||||
use_conv_transpose: bool = False,
|
||||
out_channels: Optional[int] = None,
|
||||
name: str = "conv",
|
||||
kernel_size: Optional[int] = None,
|
||||
padding=1,
|
||||
norm_type=None,
|
||||
eps=None,
|
||||
elementwise_affine=None,
|
||||
bias=True,
|
||||
interpolate=True,
|
||||
):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
@@ -110,29 +102,13 @@ class Upsample2D(nn.Module):
|
||||
self.use_conv = use_conv
|
||||
self.use_conv_transpose = use_conv_transpose
|
||||
self.name = name
|
||||
self.interpolate = interpolate
|
||||
conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
|
||||
|
||||
if norm_type == "ln_norm":
|
||||
self.norm = nn.LayerNorm(channels, eps, elementwise_affine)
|
||||
elif norm_type == "rms_norm":
|
||||
self.norm = RMSNorm(channels, eps, elementwise_affine)
|
||||
elif norm_type is None:
|
||||
self.norm = None
|
||||
else:
|
||||
raise ValueError(f"unknown norm_type: {norm_type}")
|
||||
|
||||
conv = None
|
||||
if use_conv_transpose:
|
||||
if kernel_size is None:
|
||||
kernel_size = 4
|
||||
conv = nn.ConvTranspose2d(
|
||||
channels, self.out_channels, kernel_size=kernel_size, stride=2, padding=padding, bias=bias
|
||||
)
|
||||
conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1)
|
||||
elif use_conv:
|
||||
if kernel_size is None:
|
||||
kernel_size = 3
|
||||
conv = conv_cls(self.channels, self.out_channels, kernel_size=kernel_size, padding=padding, bias=bias)
|
||||
conv = conv_cls(self.channels, self.out_channels, 3, padding=1)
|
||||
|
||||
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
|
||||
if name == "conv":
|
||||
@@ -148,9 +124,6 @@ class Upsample2D(nn.Module):
|
||||
) -> torch.FloatTensor:
|
||||
assert hidden_states.shape[1] == self.channels
|
||||
|
||||
if self.norm is not None:
|
||||
hidden_states = self.norm(hidden_states.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
||||
|
||||
if self.use_conv_transpose:
|
||||
return self.conv(hidden_states)
|
||||
|
||||
@@ -167,11 +140,10 @@ class Upsample2D(nn.Module):
|
||||
|
||||
# if `output_size` is passed we force the interpolation output
|
||||
# size and do not make use of `scale_factor=2`
|
||||
if self.interpolate:
|
||||
if output_size is None:
|
||||
hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
|
||||
else:
|
||||
hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
|
||||
if output_size is None:
|
||||
hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
|
||||
else:
|
||||
hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
|
||||
|
||||
# If the input is bfloat16, we cast back to bfloat16
|
||||
if dtype == torch.bfloat16:
|
||||
|
||||
@@ -1,471 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Dict, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from .attention import BasicTransformerBlock, SkipFFTransformerBlock
|
||||
from .attention_processor import (
|
||||
ADDED_KV_ATTENTION_PROCESSORS,
|
||||
CROSS_ATTENTION_PROCESSORS,
|
||||
AttentionProcessor,
|
||||
AttnAddedKVProcessor,
|
||||
AttnProcessor,
|
||||
)
|
||||
from .embeddings import TimestepEmbedding, get_timestep_embedding
|
||||
from .modeling_utils import ModelMixin
|
||||
from .normalization import GlobalResponseNorm, RMSNorm
|
||||
from .resnet import Downsample2D, Upsample2D
|
||||
|
||||
|
||||
class UVit2DModel(ModelMixin, ConfigMixin):
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
# global config
|
||||
hidden_size: int = 1024,
|
||||
use_bias: bool = False,
|
||||
hidden_dropout: float = 0.0,
|
||||
# conditioning dimensions
|
||||
cond_embed_dim: int = 768,
|
||||
micro_cond_encode_dim: int = 256,
|
||||
micro_cond_embed_dim: int = 1280,
|
||||
encoder_hidden_size: int = 768,
|
||||
# num tokens
|
||||
vocab_size: int = 8256, # codebook_size + 1 (for the mask token) rounded
|
||||
codebook_size: int = 8192,
|
||||
# `UVit2DConvEmbed`
|
||||
in_channels: int = 768,
|
||||
block_out_channels: int = 768,
|
||||
num_res_blocks: int = 3,
|
||||
downsample: bool = False,
|
||||
upsample: bool = False,
|
||||
block_num_heads: int = 12,
|
||||
# `TransformerLayer`
|
||||
num_hidden_layers: int = 22,
|
||||
num_attention_heads: int = 16,
|
||||
# `Attention`
|
||||
attention_dropout: float = 0.0,
|
||||
# `FeedForward`
|
||||
intermediate_size: int = 2816,
|
||||
# `Norm`
|
||||
layer_norm_eps: float = 1e-6,
|
||||
ln_elementwise_affine: bool = True,
|
||||
sample_size: int = 64,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.encoder_proj = nn.Linear(encoder_hidden_size, hidden_size, bias=use_bias)
|
||||
self.encoder_proj_layer_norm = RMSNorm(hidden_size, layer_norm_eps, ln_elementwise_affine)
|
||||
|
||||
self.embed = UVit2DConvEmbed(
|
||||
in_channels, block_out_channels, vocab_size, ln_elementwise_affine, layer_norm_eps, use_bias
|
||||
)
|
||||
|
||||
self.cond_embed = TimestepEmbedding(
|
||||
micro_cond_embed_dim + cond_embed_dim, hidden_size, sample_proj_bias=use_bias
|
||||
)
|
||||
|
||||
self.down_block = UVitBlock(
|
||||
block_out_channels,
|
||||
num_res_blocks,
|
||||
hidden_size,
|
||||
hidden_dropout,
|
||||
ln_elementwise_affine,
|
||||
layer_norm_eps,
|
||||
use_bias,
|
||||
block_num_heads,
|
||||
attention_dropout,
|
||||
downsample,
|
||||
False,
|
||||
)
|
||||
|
||||
self.project_to_hidden_norm = RMSNorm(block_out_channels, layer_norm_eps, ln_elementwise_affine)
|
||||
self.project_to_hidden = nn.Linear(block_out_channels, hidden_size, bias=use_bias)
|
||||
|
||||
self.transformer_layers = nn.ModuleList(
|
||||
[
|
||||
BasicTransformerBlock(
|
||||
dim=hidden_size,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=hidden_size // num_attention_heads,
|
||||
dropout=hidden_dropout,
|
||||
cross_attention_dim=hidden_size,
|
||||
attention_bias=use_bias,
|
||||
norm_type="ada_norm_continuous",
|
||||
ada_norm_continous_conditioning_embedding_dim=hidden_size,
|
||||
norm_elementwise_affine=ln_elementwise_affine,
|
||||
norm_eps=layer_norm_eps,
|
||||
ada_norm_bias=use_bias,
|
||||
ff_inner_dim=intermediate_size,
|
||||
ff_bias=use_bias,
|
||||
attention_out_bias=use_bias,
|
||||
)
|
||||
for _ in range(num_hidden_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.project_from_hidden_norm = RMSNorm(hidden_size, layer_norm_eps, ln_elementwise_affine)
|
||||
self.project_from_hidden = nn.Linear(hidden_size, block_out_channels, bias=use_bias)
|
||||
|
||||
self.up_block = UVitBlock(
|
||||
block_out_channels,
|
||||
num_res_blocks,
|
||||
hidden_size,
|
||||
hidden_dropout,
|
||||
ln_elementwise_affine,
|
||||
layer_norm_eps,
|
||||
use_bias,
|
||||
block_num_heads,
|
||||
attention_dropout,
|
||||
downsample=False,
|
||||
upsample=upsample,
|
||||
)
|
||||
|
||||
self.mlm_layer = ConvMlmLayer(
|
||||
block_out_channels, in_channels, use_bias, ln_elementwise_affine, layer_norm_eps, codebook_size
|
||||
)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
|
||||
pass
|
||||
|
||||
def forward(self, input_ids, encoder_hidden_states, pooled_text_emb, micro_conds, cross_attention_kwargs=None):
|
||||
encoder_hidden_states = self.encoder_proj(encoder_hidden_states)
|
||||
encoder_hidden_states = self.encoder_proj_layer_norm(encoder_hidden_states)
|
||||
|
||||
micro_cond_embeds = get_timestep_embedding(
|
||||
micro_conds.flatten(), self.config.micro_cond_encode_dim, flip_sin_to_cos=True, downscale_freq_shift=0
|
||||
)
|
||||
|
||||
micro_cond_embeds = micro_cond_embeds.reshape((input_ids.shape[0], -1))
|
||||
|
||||
pooled_text_emb = torch.cat([pooled_text_emb, micro_cond_embeds], dim=1)
|
||||
pooled_text_emb = pooled_text_emb.to(dtype=self.dtype)
|
||||
pooled_text_emb = self.cond_embed(pooled_text_emb).to(encoder_hidden_states.dtype)
|
||||
|
||||
hidden_states = self.embed(input_ids)
|
||||
|
||||
hidden_states = self.down_block(
|
||||
hidden_states,
|
||||
pooled_text_emb=pooled_text_emb,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
)
|
||||
|
||||
batch_size, channels, height, width = hidden_states.shape
|
||||
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch_size, height * width, channels)
|
||||
|
||||
hidden_states = self.project_to_hidden_norm(hidden_states)
|
||||
hidden_states = self.project_to_hidden(hidden_states)
|
||||
|
||||
for layer in self.transformer_layers:
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def layer_(*args):
|
||||
return checkpoint(layer, *args)
|
||||
|
||||
else:
|
||||
layer_ = layer
|
||||
|
||||
hidden_states = layer_(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
added_cond_kwargs={"pooled_text_emb": pooled_text_emb},
|
||||
)
|
||||
|
||||
hidden_states = self.project_from_hidden_norm(hidden_states)
|
||||
hidden_states = self.project_from_hidden(hidden_states)
|
||||
|
||||
hidden_states = hidden_states.reshape(batch_size, height, width, channels).permute(0, 3, 1, 2)
|
||||
|
||||
hidden_states = self.up_block(
|
||||
hidden_states,
|
||||
pooled_text_emb=pooled_text_emb,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
)
|
||||
|
||||
logits = self.mlm_layer(hidden_states)
|
||||
|
||||
return logits
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
indexed by its weight name.
|
||||
"""
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||
|
||||
return processors
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_add_processors(name, module, processors)
|
||||
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(
|
||||
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
|
||||
):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
Parameters:
|
||||
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
||||
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
||||
for **all** `Attention` layers.
|
||||
|
||||
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
||||
processor. This is strongly recommended when setting trainable attention processors.
|
||||
|
||||
"""
|
||||
count = len(self.attn_processors.keys())
|
||||
|
||||
if isinstance(processor, dict) and len(processor) != count:
|
||||
raise ValueError(
|
||||
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
||||
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
||||
)
|
||||
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor, _remove_lora=_remove_lora)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
||||
def set_default_attn_processor(self):
|
||||
"""
|
||||
Disables custom attention processors and sets the default attention implementation.
|
||||
"""
|
||||
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
||||
processor = AttnAddedKVProcessor()
|
||||
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
||||
processor = AttnProcessor()
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
|
||||
)
|
||||
|
||||
self.set_attn_processor(processor, _remove_lora=True)
|
||||
|
||||
|
||||
class UVit2DConvEmbed(nn.Module):
|
||||
def __init__(self, in_channels, block_out_channels, vocab_size, elementwise_affine, eps, bias):
|
||||
super().__init__()
|
||||
self.embeddings = nn.Embedding(vocab_size, in_channels)
|
||||
self.layer_norm = RMSNorm(in_channels, eps, elementwise_affine)
|
||||
self.conv = nn.Conv2d(in_channels, block_out_channels, kernel_size=1, bias=bias)
|
||||
|
||||
def forward(self, input_ids):
|
||||
embeddings = self.embeddings(input_ids)
|
||||
embeddings = self.layer_norm(embeddings)
|
||||
embeddings = embeddings.permute(0, 3, 1, 2)
|
||||
embeddings = self.conv(embeddings)
|
||||
return embeddings
|
||||
|
||||
|
||||
class UVitBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
channels,
|
||||
num_res_blocks: int,
|
||||
hidden_size,
|
||||
hidden_dropout,
|
||||
ln_elementwise_affine,
|
||||
layer_norm_eps,
|
||||
use_bias,
|
||||
block_num_heads,
|
||||
attention_dropout,
|
||||
downsample: bool,
|
||||
upsample: bool,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if downsample:
|
||||
self.downsample = Downsample2D(
|
||||
channels,
|
||||
use_conv=True,
|
||||
padding=0,
|
||||
name="Conv2d_0",
|
||||
kernel_size=2,
|
||||
norm_type="rms_norm",
|
||||
eps=layer_norm_eps,
|
||||
elementwise_affine=ln_elementwise_affine,
|
||||
bias=use_bias,
|
||||
)
|
||||
else:
|
||||
self.downsample = None
|
||||
|
||||
self.res_blocks = nn.ModuleList(
|
||||
[
|
||||
ConvNextBlock(
|
||||
channels,
|
||||
layer_norm_eps,
|
||||
ln_elementwise_affine,
|
||||
use_bias,
|
||||
hidden_dropout,
|
||||
hidden_size,
|
||||
)
|
||||
for i in range(num_res_blocks)
|
||||
]
|
||||
)
|
||||
|
||||
self.attention_blocks = nn.ModuleList(
|
||||
[
|
||||
SkipFFTransformerBlock(
|
||||
channels,
|
||||
block_num_heads,
|
||||
channels // block_num_heads,
|
||||
hidden_size,
|
||||
use_bias,
|
||||
attention_dropout,
|
||||
channels,
|
||||
attention_bias=use_bias,
|
||||
attention_out_bias=use_bias,
|
||||
)
|
||||
for _ in range(num_res_blocks)
|
||||
]
|
||||
)
|
||||
|
||||
if upsample:
|
||||
self.upsample = Upsample2D(
|
||||
channels,
|
||||
use_conv_transpose=True,
|
||||
kernel_size=2,
|
||||
padding=0,
|
||||
name="conv",
|
||||
norm_type="rms_norm",
|
||||
eps=layer_norm_eps,
|
||||
elementwise_affine=ln_elementwise_affine,
|
||||
bias=use_bias,
|
||||
interpolate=False,
|
||||
)
|
||||
else:
|
||||
self.upsample = None
|
||||
|
||||
def forward(self, x, pooled_text_emb, encoder_hidden_states, cross_attention_kwargs):
|
||||
if self.downsample is not None:
|
||||
x = self.downsample(x)
|
||||
|
||||
for res_block, attention_block in zip(self.res_blocks, self.attention_blocks):
|
||||
x = res_block(x, pooled_text_emb)
|
||||
|
||||
batch_size, channels, height, width = x.shape
|
||||
x = x.view(batch_size, channels, height * width).permute(0, 2, 1)
|
||||
x = attention_block(
|
||||
x, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs
|
||||
)
|
||||
x = x.permute(0, 2, 1).view(batch_size, channels, height, width)
|
||||
|
||||
if self.upsample is not None:
|
||||
x = self.upsample(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class ConvNextBlock(nn.Module):
|
||||
def __init__(
|
||||
self, channels, layer_norm_eps, ln_elementwise_affine, use_bias, hidden_dropout, hidden_size, res_ffn_factor=4
|
||||
):
|
||||
super().__init__()
|
||||
self.depthwise = nn.Conv2d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
groups=channels,
|
||||
bias=use_bias,
|
||||
)
|
||||
self.norm = RMSNorm(channels, layer_norm_eps, ln_elementwise_affine)
|
||||
self.channelwise_linear_1 = nn.Linear(channels, int(channels * res_ffn_factor), bias=use_bias)
|
||||
self.channelwise_act = nn.GELU()
|
||||
self.channelwise_norm = GlobalResponseNorm(int(channels * res_ffn_factor))
|
||||
self.channelwise_linear_2 = nn.Linear(int(channels * res_ffn_factor), channels, bias=use_bias)
|
||||
self.channelwise_dropout = nn.Dropout(hidden_dropout)
|
||||
self.cond_embeds_mapper = nn.Linear(hidden_size, channels * 2, use_bias)
|
||||
|
||||
def forward(self, x, cond_embeds):
|
||||
x_res = x
|
||||
|
||||
x = self.depthwise(x)
|
||||
|
||||
x = x.permute(0, 2, 3, 1)
|
||||
x = self.norm(x)
|
||||
|
||||
x = self.channelwise_linear_1(x)
|
||||
x = self.channelwise_act(x)
|
||||
x = self.channelwise_norm(x)
|
||||
x = self.channelwise_linear_2(x)
|
||||
x = self.channelwise_dropout(x)
|
||||
|
||||
x = x.permute(0, 3, 1, 2)
|
||||
|
||||
x = x + x_res
|
||||
|
||||
scale, shift = self.cond_embeds_mapper(F.silu(cond_embeds)).chunk(2, dim=1)
|
||||
x = x * (1 + scale[:, :, None, None]) + shift[:, :, None, None]
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class ConvMlmLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
block_out_channels: int,
|
||||
in_channels: int,
|
||||
use_bias: bool,
|
||||
ln_elementwise_affine: bool,
|
||||
layer_norm_eps: float,
|
||||
codebook_size: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.conv1 = nn.Conv2d(block_out_channels, in_channels, kernel_size=1, bias=use_bias)
|
||||
self.layer_norm = RMSNorm(in_channels, layer_norm_eps, ln_elementwise_affine)
|
||||
self.conv2 = nn.Conv2d(in_channels, codebook_size, kernel_size=1, bias=use_bias)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.conv1(hidden_states)
|
||||
hidden_states = self.layer_norm(hidden_states.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
||||
logits = self.conv2(hidden_states)
|
||||
return logits
|
||||
@@ -88,9 +88,6 @@ class VQModel(ModelMixin, ConfigMixin):
|
||||
vq_embed_dim: Optional[int] = None,
|
||||
scaling_factor: float = 0.18215,
|
||||
norm_type: str = "group", # group, spatial
|
||||
mid_block_add_attention=True,
|
||||
lookup_from_codebook=False,
|
||||
force_upcast=False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -104,7 +101,6 @@ class VQModel(ModelMixin, ConfigMixin):
|
||||
act_fn=act_fn,
|
||||
norm_num_groups=norm_num_groups,
|
||||
double_z=False,
|
||||
mid_block_add_attention=mid_block_add_attention,
|
||||
)
|
||||
|
||||
vq_embed_dim = vq_embed_dim if vq_embed_dim is not None else latent_channels
|
||||
@@ -123,7 +119,6 @@ class VQModel(ModelMixin, ConfigMixin):
|
||||
act_fn=act_fn,
|
||||
norm_num_groups=norm_num_groups,
|
||||
norm_type=norm_type,
|
||||
mid_block_add_attention=mid_block_add_attention,
|
||||
)
|
||||
|
||||
@apply_forward_hook
|
||||
@@ -138,13 +133,11 @@ class VQModel(ModelMixin, ConfigMixin):
|
||||
|
||||
@apply_forward_hook
|
||||
def decode(
|
||||
self, h: torch.FloatTensor, force_not_quantize: bool = False, return_dict: bool = True, shape=None
|
||||
self, h: torch.FloatTensor, force_not_quantize: bool = False, return_dict: bool = True
|
||||
) -> Union[DecoderOutput, torch.FloatTensor]:
|
||||
# also go through quantization layer
|
||||
if not force_not_quantize:
|
||||
quant, _, _ = self.quantize(h)
|
||||
elif self.config.lookup_from_codebook:
|
||||
quant = self.quantize.get_codebook_entry(h, shape)
|
||||
else:
|
||||
quant = h
|
||||
quant2 = self.post_quant_conv(quant)
|
||||
|
||||
@@ -108,7 +108,6 @@ else:
|
||||
"VersatileDiffusionTextToImagePipeline",
|
||||
]
|
||||
)
|
||||
_import_structure["amused"] = ["AmusedImg2ImgPipeline", "AmusedInpaintPipeline", "AmusedPipeline"]
|
||||
_import_structure["animatediff"] = ["AnimateDiffPipeline"]
|
||||
_import_structure["audioldm"] = ["AudioLDMPipeline"]
|
||||
_import_structure["audioldm2"] = [
|
||||
@@ -186,11 +185,12 @@ else:
|
||||
"StableDiffusionInpaintPipeline",
|
||||
"StableDiffusionInstructPix2PixPipeline",
|
||||
"StableDiffusionLatentUpscalePipeline",
|
||||
"StableDiffusionLDM3DPipeline",
|
||||
"StableDiffusionPanoramaPipeline",
|
||||
"StableDiffusionPipeline",
|
||||
"StableDiffusionUpscalePipeline",
|
||||
"StableUnCLIPImg2ImgPipeline",
|
||||
"StableUnCLIPPipeline",
|
||||
"StableDiffusionLDM3DPipeline",
|
||||
]
|
||||
)
|
||||
_import_structure["stable_diffusion_attend_and_excite"] = ["StableDiffusionAttendAndExcitePipeline"]
|
||||
@@ -210,8 +210,6 @@ else:
|
||||
]
|
||||
)
|
||||
_import_structure["stable_diffusion_diffedit"] = ["StableDiffusionDiffEditPipeline"]
|
||||
_import_structure["stable_diffusion_ldm3d"] = ["StableDiffusionLDM3DPipeline"]
|
||||
_import_structure["stable_diffusion_panorama"] = ["StableDiffusionPanoramaPipeline"]
|
||||
_import_structure["t2i_adapter"] = [
|
||||
"StableDiffusionAdapterPipeline",
|
||||
"StableDiffusionXLAdapterPipeline",
|
||||
@@ -343,7 +341,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils.dummy_torch_and_transformers_objects import *
|
||||
else:
|
||||
from .amused import AmusedImg2ImgPipeline, AmusedInpaintPipeline, AmusedPipeline
|
||||
from .animatediff import AnimateDiffPipeline
|
||||
from .audioldm import AudioLDMPipeline
|
||||
from .audioldm2 import (
|
||||
@@ -430,6 +427,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
StableDiffusionInpaintPipeline,
|
||||
StableDiffusionInstructPix2PixPipeline,
|
||||
StableDiffusionLatentUpscalePipeline,
|
||||
StableDiffusionLDM3DPipeline,
|
||||
StableDiffusionPanoramaPipeline,
|
||||
StableDiffusionPipeline,
|
||||
StableDiffusionUpscalePipeline,
|
||||
StableUnCLIPImg2ImgPipeline,
|
||||
@@ -438,8 +437,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .stable_diffusion_attend_and_excite import StableDiffusionAttendAndExcitePipeline
|
||||
from .stable_diffusion_diffedit import StableDiffusionDiffEditPipeline
|
||||
from .stable_diffusion_gligen import StableDiffusionGLIGENPipeline, StableDiffusionGLIGENTextImagePipeline
|
||||
from .stable_diffusion_ldm3d import StableDiffusionLDM3DPipeline
|
||||
from .stable_diffusion_panorama import StableDiffusionPanoramaPipeline
|
||||
from .stable_diffusion_safe import StableDiffusionPipelineSafe
|
||||
from .stable_diffusion_sag import StableDiffusionSAGPipeline
|
||||
from .stable_diffusion_xl import (
|
||||
|
||||
@@ -1,62 +0,0 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import (
|
||||
DIFFUSERS_SLOW_IMPORT,
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
|
||||
|
||||
_dummy_objects = {}
|
||||
_import_structure = {}
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import (
|
||||
AmusedImg2ImgPipeline,
|
||||
AmusedInpaintPipeline,
|
||||
AmusedPipeline,
|
||||
)
|
||||
|
||||
_dummy_objects.update(
|
||||
{
|
||||
"AmusedPipeline": AmusedPipeline,
|
||||
"AmusedImg2ImgPipeline": AmusedImg2ImgPipeline,
|
||||
"AmusedInpaintPipeline": AmusedInpaintPipeline,
|
||||
}
|
||||
)
|
||||
else:
|
||||
_import_structure["pipeline_amused"] = ["AmusedPipeline"]
|
||||
_import_structure["pipeline_amused_img2img"] = ["AmusedImg2ImgPipeline"]
|
||||
_import_structure["pipeline_amused_inpaint"] = ["AmusedInpaintPipeline"]
|
||||
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import (
|
||||
AmusedPipeline,
|
||||
)
|
||||
else:
|
||||
from .pipeline_amused import AmusedPipeline
|
||||
from .pipeline_amused_img2img import AmusedImg2ImgPipeline
|
||||
from .pipeline_amused_inpaint import AmusedInpaintPipeline
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(
|
||||
__name__,
|
||||
globals()["__file__"],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
)
|
||||
|
||||
for name, value in _dummy_objects.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
@@ -1,328 +0,0 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from transformers import CLIPTextModelWithProjection, CLIPTokenizer
|
||||
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...models import UVit2DModel, VQModel
|
||||
from ...schedulers import AmusedScheduler
|
||||
from ...utils import replace_example_docstring
|
||||
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> import torch
|
||||
>>> from diffusers import AmusedPipeline
|
||||
|
||||
>>> pipe = AmusedPipeline.from_pretrained(
|
||||
... "huggingface/amused-512", variant="fp16", torch_dtype=torch.float16
|
||||
... )
|
||||
>>> pipe = pipe.to("cuda")
|
||||
|
||||
>>> prompt = "a photo of an astronaut riding a horse on mars"
|
||||
>>> image = pipe(prompt).images[0]
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
class AmusedPipeline(DiffusionPipeline):
|
||||
image_processor: VaeImageProcessor
|
||||
vqvae: VQModel
|
||||
tokenizer: CLIPTokenizer
|
||||
text_encoder: CLIPTextModelWithProjection
|
||||
transformer: UVit2DModel
|
||||
scheduler: AmusedScheduler
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->transformer->vqvae"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vqvae: VQModel,
|
||||
tokenizer: CLIPTokenizer,
|
||||
text_encoder: CLIPTextModelWithProjection,
|
||||
transformer: UVit2DModel,
|
||||
scheduler: AmusedScheduler,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
vqvae=vqvae,
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vqvae.config.block_out_channels) - 1)
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_normalize=False)
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Optional[Union[List[str], str]] = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
num_inference_steps: int = 12,
|
||||
guidance_scale: float = 10.0,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
latents: Optional[torch.IntTensor] = None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
output_type="pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
callback_steps: int = 1,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
micro_conditioning_aesthetic_score: int = 6,
|
||||
micro_conditioning_crop_coord: Tuple[int, int] = (0, 0),
|
||||
temperature: Union[int, Tuple[int, int], List[int]] = (2, 0),
|
||||
):
|
||||
"""
|
||||
The call function to the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
|
||||
height (`int`, *optional*, defaults to `self.transformer.config.sample_size * self.vae_scale_factor`):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
||||
The width in pixels of the generated image.
|
||||
num_inference_steps (`int`, *optional*, defaults to 16):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
guidance_scale (`float`, *optional*, defaults to 10.0):
|
||||
A higher guidance scale value encourages the model to generate images closely linked to the text
|
||||
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide what to not include in image generation. If not defined, you need to
|
||||
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
||||
generation deterministic.
|
||||
latents (`torch.IntTensor`, *optional*):
|
||||
Pre-generated tokens representing latent vectors in `self.vqvae`, to be used as inputs for image
|
||||
gneration. If not provided, the starting latents will be completely masked.
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
|
||||
provided, text embeddings are generated from the `prompt` input argument. A single vector from the
|
||||
pooled and projected final hidden states.
|
||||
encoder_hidden_states (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated penultimate hidden states from the text encoder providing additional text conditioning.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
|
||||
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
|
||||
negative_encoder_hidden_states (`torch.FloatTensor`, *optional*):
|
||||
Analogous to `encoder_hidden_states` for the positive prompt.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
||||
plain tuple.
|
||||
callback (`Callable`, *optional*):
|
||||
A function that calls every `callback_steps` steps during inference. The function is called with the
|
||||
following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
||||
callback_steps (`int`, *optional*, defaults to 1):
|
||||
The frequency at which the `callback` function is called. If not specified, the callback is called at
|
||||
every step.
|
||||
cross_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
|
||||
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
micro_conditioning_aesthetic_score (`int`, *optional*, defaults to 6):
|
||||
The targeted aesthetic score according to the laion aesthetic classifier. See https://laion.ai/blog/laion-aesthetics/
|
||||
and the micro-conditioning section of https://arxiv.org/abs/2307.01952.
|
||||
micro_conditioning_crop_coord (`Tuple[int]`, *optional*, defaults to (0, 0)):
|
||||
The targeted height, width crop coordinates. See the micro-conditioning section of https://arxiv.org/abs/2307.01952.
|
||||
temperature (`Union[int, Tuple[int, int], List[int]]`, *optional*, defaults to (2, 0)):
|
||||
Configures the temperature scheduler on `self.scheduler` see `AmusedScheduler#set_timesteps`.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.pipeline_utils.ImagePipelineOutput`] or `tuple`:
|
||||
If `return_dict` is `True`, [`~pipelines.pipeline_utils.ImagePipelineOutput`] is returned, otherwise a
|
||||
`tuple` is returned where the first element is a list with the generated images.
|
||||
"""
|
||||
if (prompt_embeds is not None and encoder_hidden_states is None) or (
|
||||
prompt_embeds is None and encoder_hidden_states is not None
|
||||
):
|
||||
raise ValueError("pass either both `prompt_embeds` and `encoder_hidden_states` or neither")
|
||||
|
||||
if (negative_prompt_embeds is not None and negative_encoder_hidden_states is None) or (
|
||||
negative_prompt_embeds is None and negative_encoder_hidden_states is not None
|
||||
):
|
||||
raise ValueError(
|
||||
"pass either both `negatve_prompt_embeds` and `negative_encoder_hidden_states` or neither"
|
||||
)
|
||||
|
||||
if (prompt is None and prompt_embeds is None) or (prompt is not None and prompt_embeds is not None):
|
||||
raise ValueError("pass only one of `prompt` or `prompt_embeds`")
|
||||
|
||||
if isinstance(prompt, str):
|
||||
prompt = [prompt]
|
||||
|
||||
if prompt is not None:
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
batch_size = batch_size * num_images_per_prompt
|
||||
|
||||
if height is None:
|
||||
height = self.transformer.config.sample_size * self.vae_scale_factor
|
||||
|
||||
if width is None:
|
||||
width = self.transformer.config.sample_size * self.vae_scale_factor
|
||||
|
||||
if prompt_embeds is None:
|
||||
input_ids = self.tokenizer(
|
||||
prompt,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
).input_ids.to(self._execution_device)
|
||||
|
||||
outputs = self.text_encoder(input_ids, return_dict=True, output_hidden_states=True)
|
||||
prompt_embeds = outputs.text_embeds
|
||||
encoder_hidden_states = outputs.hidden_states[-2]
|
||||
|
||||
prompt_embeds = prompt_embeds.repeat(num_images_per_prompt, 1)
|
||||
encoder_hidden_states = encoder_hidden_states.repeat(num_images_per_prompt, 1, 1)
|
||||
|
||||
if guidance_scale > 1.0:
|
||||
if negative_prompt_embeds is None:
|
||||
if negative_prompt is None:
|
||||
negative_prompt = [""] * len(prompt)
|
||||
|
||||
if isinstance(negative_prompt, str):
|
||||
negative_prompt = [negative_prompt]
|
||||
|
||||
input_ids = self.tokenizer(
|
||||
negative_prompt,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
).input_ids.to(self._execution_device)
|
||||
|
||||
outputs = self.text_encoder(input_ids, return_dict=True, output_hidden_states=True)
|
||||
negative_prompt_embeds = outputs.text_embeds
|
||||
negative_encoder_hidden_states = outputs.hidden_states[-2]
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(num_images_per_prompt, 1)
|
||||
negative_encoder_hidden_states = negative_encoder_hidden_states.repeat(num_images_per_prompt, 1, 1)
|
||||
|
||||
prompt_embeds = torch.concat([negative_prompt_embeds, prompt_embeds])
|
||||
encoder_hidden_states = torch.concat([negative_encoder_hidden_states, encoder_hidden_states])
|
||||
|
||||
# Note that the micro conditionings _do_ flip the order of width, height for the original size
|
||||
# and the crop coordinates. This is how it was done in the original code base
|
||||
micro_conds = torch.tensor(
|
||||
[
|
||||
width,
|
||||
height,
|
||||
micro_conditioning_crop_coord[0],
|
||||
micro_conditioning_crop_coord[1],
|
||||
micro_conditioning_aesthetic_score,
|
||||
],
|
||||
device=self._execution_device,
|
||||
dtype=encoder_hidden_states.dtype,
|
||||
)
|
||||
micro_conds = micro_conds.unsqueeze(0)
|
||||
micro_conds = micro_conds.expand(2 * batch_size if guidance_scale > 1.0 else batch_size, -1)
|
||||
|
||||
shape = (batch_size, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
||||
|
||||
if latents is None:
|
||||
latents = torch.full(
|
||||
shape, self.scheduler.config.mask_token_id, dtype=torch.long, device=self._execution_device
|
||||
)
|
||||
|
||||
self.scheduler.set_timesteps(num_inference_steps, temperature, self._execution_device)
|
||||
|
||||
num_warmup_steps = len(self.scheduler.timesteps) - num_inference_steps * self.scheduler.order
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, timestep in enumerate(self.scheduler.timesteps):
|
||||
if guidance_scale > 1.0:
|
||||
model_input = torch.cat([latents] * 2)
|
||||
else:
|
||||
model_input = latents
|
||||
|
||||
model_output = self.transformer(
|
||||
model_input,
|
||||
micro_conds=micro_conds,
|
||||
pooled_text_emb=prompt_embeds,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
)
|
||||
|
||||
if guidance_scale > 1.0:
|
||||
uncond_logits, cond_logits = model_output.chunk(2)
|
||||
model_output = uncond_logits + guidance_scale * (cond_logits - uncond_logits)
|
||||
|
||||
latents = self.scheduler.step(
|
||||
model_output=model_output,
|
||||
timestep=timestep,
|
||||
sample=latents,
|
||||
generator=generator,
|
||||
).prev_sample
|
||||
|
||||
if i == len(self.scheduler.timesteps) - 1 or (
|
||||
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
|
||||
):
|
||||
progress_bar.update()
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
step_idx = i // getattr(self.scheduler, "order", 1)
|
||||
callback(step_idx, timestep, latents)
|
||||
|
||||
if output_type == "latent":
|
||||
output = latents
|
||||
else:
|
||||
needs_upcasting = self.vqvae.dtype == torch.float16 and self.vqvae.config.force_upcast
|
||||
|
||||
if needs_upcasting:
|
||||
self.vqvae.float()
|
||||
|
||||
output = self.vqvae.decode(
|
||||
latents,
|
||||
force_not_quantize=True,
|
||||
shape=(
|
||||
batch_size,
|
||||
height // self.vae_scale_factor,
|
||||
width // self.vae_scale_factor,
|
||||
self.vqvae.config.latent_channels,
|
||||
),
|
||||
).sample.clip(0, 1)
|
||||
output = self.image_processor.postprocess(output, output_type)
|
||||
|
||||
if needs_upcasting:
|
||||
self.vqvae.half()
|
||||
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
return ImagePipelineOutput(output)
|
||||
@@ -1,347 +0,0 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from transformers import CLIPTextModelWithProjection, CLIPTokenizer
|
||||
|
||||
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from ...models import UVit2DModel, VQModel
|
||||
from ...schedulers import AmusedScheduler
|
||||
from ...utils import replace_example_docstring
|
||||
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> import torch
|
||||
>>> from diffusers import AmusedImg2ImgPipeline
|
||||
>>> from diffusers.utils import load_image
|
||||
|
||||
>>> pipe = AmusedImg2ImgPipeline.from_pretrained(
|
||||
... "huggingface/amused-512", variant="fp16", torch_dtype=torch.float16
|
||||
... )
|
||||
>>> pipe = pipe.to("cuda")
|
||||
|
||||
>>> prompt = "winter mountains"
|
||||
>>> input_image = (
|
||||
... load_image(
|
||||
... "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/open_muse/mountains.jpg"
|
||||
... )
|
||||
... .resize((512, 512))
|
||||
... .convert("RGB")
|
||||
... )
|
||||
>>> image = pipe(prompt, input_image).images[0]
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
class AmusedImg2ImgPipeline(DiffusionPipeline):
|
||||
image_processor: VaeImageProcessor
|
||||
vqvae: VQModel
|
||||
tokenizer: CLIPTokenizer
|
||||
text_encoder: CLIPTextModelWithProjection
|
||||
transformer: UVit2DModel
|
||||
scheduler: AmusedScheduler
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->transformer->vqvae"
|
||||
|
||||
# TODO - when calling self.vqvae.quantize, it uses self.vqvae.quantize.embedding.weight before
|
||||
# the forward method of self.vqvae.quantize, so the hook doesn't get called to move the parameter
|
||||
# off the meta device. There should be a way to fix this instead of just not offloading it
|
||||
_exclude_from_cpu_offload = ["vqvae"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vqvae: VQModel,
|
||||
tokenizer: CLIPTokenizer,
|
||||
text_encoder: CLIPTextModelWithProjection,
|
||||
transformer: UVit2DModel,
|
||||
scheduler: AmusedScheduler,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
vqvae=vqvae,
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vqvae.config.block_out_channels) - 1)
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_normalize=False)
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Optional[Union[List[str], str]] = None,
|
||||
image: PipelineImageInput = None,
|
||||
strength: float = 0.5,
|
||||
num_inference_steps: int = 12,
|
||||
guidance_scale: float = 10.0,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
output_type="pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
callback_steps: int = 1,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
micro_conditioning_aesthetic_score: int = 6,
|
||||
micro_conditioning_crop_coord: Tuple[int, int] = (0, 0),
|
||||
temperature: Union[int, Tuple[int, int], List[int]] = (2, 0),
|
||||
):
|
||||
"""
|
||||
The call function to the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
|
||||
image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
|
||||
`Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
|
||||
numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
|
||||
or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
|
||||
list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
|
||||
latents as `image`, but if passing latents directly it is not encoded again.
|
||||
strength (`float`, *optional*, defaults to 0.5):
|
||||
Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
|
||||
starting point and more noise is added the higher the `strength`. The number of denoising steps depends
|
||||
on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
|
||||
process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
|
||||
essentially ignores `image`.
|
||||
num_inference_steps (`int`, *optional*, defaults to 16):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
guidance_scale (`float`, *optional*, defaults to 10.0):
|
||||
A higher guidance scale value encourages the model to generate images closely linked to the text
|
||||
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide what to not include in image generation. If not defined, you need to
|
||||
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
||||
generation deterministic.
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
|
||||
provided, text embeddings are generated from the `prompt` input argument. A single vector from the
|
||||
pooled and projected final hidden states.
|
||||
encoder_hidden_states (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated penultimate hidden states from the text encoder providing additional text conditioning.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
|
||||
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
|
||||
negative_encoder_hidden_states (`torch.FloatTensor`, *optional*):
|
||||
Analogous to `encoder_hidden_states` for the positive prompt.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
||||
plain tuple.
|
||||
callback (`Callable`, *optional*):
|
||||
A function that calls every `callback_steps` steps during inference. The function is called with the
|
||||
following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
||||
callback_steps (`int`, *optional*, defaults to 1):
|
||||
The frequency at which the `callback` function is called. If not specified, the callback is called at
|
||||
every step.
|
||||
cross_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
|
||||
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
micro_conditioning_aesthetic_score (`int`, *optional*, defaults to 6):
|
||||
The targeted aesthetic score according to the laion aesthetic classifier. See https://laion.ai/blog/laion-aesthetics/
|
||||
and the micro-conditioning section of https://arxiv.org/abs/2307.01952.
|
||||
micro_conditioning_crop_coord (`Tuple[int]`, *optional*, defaults to (0, 0)):
|
||||
The targeted height, width crop coordinates. See the micro-conditioning section of https://arxiv.org/abs/2307.01952.
|
||||
temperature (`Union[int, Tuple[int, int], List[int]]`, *optional*, defaults to (2, 0)):
|
||||
Configures the temperature scheduler on `self.scheduler` see `AmusedScheduler#set_timesteps`.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.pipeline_utils.ImagePipelineOutput`] or `tuple`:
|
||||
If `return_dict` is `True`, [`~pipelines.pipeline_utils.ImagePipelineOutput`] is returned, otherwise a
|
||||
`tuple` is returned where the first element is a list with the generated images.
|
||||
"""
|
||||
|
||||
if (prompt_embeds is not None and encoder_hidden_states is None) or (
|
||||
prompt_embeds is None and encoder_hidden_states is not None
|
||||
):
|
||||
raise ValueError("pass either both `prompt_embeds` and `encoder_hidden_states` or neither")
|
||||
|
||||
if (negative_prompt_embeds is not None and negative_encoder_hidden_states is None) or (
|
||||
negative_prompt_embeds is None and negative_encoder_hidden_states is not None
|
||||
):
|
||||
raise ValueError(
|
||||
"pass either both `negatve_prompt_embeds` and `negative_encoder_hidden_states` or neither"
|
||||
)
|
||||
|
||||
if (prompt is None and prompt_embeds is None) or (prompt is not None and prompt_embeds is not None):
|
||||
raise ValueError("pass only one of `prompt` or `prompt_embeds`")
|
||||
|
||||
if isinstance(prompt, str):
|
||||
prompt = [prompt]
|
||||
|
||||
if prompt is not None:
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
batch_size = batch_size * num_images_per_prompt
|
||||
|
||||
if prompt_embeds is None:
|
||||
input_ids = self.tokenizer(
|
||||
prompt,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
).input_ids.to(self._execution_device)
|
||||
|
||||
outputs = self.text_encoder(input_ids, return_dict=True, output_hidden_states=True)
|
||||
prompt_embeds = outputs.text_embeds
|
||||
encoder_hidden_states = outputs.hidden_states[-2]
|
||||
|
||||
prompt_embeds = prompt_embeds.repeat(num_images_per_prompt, 1)
|
||||
encoder_hidden_states = encoder_hidden_states.repeat(num_images_per_prompt, 1, 1)
|
||||
|
||||
if guidance_scale > 1.0:
|
||||
if negative_prompt_embeds is None:
|
||||
if negative_prompt is None:
|
||||
negative_prompt = [""] * len(prompt)
|
||||
|
||||
if isinstance(negative_prompt, str):
|
||||
negative_prompt = [negative_prompt]
|
||||
|
||||
input_ids = self.tokenizer(
|
||||
negative_prompt,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
).input_ids.to(self._execution_device)
|
||||
|
||||
outputs = self.text_encoder(input_ids, return_dict=True, output_hidden_states=True)
|
||||
negative_prompt_embeds = outputs.text_embeds
|
||||
negative_encoder_hidden_states = outputs.hidden_states[-2]
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(num_images_per_prompt, 1)
|
||||
negative_encoder_hidden_states = negative_encoder_hidden_states.repeat(num_images_per_prompt, 1, 1)
|
||||
|
||||
prompt_embeds = torch.concat([negative_prompt_embeds, prompt_embeds])
|
||||
encoder_hidden_states = torch.concat([negative_encoder_hidden_states, encoder_hidden_states])
|
||||
|
||||
image = self.image_processor.preprocess(image)
|
||||
|
||||
height, width = image.shape[-2:]
|
||||
|
||||
# Note that the micro conditionings _do_ flip the order of width, height for the original size
|
||||
# and the crop coordinates. This is how it was done in the original code base
|
||||
micro_conds = torch.tensor(
|
||||
[
|
||||
width,
|
||||
height,
|
||||
micro_conditioning_crop_coord[0],
|
||||
micro_conditioning_crop_coord[1],
|
||||
micro_conditioning_aesthetic_score,
|
||||
],
|
||||
device=self._execution_device,
|
||||
dtype=encoder_hidden_states.dtype,
|
||||
)
|
||||
|
||||
micro_conds = micro_conds.unsqueeze(0)
|
||||
micro_conds = micro_conds.expand(2 * batch_size if guidance_scale > 1.0 else batch_size, -1)
|
||||
|
||||
self.scheduler.set_timesteps(num_inference_steps, temperature, self._execution_device)
|
||||
num_inference_steps = int(len(self.scheduler.timesteps) * strength)
|
||||
start_timestep_idx = len(self.scheduler.timesteps) - num_inference_steps
|
||||
|
||||
needs_upcasting = self.vqvae.dtype == torch.float16 and self.vqvae.config.force_upcast
|
||||
|
||||
if needs_upcasting:
|
||||
self.vqvae.float()
|
||||
|
||||
latents = self.vqvae.encode(image.to(dtype=self.vqvae.dtype, device=self._execution_device)).latents
|
||||
latents_bsz, channels, latents_height, latents_width = latents.shape
|
||||
latents = self.vqvae.quantize(latents)[2][2].reshape(latents_bsz, latents_height, latents_width)
|
||||
latents = self.scheduler.add_noise(
|
||||
latents, self.scheduler.timesteps[start_timestep_idx - 1], generator=generator
|
||||
)
|
||||
latents = latents.repeat(num_images_per_prompt, 1, 1)
|
||||
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i in range(start_timestep_idx, len(self.scheduler.timesteps)):
|
||||
timestep = self.scheduler.timesteps[i]
|
||||
|
||||
if guidance_scale > 1.0:
|
||||
model_input = torch.cat([latents] * 2)
|
||||
else:
|
||||
model_input = latents
|
||||
|
||||
model_output = self.transformer(
|
||||
model_input,
|
||||
micro_conds=micro_conds,
|
||||
pooled_text_emb=prompt_embeds,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
)
|
||||
|
||||
if guidance_scale > 1.0:
|
||||
uncond_logits, cond_logits = model_output.chunk(2)
|
||||
model_output = uncond_logits + guidance_scale * (cond_logits - uncond_logits)
|
||||
|
||||
latents = self.scheduler.step(
|
||||
model_output=model_output,
|
||||
timestep=timestep,
|
||||
sample=latents,
|
||||
generator=generator,
|
||||
).prev_sample
|
||||
|
||||
if i == len(self.scheduler.timesteps) - 1 or ((i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
step_idx = i // getattr(self.scheduler, "order", 1)
|
||||
callback(step_idx, timestep, latents)
|
||||
|
||||
if output_type == "latent":
|
||||
output = latents
|
||||
else:
|
||||
output = self.vqvae.decode(
|
||||
latents,
|
||||
force_not_quantize=True,
|
||||
shape=(
|
||||
batch_size,
|
||||
height // self.vae_scale_factor,
|
||||
width // self.vae_scale_factor,
|
||||
self.vqvae.config.latent_channels,
|
||||
),
|
||||
).sample.clip(0, 1)
|
||||
output = self.image_processor.postprocess(output, output_type)
|
||||
|
||||
if needs_upcasting:
|
||||
self.vqvae.half()
|
||||
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
return ImagePipelineOutput(output)
|
||||
@@ -1,378 +0,0 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from transformers import CLIPTextModelWithProjection, CLIPTokenizer
|
||||
|
||||
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from ...models import UVit2DModel, VQModel
|
||||
from ...schedulers import AmusedScheduler
|
||||
from ...utils import replace_example_docstring
|
||||
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> import torch
|
||||
>>> from diffusers import AmusedInpaintPipeline
|
||||
>>> from diffusers.utils import load_image
|
||||
|
||||
>>> pipe = AmusedInpaintPipeline.from_pretrained(
|
||||
... "huggingface/amused-512", variant="fp16", torch_dtype=torch.float16
|
||||
... )
|
||||
>>> pipe = pipe.to("cuda")
|
||||
|
||||
>>> prompt = "fall mountains"
|
||||
>>> input_image = (
|
||||
... load_image(
|
||||
... "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/open_muse/mountains_1.jpg"
|
||||
... )
|
||||
... .resize((512, 512))
|
||||
... .convert("RGB")
|
||||
... )
|
||||
>>> mask = (
|
||||
... load_image(
|
||||
... "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/open_muse/mountains_1_mask.png"
|
||||
... )
|
||||
... .resize((512, 512))
|
||||
... .convert("L")
|
||||
... )
|
||||
>>> pipe(prompt, input_image, mask).images[0].save("out.png")
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
class AmusedInpaintPipeline(DiffusionPipeline):
|
||||
image_processor: VaeImageProcessor
|
||||
vqvae: VQModel
|
||||
tokenizer: CLIPTokenizer
|
||||
text_encoder: CLIPTextModelWithProjection
|
||||
transformer: UVit2DModel
|
||||
scheduler: AmusedScheduler
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->transformer->vqvae"
|
||||
|
||||
# TODO - when calling self.vqvae.quantize, it uses self.vqvae.quantize.embedding.weight before
|
||||
# the forward method of self.vqvae.quantize, so the hook doesn't get called to move the parameter
|
||||
# off the meta device. There should be a way to fix this instead of just not offloading it
|
||||
_exclude_from_cpu_offload = ["vqvae"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vqvae: VQModel,
|
||||
tokenizer: CLIPTokenizer,
|
||||
text_encoder: CLIPTextModelWithProjection,
|
||||
transformer: UVit2DModel,
|
||||
scheduler: AmusedScheduler,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
vqvae=vqvae,
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vqvae.config.block_out_channels) - 1)
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_normalize=False)
|
||||
self.mask_processor = VaeImageProcessor(
|
||||
vae_scale_factor=self.vae_scale_factor,
|
||||
do_normalize=False,
|
||||
do_binarize=True,
|
||||
do_convert_grayscale=True,
|
||||
do_resize=True,
|
||||
)
|
||||
self.scheduler.register_to_config(masking_schedule="linear")
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Optional[Union[List[str], str]] = None,
|
||||
image: PipelineImageInput = None,
|
||||
mask_image: PipelineImageInput = None,
|
||||
strength: float = 1.0,
|
||||
num_inference_steps: int = 12,
|
||||
guidance_scale: float = 10.0,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
output_type="pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
callback_steps: int = 1,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
micro_conditioning_aesthetic_score: int = 6,
|
||||
micro_conditioning_crop_coord: Tuple[int, int] = (0, 0),
|
||||
temperature: Union[int, Tuple[int, int], List[int]] = (2, 0),
|
||||
):
|
||||
"""
|
||||
The call function to the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
|
||||
image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
|
||||
`Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
|
||||
numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
|
||||
or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
|
||||
list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
|
||||
latents as `image`, but if passing latents directly it is not encoded again.
|
||||
mask_image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
|
||||
`Image`, numpy array or tensor representing an image batch to mask `image`. White pixels in the mask
|
||||
are repainted while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a
|
||||
single channel (luminance) before use. If it's a numpy array or pytorch tensor, it should contain one
|
||||
color channel (L) instead of 3, so the expected shape for pytorch tensor would be `(B, 1, H, W)`, `(B,
|
||||
H, W)`, `(1, H, W)`, `(H, W)`. And for numpy array would be for `(B, H, W, 1)`, `(B, H, W)`, `(H, W,
|
||||
1)`, or `(H, W)`.
|
||||
strength (`float`, *optional*, defaults to 1.0):
|
||||
Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
|
||||
starting point and more noise is added the higher the `strength`. The number of denoising steps depends
|
||||
on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
|
||||
process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
|
||||
essentially ignores `image`.
|
||||
num_inference_steps (`int`, *optional*, defaults to 16):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
guidance_scale (`float`, *optional*, defaults to 10.0):
|
||||
A higher guidance scale value encourages the model to generate images closely linked to the text
|
||||
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide what to not include in image generation. If not defined, you need to
|
||||
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
||||
generation deterministic.
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
|
||||
provided, text embeddings are generated from the `prompt` input argument. A single vector from the
|
||||
pooled and projected final hidden states.
|
||||
encoder_hidden_states (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated penultimate hidden states from the text encoder providing additional text conditioning.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
|
||||
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
|
||||
negative_encoder_hidden_states (`torch.FloatTensor`, *optional*):
|
||||
Analogous to `encoder_hidden_states` for the positive prompt.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
||||
plain tuple.
|
||||
callback (`Callable`, *optional*):
|
||||
A function that calls every `callback_steps` steps during inference. The function is called with the
|
||||
following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
||||
callback_steps (`int`, *optional*, defaults to 1):
|
||||
The frequency at which the `callback` function is called. If not specified, the callback is called at
|
||||
every step.
|
||||
cross_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
|
||||
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
micro_conditioning_aesthetic_score (`int`, *optional*, defaults to 6):
|
||||
The targeted aesthetic score according to the laion aesthetic classifier. See https://laion.ai/blog/laion-aesthetics/
|
||||
and the micro-conditioning section of https://arxiv.org/abs/2307.01952.
|
||||
micro_conditioning_crop_coord (`Tuple[int]`, *optional*, defaults to (0, 0)):
|
||||
The targeted height, width crop coordinates. See the micro-conditioning section of https://arxiv.org/abs/2307.01952.
|
||||
temperature (`Union[int, Tuple[int, int], List[int]]`, *optional*, defaults to (2, 0)):
|
||||
Configures the temperature scheduler on `self.scheduler` see `AmusedScheduler#set_timesteps`.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.pipeline_utils.ImagePipelineOutput`] or `tuple`:
|
||||
If `return_dict` is `True`, [`~pipelines.pipeline_utils.ImagePipelineOutput`] is returned, otherwise a
|
||||
`tuple` is returned where the first element is a list with the generated images.
|
||||
"""
|
||||
|
||||
if (prompt_embeds is not None and encoder_hidden_states is None) or (
|
||||
prompt_embeds is None and encoder_hidden_states is not None
|
||||
):
|
||||
raise ValueError("pass either both `prompt_embeds` and `encoder_hidden_states` or neither")
|
||||
|
||||
if (negative_prompt_embeds is not None and negative_encoder_hidden_states is None) or (
|
||||
negative_prompt_embeds is None and negative_encoder_hidden_states is not None
|
||||
):
|
||||
raise ValueError(
|
||||
"pass either both `negatve_prompt_embeds` and `negative_encoder_hidden_states` or neither"
|
||||
)
|
||||
|
||||
if (prompt is None and prompt_embeds is None) or (prompt is not None and prompt_embeds is not None):
|
||||
raise ValueError("pass only one of `prompt` or `prompt_embeds`")
|
||||
|
||||
if isinstance(prompt, str):
|
||||
prompt = [prompt]
|
||||
|
||||
if prompt is not None:
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
batch_size = batch_size * num_images_per_prompt
|
||||
|
||||
if prompt_embeds is None:
|
||||
input_ids = self.tokenizer(
|
||||
prompt,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
).input_ids.to(self._execution_device)
|
||||
|
||||
outputs = self.text_encoder(input_ids, return_dict=True, output_hidden_states=True)
|
||||
prompt_embeds = outputs.text_embeds
|
||||
encoder_hidden_states = outputs.hidden_states[-2]
|
||||
|
||||
prompt_embeds = prompt_embeds.repeat(num_images_per_prompt, 1)
|
||||
encoder_hidden_states = encoder_hidden_states.repeat(num_images_per_prompt, 1, 1)
|
||||
|
||||
if guidance_scale > 1.0:
|
||||
if negative_prompt_embeds is None:
|
||||
if negative_prompt is None:
|
||||
negative_prompt = [""] * len(prompt)
|
||||
|
||||
if isinstance(negative_prompt, str):
|
||||
negative_prompt = [negative_prompt]
|
||||
|
||||
input_ids = self.tokenizer(
|
||||
negative_prompt,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
).input_ids.to(self._execution_device)
|
||||
|
||||
outputs = self.text_encoder(input_ids, return_dict=True, output_hidden_states=True)
|
||||
negative_prompt_embeds = outputs.text_embeds
|
||||
negative_encoder_hidden_states = outputs.hidden_states[-2]
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(num_images_per_prompt, 1)
|
||||
negative_encoder_hidden_states = negative_encoder_hidden_states.repeat(num_images_per_prompt, 1, 1)
|
||||
|
||||
prompt_embeds = torch.concat([negative_prompt_embeds, prompt_embeds])
|
||||
encoder_hidden_states = torch.concat([negative_encoder_hidden_states, encoder_hidden_states])
|
||||
|
||||
image = self.image_processor.preprocess(image)
|
||||
|
||||
height, width = image.shape[-2:]
|
||||
|
||||
# Note that the micro conditionings _do_ flip the order of width, height for the original size
|
||||
# and the crop coordinates. This is how it was done in the original code base
|
||||
micro_conds = torch.tensor(
|
||||
[
|
||||
width,
|
||||
height,
|
||||
micro_conditioning_crop_coord[0],
|
||||
micro_conditioning_crop_coord[1],
|
||||
micro_conditioning_aesthetic_score,
|
||||
],
|
||||
device=self._execution_device,
|
||||
dtype=encoder_hidden_states.dtype,
|
||||
)
|
||||
|
||||
micro_conds = micro_conds.unsqueeze(0)
|
||||
micro_conds = micro_conds.expand(2 * batch_size if guidance_scale > 1.0 else batch_size, -1)
|
||||
|
||||
self.scheduler.set_timesteps(num_inference_steps, temperature, self._execution_device)
|
||||
num_inference_steps = int(len(self.scheduler.timesteps) * strength)
|
||||
start_timestep_idx = len(self.scheduler.timesteps) - num_inference_steps
|
||||
|
||||
needs_upcasting = self.vqvae.dtype == torch.float16 and self.vqvae.config.force_upcast
|
||||
|
||||
if needs_upcasting:
|
||||
self.vqvae.float()
|
||||
|
||||
latents = self.vqvae.encode(image.to(dtype=self.vqvae.dtype, device=self._execution_device)).latents
|
||||
latents_bsz, channels, latents_height, latents_width = latents.shape
|
||||
latents = self.vqvae.quantize(latents)[2][2].reshape(latents_bsz, latents_height, latents_width)
|
||||
|
||||
mask = self.mask_processor.preprocess(
|
||||
mask_image, height // self.vae_scale_factor, width // self.vae_scale_factor
|
||||
)
|
||||
mask = mask.reshape(mask.shape[0], latents_height, latents_width).bool().to(latents.device)
|
||||
latents[mask] = self.scheduler.config.mask_token_id
|
||||
|
||||
starting_mask_ratio = mask.sum() / latents.numel()
|
||||
|
||||
latents = latents.repeat(num_images_per_prompt, 1, 1)
|
||||
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i in range(start_timestep_idx, len(self.scheduler.timesteps)):
|
||||
timestep = self.scheduler.timesteps[i]
|
||||
|
||||
if guidance_scale > 1.0:
|
||||
model_input = torch.cat([latents] * 2)
|
||||
else:
|
||||
model_input = latents
|
||||
|
||||
model_output = self.transformer(
|
||||
model_input,
|
||||
micro_conds=micro_conds,
|
||||
pooled_text_emb=prompt_embeds,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
)
|
||||
|
||||
if guidance_scale > 1.0:
|
||||
uncond_logits, cond_logits = model_output.chunk(2)
|
||||
model_output = uncond_logits + guidance_scale * (cond_logits - uncond_logits)
|
||||
|
||||
latents = self.scheduler.step(
|
||||
model_output=model_output,
|
||||
timestep=timestep,
|
||||
sample=latents,
|
||||
generator=generator,
|
||||
starting_mask_ratio=starting_mask_ratio,
|
||||
).prev_sample
|
||||
|
||||
if i == len(self.scheduler.timesteps) - 1 or ((i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
step_idx = i // getattr(self.scheduler, "order", 1)
|
||||
callback(step_idx, timestep, latents)
|
||||
|
||||
if output_type == "latent":
|
||||
output = latents
|
||||
else:
|
||||
output = self.vqvae.decode(
|
||||
latents,
|
||||
force_not_quantize=True,
|
||||
shape=(
|
||||
batch_size,
|
||||
height // self.vae_scale_factor,
|
||||
width // self.vae_scale_factor,
|
||||
self.vqvae.config.latent_channels,
|
||||
),
|
||||
).sample.clip(0, 1)
|
||||
output = self.image_processor.postprocess(output, output_type)
|
||||
|
||||
if needs_upcasting:
|
||||
self.vqvae.half()
|
||||
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
return ImagePipelineOutput(output)
|
||||
@@ -633,7 +633,7 @@ class StableDiffusionControlNetPipeline(
|
||||
# When `image` is a nested list:
|
||||
# (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]])
|
||||
elif any(isinstance(i, list) for i in image):
|
||||
raise ValueError("A single batch of multiple conditionings is not supported at the moment.")
|
||||
raise ValueError("A single batch of multiple conditionings are supported at the moment.")
|
||||
elif len(image) != len(self.controlnet.nets):
|
||||
raise ValueError(
|
||||
f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets."
|
||||
@@ -659,7 +659,7 @@ class StableDiffusionControlNetPipeline(
|
||||
):
|
||||
if isinstance(controlnet_conditioning_scale, list):
|
||||
if any(isinstance(i, list) for i in controlnet_conditioning_scale):
|
||||
raise ValueError("A single batch of multiple conditionings is not supported at the moment.")
|
||||
raise ValueError("A single batch of multiple conditionings are supported at the moment.")
|
||||
elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len(
|
||||
self.controlnet.nets
|
||||
):
|
||||
|
||||
@@ -40,7 +40,9 @@ else:
|
||||
_import_structure["pipeline_stable_diffusion_inpaint_legacy"] = ["StableDiffusionInpaintPipelineLegacy"]
|
||||
_import_structure["pipeline_stable_diffusion_instruct_pix2pix"] = ["StableDiffusionInstructPix2PixPipeline"]
|
||||
_import_structure["pipeline_stable_diffusion_latent_upscale"] = ["StableDiffusionLatentUpscalePipeline"]
|
||||
_import_structure["pipeline_stable_diffusion_ldm3d"] = ["StableDiffusionLDM3DPipeline"]
|
||||
_import_structure["pipeline_stable_diffusion_model_editing"] = ["StableDiffusionModelEditingPipeline"]
|
||||
_import_structure["pipeline_stable_diffusion_panorama"] = ["StableDiffusionPanoramaPipeline"]
|
||||
_import_structure["pipeline_stable_diffusion_paradigms"] = ["StableDiffusionParadigmsPipeline"]
|
||||
_import_structure["pipeline_stable_diffusion_upscale"] = ["StableDiffusionUpscalePipeline"]
|
||||
_import_structure["pipeline_stable_unclip"] = ["StableUnCLIPPipeline"]
|
||||
@@ -64,15 +66,18 @@ try:
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import (
|
||||
StableDiffusionDepth2ImgPipeline,
|
||||
StableDiffusionPix2PixZeroPipeline,
|
||||
)
|
||||
|
||||
_dummy_objects.update(
|
||||
{
|
||||
"StableDiffusionDepth2ImgPipeline": StableDiffusionDepth2ImgPipeline,
|
||||
"StableDiffusionPix2PixZeroPipeline": StableDiffusionPix2PixZeroPipeline,
|
||||
}
|
||||
)
|
||||
else:
|
||||
_import_structure["pipeline_stable_diffusion_depth2img"] = ["StableDiffusionDepth2ImgPipeline"]
|
||||
_import_structure["pipeline_stable_diffusion_pix2pix_zero"] = ["StableDiffusionPix2PixZeroPipeline"]
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_onnx_available()):
|
||||
@@ -123,6 +128,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .pipeline_stable_diffusion_latent_upscale import (
|
||||
StableDiffusionLatentUpscalePipeline,
|
||||
)
|
||||
from .pipeline_stable_diffusion_ldm3d import StableDiffusionLDM3DPipeline
|
||||
from .pipeline_stable_diffusion_panorama import StableDiffusionPanoramaPipeline
|
||||
from .pipeline_stable_diffusion_upscale import StableDiffusionUpscalePipeline
|
||||
from .pipeline_stable_unclip import StableUnCLIPPipeline
|
||||
from .pipeline_stable_unclip_img2img import StableUnCLIPImg2ImgPipeline
|
||||
@@ -145,7 +152,10 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.26.0")):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import StableDiffusionDepth2ImgPipeline
|
||||
from ...utils.dummy_torch_and_transformers_objects import (
|
||||
StableDiffusionDepth2ImgPipeline,
|
||||
StableDiffusionPix2PixZeroPipeline,
|
||||
)
|
||||
else:
|
||||
from .pipeline_stable_diffusion_depth2img import (
|
||||
StableDiffusionDepth2ImgPipeline,
|
||||
|
||||
@@ -768,10 +768,6 @@ class StableDiffusionPipeline(
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
@@ -913,7 +909,6 @@ class StableDiffusionPipeline(
|
||||
self._guidance_rescale = guidance_rescale
|
||||
self._clip_skip = clip_skip
|
||||
self._cross_attention_kwargs = cross_attention_kwargs
|
||||
self._interrupt = False
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
@@ -991,9 +986,6 @@ class StableDiffusionPipeline(
|
||||
self._num_timesteps = len(timesteps)
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
@@ -832,10 +832,6 @@ class StableDiffusionImg2ImgPipeline(
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
@@ -967,7 +963,6 @@ class StableDiffusionImg2ImgPipeline(
|
||||
self._guidance_scale = guidance_scale
|
||||
self._clip_skip = clip_skip
|
||||
self._cross_attention_kwargs = cross_attention_kwargs
|
||||
self._interrupt = False
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
@@ -1046,9 +1041,6 @@ class StableDiffusionImg2ImgPipeline(
|
||||
self._num_timesteps = len(timesteps)
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
@@ -958,10 +958,6 @@ class StableDiffusionInpaintPipeline(
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
@@ -1148,7 +1144,6 @@ class StableDiffusionInpaintPipeline(
|
||||
self._guidance_scale = guidance_scale
|
||||
self._clip_skip = clip_skip
|
||||
self._cross_attention_kwargs = cross_attention_kwargs
|
||||
self._interrupt = False
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
@@ -1293,9 +1288,6 @@ class StableDiffusionInpaintPipeline(
|
||||
self._num_timesteps = len(timesteps)
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
||||
|
||||
|
||||
@@ -37,7 +37,7 @@ from ...utils import (
|
||||
)
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
from .safety_checker import StableDiffusionSafetyChecker
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
@@ -33,8 +33,8 @@ from ...utils import (
|
||||
)
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from ..stable_diffusion import StableDiffusionPipelineOutput
|
||||
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
from . import StableDiffusionPipelineOutput
|
||||
from .safety_checker import StableDiffusionSafetyChecker
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
@@ -1,48 +0,0 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import (
|
||||
DIFFUSERS_SLOW_IMPORT,
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
get_objects_from_module,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
|
||||
|
||||
_dummy_objects = {}
|
||||
_import_structure = {}
|
||||
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils import dummy_torch_and_transformers_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["pipeline_stable_diffusion_ldm3d"] = ["StableDiffusionLDM3DPipeline"]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import *
|
||||
else:
|
||||
from .pipeline_stable_diffusion_ldm3d import StableDiffusionLDM3DPipeline
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(
|
||||
__name__,
|
||||
globals()["__file__"],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
)
|
||||
|
||||
for name, value in _dummy_objects.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
@@ -1,48 +0,0 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import (
|
||||
DIFFUSERS_SLOW_IMPORT,
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
get_objects_from_module,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
|
||||
|
||||
_dummy_objects = {}
|
||||
_import_structure = {}
|
||||
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils import dummy_torch_and_transformers_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["pipeline_stable_diffusion_panorama"] = ["StableDiffusionPanoramaPipeline"]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import *
|
||||
else:
|
||||
from .pipeline_stable_diffusion_panorama import StableDiffusionPanoramaPipeline
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(
|
||||
__name__,
|
||||
globals()["__file__"],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
)
|
||||
|
||||
for name, value in _dummy_objects.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
@@ -849,10 +849,6 @@ class StableDiffusionXLPipeline(
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
@@ -1071,7 +1067,6 @@ class StableDiffusionXLPipeline(
|
||||
self._clip_skip = clip_skip
|
||||
self._cross_attention_kwargs = cross_attention_kwargs
|
||||
self._denoising_end = denoising_end
|
||||
self._interrupt = False
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
@@ -1201,9 +1196,6 @@ class StableDiffusionXLPipeline(
|
||||
self._num_timesteps = len(timesteps)
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
||||
|
||||
|
||||
@@ -990,10 +990,6 @@ class StableDiffusionXLImg2ImgPipeline(
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
@@ -1225,7 +1221,6 @@ class StableDiffusionXLImg2ImgPipeline(
|
||||
self._cross_attention_kwargs = cross_attention_kwargs
|
||||
self._denoising_end = denoising_end
|
||||
self._denoising_start = denoising_start
|
||||
self._interrupt = False
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
@@ -1381,9 +1376,6 @@ class StableDiffusionXLImg2ImgPipeline(
|
||||
self._num_timesteps = len(timesteps)
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
||||
|
||||
|
||||
@@ -1210,10 +1210,6 @@ class StableDiffusionXLInpaintPipeline(
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
@@ -1466,7 +1462,6 @@ class StableDiffusionXLInpaintPipeline(
|
||||
self._cross_attention_kwargs = cross_attention_kwargs
|
||||
self._denoising_end = denoising_end
|
||||
self._denoising_start = denoising_start
|
||||
self._interrupt = False
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
@@ -1689,8 +1684,6 @@ class StableDiffusionXLInpaintPipeline(
|
||||
self._num_timesteps = len(timesteps)
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
||||
|
||||
|
||||
@@ -39,7 +39,6 @@ except OptionalDependencyNotAvailable:
|
||||
|
||||
else:
|
||||
_import_structure["deprecated"] = ["KarrasVeScheduler", "ScoreSdeVpScheduler"]
|
||||
_import_structure["scheduling_amused"] = ["AmusedScheduler"]
|
||||
_import_structure["scheduling_consistency_decoder"] = ["ConsistencyDecoderScheduler"]
|
||||
_import_structure["scheduling_consistency_models"] = ["CMStochasticIterativeScheduler"]
|
||||
_import_structure["scheduling_ddim"] = ["DDIMScheduler"]
|
||||
@@ -130,7 +129,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from ..utils.dummy_pt_objects import * # noqa F403
|
||||
else:
|
||||
from .deprecated import KarrasVeScheduler, ScoreSdeVpScheduler
|
||||
from .scheduling_amused import AmusedScheduler
|
||||
from .scheduling_consistency_decoder import ConsistencyDecoderScheduler
|
||||
from .scheduling_consistency_models import CMStochasticIterativeScheduler
|
||||
from .scheduling_ddim import DDIMScheduler
|
||||
|
||||
@@ -1,162 +0,0 @@
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import BaseOutput
|
||||
from .scheduling_utils import SchedulerMixin
|
||||
|
||||
|
||||
def gumbel_noise(t, generator=None):
|
||||
device = generator.device if generator is not None else t.device
|
||||
noise = torch.zeros_like(t, device=device).uniform_(0, 1, generator=generator).to(t.device)
|
||||
return -torch.log((-torch.log(noise.clamp(1e-20))).clamp(1e-20))
|
||||
|
||||
|
||||
def mask_by_random_topk(mask_len, probs, temperature=1.0, generator=None):
|
||||
confidence = torch.log(probs.clamp(1e-20)) + temperature * gumbel_noise(probs, generator=generator)
|
||||
sorted_confidence = torch.sort(confidence, dim=-1).values
|
||||
cut_off = torch.gather(sorted_confidence, 1, mask_len.long())
|
||||
masking = confidence < cut_off
|
||||
return masking
|
||||
|
||||
|
||||
@dataclass
|
||||
class AmusedSchedulerOutput(BaseOutput):
|
||||
"""
|
||||
Output class for the scheduler's `step` function output.
|
||||
|
||||
Args:
|
||||
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
|
||||
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
|
||||
denoising loop.
|
||||
pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
|
||||
The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
|
||||
`pred_original_sample` can be used to preview progress or for guidance.
|
||||
"""
|
||||
|
||||
prev_sample: torch.FloatTensor
|
||||
pred_original_sample: torch.FloatTensor = None
|
||||
|
||||
|
||||
class AmusedScheduler(SchedulerMixin, ConfigMixin):
|
||||
order = 1
|
||||
|
||||
temperatures: torch.Tensor
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
mask_token_id: int,
|
||||
masking_schedule: str = "cosine",
|
||||
):
|
||||
self.temperatures = None
|
||||
self.timesteps = None
|
||||
|
||||
def set_timesteps(
|
||||
self,
|
||||
num_inference_steps: int,
|
||||
temperature: Union[int, Tuple[int, int], List[int]] = (2, 0),
|
||||
device: Union[str, torch.device] = None,
|
||||
):
|
||||
self.timesteps = torch.arange(num_inference_steps, device=device).flip(0)
|
||||
|
||||
if isinstance(temperature, (tuple, list)):
|
||||
self.temperatures = torch.linspace(temperature[0], temperature[1], num_inference_steps, device=device)
|
||||
else:
|
||||
self.temperatures = torch.linspace(temperature, 0.01, num_inference_steps, device=device)
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: torch.FloatTensor,
|
||||
timestep: torch.long,
|
||||
sample: torch.LongTensor,
|
||||
starting_mask_ratio: int = 1,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[AmusedSchedulerOutput, Tuple]:
|
||||
two_dim_input = sample.ndim == 3 and model_output.ndim == 4
|
||||
|
||||
if two_dim_input:
|
||||
batch_size, codebook_size, height, width = model_output.shape
|
||||
sample = sample.reshape(batch_size, height * width)
|
||||
model_output = model_output.reshape(batch_size, codebook_size, height * width).permute(0, 2, 1)
|
||||
|
||||
unknown_map = sample == self.config.mask_token_id
|
||||
|
||||
probs = model_output.softmax(dim=-1)
|
||||
|
||||
device = probs.device
|
||||
probs_ = probs.to(generator.device) if generator is not None else probs # handles when generator is on CPU
|
||||
if probs_.device.type == "cpu" and probs_.dtype != torch.float32:
|
||||
probs_ = probs_.float() # multinomial is not implemented for cpu half precision
|
||||
probs_ = probs_.reshape(-1, probs.size(-1))
|
||||
pred_original_sample = torch.multinomial(probs_, 1, generator=generator).to(device=device)
|
||||
pred_original_sample = pred_original_sample[:, 0].view(*probs.shape[:-1])
|
||||
pred_original_sample = torch.where(unknown_map, pred_original_sample, sample)
|
||||
|
||||
if timestep == 0:
|
||||
prev_sample = pred_original_sample
|
||||
else:
|
||||
seq_len = sample.shape[1]
|
||||
step_idx = (self.timesteps == timestep).nonzero()
|
||||
ratio = (step_idx + 1) / len(self.timesteps)
|
||||
|
||||
if self.config.masking_schedule == "cosine":
|
||||
mask_ratio = torch.cos(ratio * math.pi / 2)
|
||||
elif self.config.masking_schedule == "linear":
|
||||
mask_ratio = 1 - ratio
|
||||
else:
|
||||
raise ValueError(f"unknown masking schedule {self.config.masking_schedule}")
|
||||
|
||||
mask_ratio = starting_mask_ratio * mask_ratio
|
||||
|
||||
mask_len = (seq_len * mask_ratio).floor()
|
||||
# do not mask more than amount previously masked
|
||||
mask_len = torch.min(unknown_map.sum(dim=-1, keepdim=True) - 1, mask_len)
|
||||
# mask at least one
|
||||
mask_len = torch.max(torch.tensor([1], device=model_output.device), mask_len)
|
||||
|
||||
selected_probs = torch.gather(probs, -1, pred_original_sample[:, :, None])[:, :, 0]
|
||||
# Ignores the tokens given in the input by overwriting their confidence.
|
||||
selected_probs = torch.where(unknown_map, selected_probs, torch.finfo(selected_probs.dtype).max)
|
||||
|
||||
masking = mask_by_random_topk(mask_len, selected_probs, self.temperatures[step_idx], generator)
|
||||
|
||||
# Masks tokens with lower confidence.
|
||||
prev_sample = torch.where(masking, self.config.mask_token_id, pred_original_sample)
|
||||
|
||||
if two_dim_input:
|
||||
prev_sample = prev_sample.reshape(batch_size, height, width)
|
||||
pred_original_sample = pred_original_sample.reshape(batch_size, height, width)
|
||||
|
||||
if not return_dict:
|
||||
return (prev_sample, pred_original_sample)
|
||||
|
||||
return AmusedSchedulerOutput(prev_sample, pred_original_sample)
|
||||
|
||||
def add_noise(self, sample, timesteps, generator=None):
|
||||
step_idx = (self.timesteps == timesteps).nonzero()
|
||||
ratio = (step_idx + 1) / len(self.timesteps)
|
||||
|
||||
if self.config.masking_schedule == "cosine":
|
||||
mask_ratio = torch.cos(ratio * math.pi / 2)
|
||||
elif self.config.masking_schedule == "linear":
|
||||
mask_ratio = 1 - ratio
|
||||
else:
|
||||
raise ValueError(f"unknown masking schedule {self.config.masking_schedule}")
|
||||
|
||||
mask_indices = (
|
||||
torch.rand(
|
||||
sample.shape, device=generator.device if generator is not None else sample.device, generator=generator
|
||||
).to(sample.device)
|
||||
< mask_ratio
|
||||
)
|
||||
|
||||
masked_sample = sample.clone()
|
||||
|
||||
masked_sample[mask_indices] = self.config.mask_token_id
|
||||
|
||||
return masked_sample
|
||||
@@ -317,21 +317,6 @@ class UNetSpatioTemporalConditionModel(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class UVit2DModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class VQModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
@@ -675,21 +660,6 @@ class ScoreSdeVePipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class AmusedScheduler(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class CMStochasticIterativeScheduler(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
||||
@@ -32,51 +32,6 @@ class AltDiffusionPipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class AmusedImg2ImgPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class AmusedInpaintPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class AmusedPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class AnimateDiffPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
|
||||
@@ -89,7 +89,7 @@ def is_compiled_module(module) -> bool:
|
||||
return isinstance(module, torch._dynamo.eval_frame.OptimizedModule)
|
||||
|
||||
|
||||
def fourier_filter(x_in: "torch.Tensor", threshold: int, scale: int) -> "torch.Tensor":
|
||||
def fourier_filter(x_in: torch.Tensor, threshold: int, scale: int) -> torch.Tensor:
|
||||
"""Fourier filter as introduced in FreeU (https://arxiv.org/abs/2309.11497).
|
||||
|
||||
This version of the method comes from here:
|
||||
@@ -121,8 +121,8 @@ def fourier_filter(x_in: "torch.Tensor", threshold: int, scale: int) -> "torch.T
|
||||
|
||||
|
||||
def apply_freeu(
|
||||
resolution_idx: int, hidden_states: "torch.Tensor", res_hidden_states: "torch.Tensor", **freeu_kwargs
|
||||
) -> Tuple["torch.Tensor", "torch.Tensor"]:
|
||||
resolution_idx: int, hidden_states: torch.Tensor, res_hidden_states: torch.Tensor, **freeu_kwargs
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Applies the FreeU mechanism as introduced in https:
|
||||
//arxiv.org/abs/2309.11497. Adapted from the official code repository: https://github.com/ChenyangSi/FreeU.
|
||||
|
||||
|
||||
@@ -111,16 +111,12 @@ class PeftLoraLoaderMixinTests:
|
||||
|
||||
def get_dummy_components(self, scheduler_cls=None):
|
||||
scheduler_cls = self.scheduler_cls if scheduler_cls is None else LCMScheduler
|
||||
rank = 4
|
||||
|
||||
torch.manual_seed(0)
|
||||
unet = UNet2DConditionModel(**self.unet_kwargs)
|
||||
|
||||
scheduler = scheduler_cls(**self.scheduler_kwargs)
|
||||
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKL(**self.vae_kwargs)
|
||||
|
||||
text_encoder = CLIPTextModel.from_pretrained("peft-internal-testing/tiny-clip-text-2")
|
||||
tokenizer = CLIPTokenizer.from_pretrained("peft-internal-testing/tiny-clip-text-2")
|
||||
|
||||
@@ -129,14 +125,11 @@ class PeftLoraLoaderMixinTests:
|
||||
tokenizer_2 = CLIPTokenizer.from_pretrained("peft-internal-testing/tiny-clip-text-2")
|
||||
|
||||
text_lora_config = LoraConfig(
|
||||
r=rank,
|
||||
lora_alpha=rank,
|
||||
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
|
||||
init_lora_weights=False,
|
||||
r=4, lora_alpha=4, target_modules=["q_proj", "k_proj", "v_proj", "out_proj"], init_lora_weights=False
|
||||
)
|
||||
|
||||
unet_lora_config = LoraConfig(
|
||||
r=rank, lora_alpha=rank, target_modules=["to_q", "to_k", "to_v", "to_out.0"], init_lora_weights=False
|
||||
r=4, lora_alpha=4, target_modules=["to_q", "to_k", "to_v", "to_out.0"], init_lora_weights=False
|
||||
)
|
||||
|
||||
unet_lora_attn_procs, unet_lora_layers = create_unet_lora_layers(unet)
|
||||
@@ -1404,36 +1397,7 @@ class StableDiffusionXLLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
class LoraIntegrationTests(PeftLoraLoaderMixinTests, unittest.TestCase):
|
||||
pipeline_class = StableDiffusionPipeline
|
||||
scheduler_cls = DDIMScheduler
|
||||
scheduler_kwargs = {
|
||||
"beta_start": 0.00085,
|
||||
"beta_end": 0.012,
|
||||
"beta_schedule": "scaled_linear",
|
||||
"clip_sample": False,
|
||||
"set_alpha_to_one": False,
|
||||
"steps_offset": 1,
|
||||
}
|
||||
unet_kwargs = {
|
||||
"block_out_channels": (32, 64),
|
||||
"layers_per_block": 2,
|
||||
"sample_size": 32,
|
||||
"in_channels": 4,
|
||||
"out_channels": 4,
|
||||
"down_block_types": ("DownBlock2D", "CrossAttnDownBlock2D"),
|
||||
"up_block_types": ("CrossAttnUpBlock2D", "UpBlock2D"),
|
||||
"cross_attention_dim": 32,
|
||||
}
|
||||
vae_kwargs = {
|
||||
"block_out_channels": [32, 64],
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
"down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"],
|
||||
"up_block_types": ["UpDecoderBlock2D", "UpDecoderBlock2D"],
|
||||
"latent_channels": 4,
|
||||
}
|
||||
|
||||
class LoraIntegrationTests(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
import gc
|
||||
|
||||
@@ -1686,43 +1650,7 @@ class LoraIntegrationTests(PeftLoraLoaderMixinTests, unittest.TestCase):
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
class LoraSDXLIntegrationTests(PeftLoraLoaderMixinTests, unittest.TestCase):
|
||||
has_two_text_encoders = True
|
||||
pipeline_class = StableDiffusionXLPipeline
|
||||
scheduler_cls = EulerDiscreteScheduler
|
||||
scheduler_kwargs = {
|
||||
"beta_start": 0.00085,
|
||||
"beta_end": 0.012,
|
||||
"beta_schedule": "scaled_linear",
|
||||
"timestep_spacing": "leading",
|
||||
"steps_offset": 1,
|
||||
}
|
||||
unet_kwargs = {
|
||||
"block_out_channels": (32, 64),
|
||||
"layers_per_block": 2,
|
||||
"sample_size": 32,
|
||||
"in_channels": 4,
|
||||
"out_channels": 4,
|
||||
"down_block_types": ("DownBlock2D", "CrossAttnDownBlock2D"),
|
||||
"up_block_types": ("CrossAttnUpBlock2D", "UpBlock2D"),
|
||||
"attention_head_dim": (2, 4),
|
||||
"use_linear_projection": True,
|
||||
"addition_embed_type": "text_time",
|
||||
"addition_time_embed_dim": 8,
|
||||
"transformer_layers_per_block": (1, 2),
|
||||
"projection_class_embeddings_input_dim": 80, # 6 * 8 + 32
|
||||
"cross_attention_dim": 64,
|
||||
}
|
||||
vae_kwargs = {
|
||||
"block_out_channels": [32, 64],
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
"down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"],
|
||||
"up_block_types": ["UpDecoderBlock2D", "UpDecoderBlock2D"],
|
||||
"latent_channels": 4,
|
||||
"sample_size": 128,
|
||||
}
|
||||
|
||||
class LoraSDXLIntegrationTests(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
import gc
|
||||
|
||||
@@ -1949,9 +1877,7 @@ class LoraSDXLIntegrationTests(PeftLoraLoaderMixinTests, unittest.TestCase):
|
||||
).images
|
||||
images_without_fusion = images.flatten()
|
||||
|
||||
max_diff = numpy_cosine_similarity_distance(images_with_fusion, images_without_fusion)
|
||||
assert max_diff < 1e-4
|
||||
|
||||
self.assertTrue(np.allclose(images_with_fusion, images_without_fusion, atol=1e-3))
|
||||
release_memory(pipe)
|
||||
|
||||
def test_sdxl_1_0_lora_unfusion_effectivity(self):
|
||||
|
||||
@@ -1,181 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer
|
||||
|
||||
from diffusers import AmusedPipeline, AmusedScheduler, UVit2DModel, VQModel
|
||||
from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu, slow, torch_device
|
||||
|
||||
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS
|
||||
from ..test_pipelines_common import PipelineTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class AmusedPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = AmusedPipeline
|
||||
params = TEXT_TO_IMAGE_PARAMS | {"encoder_hidden_states", "negative_encoder_hidden_states"}
|
||||
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
transformer = UVit2DModel(
|
||||
hidden_size=32,
|
||||
use_bias=False,
|
||||
hidden_dropout=0.0,
|
||||
cond_embed_dim=32,
|
||||
micro_cond_encode_dim=2,
|
||||
micro_cond_embed_dim=10,
|
||||
encoder_hidden_size=32,
|
||||
vocab_size=32,
|
||||
codebook_size=32,
|
||||
in_channels=32,
|
||||
block_out_channels=32,
|
||||
num_res_blocks=1,
|
||||
downsample=True,
|
||||
upsample=True,
|
||||
block_num_heads=1,
|
||||
num_hidden_layers=1,
|
||||
num_attention_heads=1,
|
||||
attention_dropout=0.0,
|
||||
intermediate_size=32,
|
||||
layer_norm_eps=1e-06,
|
||||
ln_elementwise_affine=True,
|
||||
)
|
||||
scheduler = AmusedScheduler(mask_token_id=31)
|
||||
torch.manual_seed(0)
|
||||
vqvae = VQModel(
|
||||
act_fn="silu",
|
||||
block_out_channels=[32],
|
||||
down_block_types=[
|
||||
"DownEncoderBlock2D",
|
||||
],
|
||||
in_channels=3,
|
||||
latent_channels=32,
|
||||
layers_per_block=2,
|
||||
norm_num_groups=32,
|
||||
num_vq_embeddings=32,
|
||||
out_channels=3,
|
||||
sample_size=32,
|
||||
up_block_types=[
|
||||
"UpDecoderBlock2D",
|
||||
],
|
||||
mid_block_add_attention=False,
|
||||
lookup_from_codebook=True,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
text_encoder_config = CLIPTextConfig(
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
hidden_size=32,
|
||||
intermediate_size=64,
|
||||
layer_norm_eps=1e-05,
|
||||
num_attention_heads=8,
|
||||
num_hidden_layers=3,
|
||||
pad_token_id=1,
|
||||
vocab_size=1000,
|
||||
projection_dim=32,
|
||||
)
|
||||
text_encoder = CLIPTextModelWithProjection(text_encoder_config)
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
components = {
|
||||
"transformer": transformer,
|
||||
"scheduler": scheduler,
|
||||
"vqvae": vqvae,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
}
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
inputs = {
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"output_type": "np",
|
||||
"height": 4,
|
||||
"width": 4,
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_inference_batch_consistent(self, batch_sizes=[2]):
|
||||
self._test_inference_batch_consistent(batch_sizes=batch_sizes, batch_generator=False)
|
||||
|
||||
@unittest.skip("aMUSEd does not support lists of generators")
|
||||
def test_inference_batch_single_identical(self):
|
||||
...
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
class AmusedPipelineSlowTests(unittest.TestCase):
|
||||
def test_amused_256(self):
|
||||
pipe = AmusedPipeline.from_pretrained("huggingface/amused-256")
|
||||
pipe.to(torch_device)
|
||||
|
||||
image = pipe("dog", generator=torch.Generator().manual_seed(0), num_inference_steps=2, output_type="np").images
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1].flatten()
|
||||
|
||||
assert image.shape == (1, 256, 256, 3)
|
||||
expected_slice = np.array([0.4011, 0.3992, 0.3790, 0.3856, 0.3772, 0.3711, 0.3919, 0.3850, 0.3625])
|
||||
assert np.abs(image_slice - expected_slice).max() < 3e-3
|
||||
|
||||
def test_amused_256_fp16(self):
|
||||
pipe = AmusedPipeline.from_pretrained("huggingface/amused-256", variant="fp16", torch_dtype=torch.float16)
|
||||
pipe.to(torch_device)
|
||||
|
||||
image = pipe("dog", generator=torch.Generator().manual_seed(0), num_inference_steps=2, output_type="np").images
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1].flatten()
|
||||
|
||||
assert image.shape == (1, 256, 256, 3)
|
||||
expected_slice = np.array([0.0554, 0.05129, 0.0344, 0.0452, 0.0476, 0.0271, 0.0495, 0.0527, 0.0158])
|
||||
assert np.abs(image_slice - expected_slice).max() < 7e-3
|
||||
|
||||
def test_amused_512(self):
|
||||
pipe = AmusedPipeline.from_pretrained("huggingface/amused-512")
|
||||
pipe.to(torch_device)
|
||||
|
||||
image = pipe("dog", generator=torch.Generator().manual_seed(0), num_inference_steps=2, output_type="np").images
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1].flatten()
|
||||
|
||||
assert image.shape == (1, 512, 512, 3)
|
||||
expected_slice = np.array([0.9960, 0.9960, 0.9946, 0.9980, 0.9947, 0.9932, 0.9960, 0.9961, 0.9947])
|
||||
assert np.abs(image_slice - expected_slice).max() < 3e-3
|
||||
|
||||
def test_amused_512_fp16(self):
|
||||
pipe = AmusedPipeline.from_pretrained("huggingface/amused-512", variant="fp16", torch_dtype=torch.float16)
|
||||
pipe.to(torch_device)
|
||||
|
||||
image = pipe("dog", generator=torch.Generator().manual_seed(0), num_inference_steps=2, output_type="np").images
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1].flatten()
|
||||
|
||||
assert image.shape == (1, 512, 512, 3)
|
||||
expected_slice = np.array([0.9983, 1.0, 1.0, 1.0, 1.0, 0.9989, 0.9994, 0.9976, 0.9977])
|
||||
assert np.abs(image_slice - expected_slice).max() < 3e-3
|
||||
@@ -1,239 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer
|
||||
|
||||
from diffusers import AmusedImg2ImgPipeline, AmusedScheduler, UVit2DModel, VQModel
|
||||
from diffusers.utils import load_image
|
||||
from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu, slow, torch_device
|
||||
|
||||
from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS
|
||||
from ..test_pipelines_common import PipelineTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class AmusedImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = AmusedImg2ImgPipeline
|
||||
params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"height", "width", "latents"}
|
||||
batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS
|
||||
required_optional_params = PipelineTesterMixin.required_optional_params - {
|
||||
"latents",
|
||||
}
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
transformer = UVit2DModel(
|
||||
hidden_size=32,
|
||||
use_bias=False,
|
||||
hidden_dropout=0.0,
|
||||
cond_embed_dim=32,
|
||||
micro_cond_encode_dim=2,
|
||||
micro_cond_embed_dim=10,
|
||||
encoder_hidden_size=32,
|
||||
vocab_size=32,
|
||||
codebook_size=32,
|
||||
in_channels=32,
|
||||
block_out_channels=32,
|
||||
num_res_blocks=1,
|
||||
downsample=True,
|
||||
upsample=True,
|
||||
block_num_heads=1,
|
||||
num_hidden_layers=1,
|
||||
num_attention_heads=1,
|
||||
attention_dropout=0.0,
|
||||
intermediate_size=32,
|
||||
layer_norm_eps=1e-06,
|
||||
ln_elementwise_affine=True,
|
||||
)
|
||||
scheduler = AmusedScheduler(mask_token_id=31)
|
||||
torch.manual_seed(0)
|
||||
vqvae = VQModel(
|
||||
act_fn="silu",
|
||||
block_out_channels=[32],
|
||||
down_block_types=[
|
||||
"DownEncoderBlock2D",
|
||||
],
|
||||
in_channels=3,
|
||||
latent_channels=32,
|
||||
layers_per_block=2,
|
||||
norm_num_groups=32,
|
||||
num_vq_embeddings=32,
|
||||
out_channels=3,
|
||||
sample_size=32,
|
||||
up_block_types=[
|
||||
"UpDecoderBlock2D",
|
||||
],
|
||||
mid_block_add_attention=False,
|
||||
lookup_from_codebook=True,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
text_encoder_config = CLIPTextConfig(
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
hidden_size=32,
|
||||
intermediate_size=64,
|
||||
layer_norm_eps=1e-05,
|
||||
num_attention_heads=8,
|
||||
num_hidden_layers=3,
|
||||
pad_token_id=1,
|
||||
vocab_size=1000,
|
||||
projection_dim=32,
|
||||
)
|
||||
text_encoder = CLIPTextModelWithProjection(text_encoder_config)
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
components = {
|
||||
"transformer": transformer,
|
||||
"scheduler": scheduler,
|
||||
"vqvae": vqvae,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
}
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
image = torch.full((1, 3, 4, 4), 1.0, dtype=torch.float32, device=device)
|
||||
inputs = {
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"output_type": "np",
|
||||
"image": image,
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_inference_batch_consistent(self, batch_sizes=[2]):
|
||||
self._test_inference_batch_consistent(batch_sizes=batch_sizes, batch_generator=False)
|
||||
|
||||
@unittest.skip("aMUSEd does not support lists of generators")
|
||||
def test_inference_batch_single_identical(self):
|
||||
...
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
class AmusedImg2ImgPipelineSlowTests(unittest.TestCase):
|
||||
def test_amused_256(self):
|
||||
pipe = AmusedImg2ImgPipeline.from_pretrained("huggingface/amused-256")
|
||||
pipe.to(torch_device)
|
||||
|
||||
image = (
|
||||
load_image("https://huggingface.co/datasets/diffusers/docs-images/resolve/main/open_muse/mountains.jpg")
|
||||
.resize((256, 256))
|
||||
.convert("RGB")
|
||||
)
|
||||
|
||||
image = pipe(
|
||||
"winter mountains",
|
||||
image,
|
||||
generator=torch.Generator().manual_seed(0),
|
||||
num_inference_steps=2,
|
||||
output_type="np",
|
||||
).images
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1].flatten()
|
||||
|
||||
assert image.shape == (1, 256, 256, 3)
|
||||
expected_slice = np.array([0.9993, 1.0, 0.9996, 1.0, 0.9995, 0.9925, 0.9990, 0.9954, 1.0])
|
||||
|
||||
assert np.abs(image_slice - expected_slice).max() < 1e-2
|
||||
|
||||
def test_amused_256_fp16(self):
|
||||
pipe = AmusedImg2ImgPipeline.from_pretrained(
|
||||
"huggingface/amused-256", torch_dtype=torch.float16, variant="fp16"
|
||||
)
|
||||
pipe.to(torch_device)
|
||||
|
||||
image = (
|
||||
load_image("https://huggingface.co/datasets/diffusers/docs-images/resolve/main/open_muse/mountains.jpg")
|
||||
.resize((256, 256))
|
||||
.convert("RGB")
|
||||
)
|
||||
|
||||
image = pipe(
|
||||
"winter mountains",
|
||||
image,
|
||||
generator=torch.Generator().manual_seed(0),
|
||||
num_inference_steps=2,
|
||||
output_type="np",
|
||||
).images
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1].flatten()
|
||||
|
||||
assert image.shape == (1, 256, 256, 3)
|
||||
expected_slice = np.array([0.9980, 0.9980, 0.9940, 0.9944, 0.9960, 0.9908, 1.0, 1.0, 0.9986])
|
||||
|
||||
assert np.abs(image_slice - expected_slice).max() < 1e-2
|
||||
|
||||
def test_amused_512(self):
|
||||
pipe = AmusedImg2ImgPipeline.from_pretrained("huggingface/amused-512")
|
||||
pipe.to(torch_device)
|
||||
|
||||
image = (
|
||||
load_image("https://huggingface.co/datasets/diffusers/docs-images/resolve/main/open_muse/mountains.jpg")
|
||||
.resize((512, 512))
|
||||
.convert("RGB")
|
||||
)
|
||||
|
||||
image = pipe(
|
||||
"winter mountains",
|
||||
image,
|
||||
generator=torch.Generator().manual_seed(0),
|
||||
num_inference_steps=2,
|
||||
output_type="np",
|
||||
).images
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1].flatten()
|
||||
|
||||
assert image.shape == (1, 512, 512, 3)
|
||||
expected_slice = np.array([0.1344, 0.0985, 0.0, 0.1194, 0.1809, 0.0765, 0.0854, 0.1371, 0.0933])
|
||||
assert np.abs(image_slice - expected_slice).max() < 0.1
|
||||
|
||||
def test_amused_512_fp16(self):
|
||||
pipe = AmusedImg2ImgPipeline.from_pretrained(
|
||||
"huggingface/amused-512", variant="fp16", torch_dtype=torch.float16
|
||||
)
|
||||
pipe.to(torch_device)
|
||||
|
||||
image = (
|
||||
load_image("https://huggingface.co/datasets/diffusers/docs-images/resolve/main/open_muse/mountains.jpg")
|
||||
.resize((512, 512))
|
||||
.convert("RGB")
|
||||
)
|
||||
|
||||
image = pipe(
|
||||
"winter mountains",
|
||||
image,
|
||||
generator=torch.Generator().manual_seed(0),
|
||||
num_inference_steps=2,
|
||||
output_type="np",
|
||||
).images
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1].flatten()
|
||||
|
||||
assert image.shape == (1, 512, 512, 3)
|
||||
expected_slice = np.array([0.1536, 0.1767, 0.0227, 0.1079, 0.2400, 0.1427, 0.1511, 0.1564, 0.1542])
|
||||
assert np.abs(image_slice - expected_slice).max() < 0.1
|
||||
@@ -1,277 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer
|
||||
|
||||
from diffusers import AmusedInpaintPipeline, AmusedScheduler, UVit2DModel, VQModel
|
||||
from diffusers.utils import load_image
|
||||
from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu, slow, torch_device
|
||||
|
||||
from ..pipeline_params import TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS, TEXT_GUIDED_IMAGE_INPAINTING_PARAMS
|
||||
from ..test_pipelines_common import PipelineTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class AmusedInpaintPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = AmusedInpaintPipeline
|
||||
params = TEXT_GUIDED_IMAGE_INPAINTING_PARAMS - {"width", "height"}
|
||||
batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS
|
||||
required_optional_params = PipelineTesterMixin.required_optional_params - {
|
||||
"latents",
|
||||
}
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
transformer = UVit2DModel(
|
||||
hidden_size=32,
|
||||
use_bias=False,
|
||||
hidden_dropout=0.0,
|
||||
cond_embed_dim=32,
|
||||
micro_cond_encode_dim=2,
|
||||
micro_cond_embed_dim=10,
|
||||
encoder_hidden_size=32,
|
||||
vocab_size=32,
|
||||
codebook_size=32,
|
||||
in_channels=32,
|
||||
block_out_channels=32,
|
||||
num_res_blocks=1,
|
||||
downsample=True,
|
||||
upsample=True,
|
||||
block_num_heads=1,
|
||||
num_hidden_layers=1,
|
||||
num_attention_heads=1,
|
||||
attention_dropout=0.0,
|
||||
intermediate_size=32,
|
||||
layer_norm_eps=1e-06,
|
||||
ln_elementwise_affine=True,
|
||||
)
|
||||
scheduler = AmusedScheduler(mask_token_id=31)
|
||||
torch.manual_seed(0)
|
||||
vqvae = VQModel(
|
||||
act_fn="silu",
|
||||
block_out_channels=[32],
|
||||
down_block_types=[
|
||||
"DownEncoderBlock2D",
|
||||
],
|
||||
in_channels=3,
|
||||
latent_channels=32,
|
||||
layers_per_block=2,
|
||||
norm_num_groups=32,
|
||||
num_vq_embeddings=32,
|
||||
out_channels=3,
|
||||
sample_size=32,
|
||||
up_block_types=[
|
||||
"UpDecoderBlock2D",
|
||||
],
|
||||
mid_block_add_attention=False,
|
||||
lookup_from_codebook=True,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
text_encoder_config = CLIPTextConfig(
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
hidden_size=32,
|
||||
intermediate_size=64,
|
||||
layer_norm_eps=1e-05,
|
||||
num_attention_heads=8,
|
||||
num_hidden_layers=3,
|
||||
pad_token_id=1,
|
||||
vocab_size=1000,
|
||||
projection_dim=32,
|
||||
)
|
||||
text_encoder = CLIPTextModelWithProjection(text_encoder_config)
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
components = {
|
||||
"transformer": transformer,
|
||||
"scheduler": scheduler,
|
||||
"vqvae": vqvae,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
}
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
image = torch.full((1, 3, 4, 4), 1.0, dtype=torch.float32, device=device)
|
||||
mask_image = torch.full((1, 1, 4, 4), 1.0, dtype=torch.float32, device=device)
|
||||
mask_image[0, 0, 0, 0] = 0
|
||||
mask_image[0, 0, 0, 1] = 0
|
||||
inputs = {
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"output_type": "np",
|
||||
"image": image,
|
||||
"mask_image": mask_image,
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_inference_batch_consistent(self, batch_sizes=[2]):
|
||||
self._test_inference_batch_consistent(batch_sizes=batch_sizes, batch_generator=False)
|
||||
|
||||
@unittest.skip("aMUSEd does not support lists of generators")
|
||||
def test_inference_batch_single_identical(self):
|
||||
...
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
class AmusedInpaintPipelineSlowTests(unittest.TestCase):
|
||||
def test_amused_256(self):
|
||||
pipe = AmusedInpaintPipeline.from_pretrained("huggingface/amused-256")
|
||||
pipe.to(torch_device)
|
||||
|
||||
image = (
|
||||
load_image("https://huggingface.co/datasets/diffusers/docs-images/resolve/main/open_muse/mountains_1.jpg")
|
||||
.resize((256, 256))
|
||||
.convert("RGB")
|
||||
)
|
||||
|
||||
mask_image = (
|
||||
load_image(
|
||||
"https://huggingface.co/datasets/diffusers/docs-images/resolve/main/open_muse/mountains_1_mask.png"
|
||||
)
|
||||
.resize((256, 256))
|
||||
.convert("L")
|
||||
)
|
||||
|
||||
image = pipe(
|
||||
"winter mountains",
|
||||
image,
|
||||
mask_image,
|
||||
generator=torch.Generator().manual_seed(0),
|
||||
num_inference_steps=2,
|
||||
output_type="np",
|
||||
).images
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1].flatten()
|
||||
|
||||
assert image.shape == (1, 256, 256, 3)
|
||||
expected_slice = np.array([0.0699, 0.0716, 0.0608, 0.0715, 0.0797, 0.0638, 0.0802, 0.0924, 0.0634])
|
||||
assert np.abs(image_slice - expected_slice).max() < 0.1
|
||||
|
||||
def test_amused_256_fp16(self):
|
||||
pipe = AmusedInpaintPipeline.from_pretrained(
|
||||
"huggingface/amused-256", variant="fp16", torch_dtype=torch.float16
|
||||
)
|
||||
pipe.to(torch_device)
|
||||
|
||||
image = (
|
||||
load_image("https://huggingface.co/datasets/diffusers/docs-images/resolve/main/open_muse/mountains_1.jpg")
|
||||
.resize((256, 256))
|
||||
.convert("RGB")
|
||||
)
|
||||
|
||||
mask_image = (
|
||||
load_image(
|
||||
"https://huggingface.co/datasets/diffusers/docs-images/resolve/main/open_muse/mountains_1_mask.png"
|
||||
)
|
||||
.resize((256, 256))
|
||||
.convert("L")
|
||||
)
|
||||
|
||||
image = pipe(
|
||||
"winter mountains",
|
||||
image,
|
||||
mask_image,
|
||||
generator=torch.Generator().manual_seed(0),
|
||||
num_inference_steps=2,
|
||||
output_type="np",
|
||||
).images
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1].flatten()
|
||||
|
||||
assert image.shape == (1, 256, 256, 3)
|
||||
expected_slice = np.array([0.0735, 0.0749, 0.0650, 0.0739, 0.0805, 0.0667, 0.0802, 0.0923, 0.0622])
|
||||
assert np.abs(image_slice - expected_slice).max() < 0.1
|
||||
|
||||
def test_amused_512(self):
|
||||
pipe = AmusedInpaintPipeline.from_pretrained("huggingface/amused-512")
|
||||
pipe.to(torch_device)
|
||||
|
||||
image = (
|
||||
load_image("https://huggingface.co/datasets/diffusers/docs-images/resolve/main/open_muse/mountains_1.jpg")
|
||||
.resize((512, 512))
|
||||
.convert("RGB")
|
||||
)
|
||||
|
||||
mask_image = (
|
||||
load_image(
|
||||
"https://huggingface.co/datasets/diffusers/docs-images/resolve/main/open_muse/mountains_1_mask.png"
|
||||
)
|
||||
.resize((512, 512))
|
||||
.convert("L")
|
||||
)
|
||||
|
||||
image = pipe(
|
||||
"winter mountains",
|
||||
image,
|
||||
mask_image,
|
||||
generator=torch.Generator().manual_seed(0),
|
||||
num_inference_steps=2,
|
||||
output_type="np",
|
||||
).images
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1].flatten()
|
||||
|
||||
assert image.shape == (1, 512, 512, 3)
|
||||
expected_slice = np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0005, 0.0])
|
||||
assert np.abs(image_slice - expected_slice).max() < 0.05
|
||||
|
||||
def test_amused_512_fp16(self):
|
||||
pipe = AmusedInpaintPipeline.from_pretrained(
|
||||
"huggingface/amused-512", variant="fp16", torch_dtype=torch.float16
|
||||
)
|
||||
pipe.to(torch_device)
|
||||
|
||||
image = (
|
||||
load_image("https://huggingface.co/datasets/diffusers/docs-images/resolve/main/open_muse/mountains_1.jpg")
|
||||
.resize((512, 512))
|
||||
.convert("RGB")
|
||||
)
|
||||
|
||||
mask_image = (
|
||||
load_image(
|
||||
"https://huggingface.co/datasets/diffusers/docs-images/resolve/main/open_muse/mountains_1_mask.png"
|
||||
)
|
||||
.resize((512, 512))
|
||||
.convert("L")
|
||||
)
|
||||
|
||||
image = pipe(
|
||||
"winter mountains",
|
||||
image,
|
||||
mask_image,
|
||||
generator=torch.Generator().manual_seed(0),
|
||||
num_inference_steps=2,
|
||||
output_type="np",
|
||||
).images
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1].flatten()
|
||||
|
||||
assert image.shape == (1, 512, 512, 3)
|
||||
expected_slice = np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0025, 0.0])
|
||||
assert np.abs(image_slice - expected_slice).max() < 3e-3
|
||||
@@ -692,58 +692,6 @@ class StableDiffusionPipelineFastTests(
|
||||
original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
|
||||
), "Original outputs should match when fused QKV projections are disabled."
|
||||
|
||||
def test_pipeline_interrupt(self):
|
||||
components = self.get_dummy_components()
|
||||
sd_pipe = StableDiffusionPipeline(**components)
|
||||
sd_pipe = sd_pipe.to(torch_device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "hey"
|
||||
num_inference_steps = 3
|
||||
|
||||
# store intermediate latents from the generation process
|
||||
class PipelineState:
|
||||
def __init__(self):
|
||||
self.state = []
|
||||
|
||||
def apply(self, pipe, i, t, callback_kwargs):
|
||||
self.state.append(callback_kwargs["latents"])
|
||||
return callback_kwargs
|
||||
|
||||
pipe_state = PipelineState()
|
||||
sd_pipe(
|
||||
prompt,
|
||||
num_inference_steps=num_inference_steps,
|
||||
output_type="np",
|
||||
generator=torch.Generator("cpu").manual_seed(0),
|
||||
callback_on_step_end=pipe_state.apply,
|
||||
).images
|
||||
|
||||
# interrupt generation at step index
|
||||
interrupt_step_idx = 1
|
||||
|
||||
def callback_on_step_end(pipe, i, t, callback_kwargs):
|
||||
if i == interrupt_step_idx:
|
||||
pipe._interrupt = True
|
||||
|
||||
return callback_kwargs
|
||||
|
||||
output_interrupted = sd_pipe(
|
||||
prompt,
|
||||
num_inference_steps=num_inference_steps,
|
||||
output_type="latent",
|
||||
generator=torch.Generator("cpu").manual_seed(0),
|
||||
callback_on_step_end=callback_on_step_end,
|
||||
).images
|
||||
|
||||
# fetch intermediate latents at the interrupted step
|
||||
# from the completed generation process
|
||||
intermediate_latent = pipe_state.state[interrupt_step_idx]
|
||||
|
||||
# compare the intermediate latent to the output of the interrupted process
|
||||
# they should be the same
|
||||
assert torch.allclose(intermediate_latent, output_interrupted, atol=1e-4)
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
|
||||
@@ -320,62 +320,6 @@ class StableDiffusionImg2ImgPipelineFastTests(
|
||||
def test_float16_inference(self):
|
||||
super().test_float16_inference(expected_max_diff=5e-1)
|
||||
|
||||
def test_pipeline_interrupt(self):
|
||||
components = self.get_dummy_components()
|
||||
sd_pipe = StableDiffusionImg2ImgPipeline(**components)
|
||||
sd_pipe = sd_pipe.to(torch_device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
|
||||
prompt = "hey"
|
||||
num_inference_steps = 3
|
||||
|
||||
# store intermediate latents from the generation process
|
||||
class PipelineState:
|
||||
def __init__(self):
|
||||
self.state = []
|
||||
|
||||
def apply(self, pipe, i, t, callback_kwargs):
|
||||
self.state.append(callback_kwargs["latents"])
|
||||
return callback_kwargs
|
||||
|
||||
pipe_state = PipelineState()
|
||||
sd_pipe(
|
||||
prompt,
|
||||
image=inputs["image"],
|
||||
num_inference_steps=num_inference_steps,
|
||||
output_type="np",
|
||||
generator=torch.Generator("cpu").manual_seed(0),
|
||||
callback_on_step_end=pipe_state.apply,
|
||||
).images
|
||||
|
||||
# interrupt generation at step index
|
||||
interrupt_step_idx = 1
|
||||
|
||||
def callback_on_step_end(pipe, i, t, callback_kwargs):
|
||||
if i == interrupt_step_idx:
|
||||
pipe._interrupt = True
|
||||
|
||||
return callback_kwargs
|
||||
|
||||
output_interrupted = sd_pipe(
|
||||
prompt,
|
||||
image=inputs["image"],
|
||||
num_inference_steps=num_inference_steps,
|
||||
output_type="latent",
|
||||
generator=torch.Generator("cpu").manual_seed(0),
|
||||
callback_on_step_end=callback_on_step_end,
|
||||
).images
|
||||
|
||||
# fetch intermediate latents at the interrupted step
|
||||
# from the completed generation process
|
||||
intermediate_latent = pipe_state.state[interrupt_step_idx]
|
||||
|
||||
# compare the intermediate latent to the output of the interrupted process
|
||||
# they should be the same
|
||||
assert torch.allclose(intermediate_latent, output_interrupted, atol=1e-4)
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
|
||||
@@ -319,64 +319,6 @@ class StableDiffusionInpaintPipelineFastTests(
|
||||
out_1 = sd_pipe(**inputs).images
|
||||
assert np.abs(out_0 - out_1).max() < 1e-2
|
||||
|
||||
def test_pipeline_interrupt(self):
|
||||
components = self.get_dummy_components()
|
||||
sd_pipe = StableDiffusionInpaintPipeline(**components)
|
||||
sd_pipe = sd_pipe.to(torch_device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
|
||||
prompt = "hey"
|
||||
num_inference_steps = 3
|
||||
|
||||
# store intermediate latents from the generation process
|
||||
class PipelineState:
|
||||
def __init__(self):
|
||||
self.state = []
|
||||
|
||||
def apply(self, pipe, i, t, callback_kwargs):
|
||||
self.state.append(callback_kwargs["latents"])
|
||||
return callback_kwargs
|
||||
|
||||
pipe_state = PipelineState()
|
||||
sd_pipe(
|
||||
prompt,
|
||||
image=inputs["image"],
|
||||
mask_image=inputs["mask_image"],
|
||||
num_inference_steps=num_inference_steps,
|
||||
output_type="np",
|
||||
generator=torch.Generator("cpu").manual_seed(0),
|
||||
callback_on_step_end=pipe_state.apply,
|
||||
).images
|
||||
|
||||
# interrupt generation at step index
|
||||
interrupt_step_idx = 1
|
||||
|
||||
def callback_on_step_end(pipe, i, t, callback_kwargs):
|
||||
if i == interrupt_step_idx:
|
||||
pipe._interrupt = True
|
||||
|
||||
return callback_kwargs
|
||||
|
||||
output_interrupted = sd_pipe(
|
||||
prompt,
|
||||
image=inputs["image"],
|
||||
mask_image=inputs["mask_image"],
|
||||
num_inference_steps=num_inference_steps,
|
||||
output_type="latent",
|
||||
generator=torch.Generator("cpu").manual_seed(0),
|
||||
callback_on_step_end=callback_on_step_end,
|
||||
).images
|
||||
|
||||
# fetch intermediate latents at the interrupted step
|
||||
# from the completed generation process
|
||||
intermediate_latent = pipe_state.state[interrupt_step_idx]
|
||||
|
||||
# compare the intermediate latent to the output of the interrupted process
|
||||
# they should be the same
|
||||
assert torch.allclose(intermediate_latent, output_interrupted, atol=1e-4)
|
||||
|
||||
|
||||
class StableDiffusionSimpleInpaintPipelineFastTests(StableDiffusionInpaintPipelineFastTests):
|
||||
pipeline_class = StableDiffusionInpaintPipeline
|
||||
|
||||
@@ -969,58 +969,6 @@ class StableDiffusionXLPipelineFastTests(
|
||||
original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
|
||||
), "Original outputs should match when fused QKV projections are disabled."
|
||||
|
||||
def test_pipeline_interrupt(self):
|
||||
components = self.get_dummy_components()
|
||||
sd_pipe = StableDiffusionXLPipeline(**components)
|
||||
sd_pipe = sd_pipe.to(torch_device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "hey"
|
||||
num_inference_steps = 3
|
||||
|
||||
# store intermediate latents from the generation process
|
||||
class PipelineState:
|
||||
def __init__(self):
|
||||
self.state = []
|
||||
|
||||
def apply(self, pipe, i, t, callback_kwargs):
|
||||
self.state.append(callback_kwargs["latents"])
|
||||
return callback_kwargs
|
||||
|
||||
pipe_state = PipelineState()
|
||||
sd_pipe(
|
||||
prompt,
|
||||
num_inference_steps=num_inference_steps,
|
||||
output_type="np",
|
||||
generator=torch.Generator("cpu").manual_seed(0),
|
||||
callback_on_step_end=pipe_state.apply,
|
||||
).images
|
||||
|
||||
# interrupt generation at step index
|
||||
interrupt_step_idx = 1
|
||||
|
||||
def callback_on_step_end(pipe, i, t, callback_kwargs):
|
||||
if i == interrupt_step_idx:
|
||||
pipe._interrupt = True
|
||||
|
||||
return callback_kwargs
|
||||
|
||||
output_interrupted = sd_pipe(
|
||||
prompt,
|
||||
num_inference_steps=num_inference_steps,
|
||||
output_type="latent",
|
||||
generator=torch.Generator("cpu").manual_seed(0),
|
||||
callback_on_step_end=callback_on_step_end,
|
||||
).images
|
||||
|
||||
# fetch intermediate latents at the interrupted step
|
||||
# from the completed generation process
|
||||
intermediate_latent = pipe_state.state[interrupt_step_idx]
|
||||
|
||||
# compare the intermediate latent to the output of the interrupted process
|
||||
# they should be the same
|
||||
assert torch.allclose(intermediate_latent, output_interrupted, atol=1e-4)
|
||||
|
||||
|
||||
@slow
|
||||
class StableDiffusionXLPipelineIntegrationTests(unittest.TestCase):
|
||||
|
||||
@@ -439,64 +439,6 @@ class StableDiffusionXLImg2ImgPipelineFastTests(PipelineLatentTesterMixin, Pipel
|
||||
> 1e-4
|
||||
)
|
||||
|
||||
def test_pipeline_interrupt(self):
|
||||
components = self.get_dummy_components()
|
||||
sd_pipe = StableDiffusionXLImg2ImgPipeline(**components)
|
||||
sd_pipe = sd_pipe.to(torch_device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
|
||||
prompt = "hey"
|
||||
num_inference_steps = 5
|
||||
|
||||
# store intermediate latents from the generation process
|
||||
class PipelineState:
|
||||
def __init__(self):
|
||||
self.state = []
|
||||
|
||||
def apply(self, pipe, i, t, callback_kwargs):
|
||||
self.state.append(callback_kwargs["latents"])
|
||||
return callback_kwargs
|
||||
|
||||
pipe_state = PipelineState()
|
||||
sd_pipe(
|
||||
prompt,
|
||||
image=inputs["image"],
|
||||
strength=0.8,
|
||||
num_inference_steps=num_inference_steps,
|
||||
output_type="np",
|
||||
generator=torch.Generator("cpu").manual_seed(0),
|
||||
callback_on_step_end=pipe_state.apply,
|
||||
).images
|
||||
|
||||
# interrupt generation at step index
|
||||
interrupt_step_idx = 1
|
||||
|
||||
def callback_on_step_end(pipe, i, t, callback_kwargs):
|
||||
if i == interrupt_step_idx:
|
||||
pipe._interrupt = True
|
||||
|
||||
return callback_kwargs
|
||||
|
||||
output_interrupted = sd_pipe(
|
||||
prompt,
|
||||
image=inputs["image"],
|
||||
strength=0.8,
|
||||
num_inference_steps=num_inference_steps,
|
||||
output_type="latent",
|
||||
generator=torch.Generator("cpu").manual_seed(0),
|
||||
callback_on_step_end=callback_on_step_end,
|
||||
).images
|
||||
|
||||
# fetch intermediate latents at the interrupted step
|
||||
# from the completed generation process
|
||||
intermediate_latent = pipe_state.state[interrupt_step_idx]
|
||||
|
||||
# compare the intermediate latent to the output of the interrupted process
|
||||
# they should be the same
|
||||
assert torch.allclose(intermediate_latent, output_interrupted, atol=1e-4)
|
||||
|
||||
|
||||
class StableDiffusionXLImg2ImgRefinerOnlyPipelineFastTests(
|
||||
PipelineLatentTesterMixin, PipelineTesterMixin, SDXLOptionalComponentsTesterMixin, unittest.TestCase
|
||||
|
||||
@@ -746,63 +746,3 @@ class StableDiffusionXLInpaintPipelineFastTests(PipelineLatentTesterMixin, Pipel
|
||||
image_slice1 = images[0, -3:, -3:, -1]
|
||||
image_slice2 = images[1, -3:, -3:, -1]
|
||||
assert np.abs(image_slice1.flatten() - image_slice2.flatten()).max() > 1e-2
|
||||
|
||||
def test_pipeline_interrupt(self):
|
||||
components = self.get_dummy_components()
|
||||
sd_pipe = StableDiffusionXLInpaintPipeline(**components)
|
||||
sd_pipe = sd_pipe.to(torch_device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
|
||||
prompt = "hey"
|
||||
num_inference_steps = 5
|
||||
|
||||
# store intermediate latents from the generation process
|
||||
class PipelineState:
|
||||
def __init__(self):
|
||||
self.state = []
|
||||
|
||||
def apply(self, pipe, i, t, callback_kwargs):
|
||||
self.state.append(callback_kwargs["latents"])
|
||||
return callback_kwargs
|
||||
|
||||
pipe_state = PipelineState()
|
||||
sd_pipe(
|
||||
prompt,
|
||||
image=inputs["image"],
|
||||
mask_image=inputs["mask_image"],
|
||||
strength=0.8,
|
||||
num_inference_steps=num_inference_steps,
|
||||
output_type="np",
|
||||
generator=torch.Generator("cpu").manual_seed(0),
|
||||
callback_on_step_end=pipe_state.apply,
|
||||
).images
|
||||
|
||||
# interrupt generation at step index
|
||||
interrupt_step_idx = 1
|
||||
|
||||
def callback_on_step_end(pipe, i, t, callback_kwargs):
|
||||
if i == interrupt_step_idx:
|
||||
pipe._interrupt = True
|
||||
|
||||
return callback_kwargs
|
||||
|
||||
output_interrupted = sd_pipe(
|
||||
prompt,
|
||||
image=inputs["image"],
|
||||
mask_image=inputs["mask_image"],
|
||||
strength=0.8,
|
||||
num_inference_steps=num_inference_steps,
|
||||
output_type="latent",
|
||||
generator=torch.Generator("cpu").manual_seed(0),
|
||||
callback_on_step_end=callback_on_step_end,
|
||||
).images
|
||||
|
||||
# fetch intermediate latents at the interrupted step
|
||||
# from the completed generation process
|
||||
intermediate_latent = pipe_state.state[interrupt_step_idx]
|
||||
|
||||
# compare the intermediate latent to the output of the interrupted process
|
||||
# they should be the same
|
||||
assert torch.allclose(intermediate_latent, output_interrupted, atol=1e-4)
|
||||
|
||||
@@ -437,7 +437,7 @@ class PipelineTesterMixin:
|
||||
self._test_inference_batch_consistent(batch_sizes=batch_sizes)
|
||||
|
||||
def _test_inference_batch_consistent(
|
||||
self, batch_sizes=[2], additional_params_copy_to_batched_inputs=["num_inference_steps"], batch_generator=True
|
||||
self, batch_sizes=[2], additional_params_copy_to_batched_inputs=["num_inference_steps"]
|
||||
):
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
@@ -472,7 +472,7 @@ class PipelineTesterMixin:
|
||||
else:
|
||||
batched_input[name] = batch_size * [value]
|
||||
|
||||
if batch_generator and "generator" in inputs:
|
||||
if "generator" in inputs:
|
||||
batched_input["generator"] = [self.get_generator(i) for i in range(batch_size)]
|
||||
|
||||
if "batch_size" in inputs:
|
||||
|
||||
Reference in New Issue
Block a user