mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-09 05:54:24 +08:00
Compare commits
20 Commits
v0.26.3
...
rename-att
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fd7f9c74e1 | ||
|
|
b71f35b908 | ||
|
|
493228a708 | ||
|
|
8bf046b7fb | ||
|
|
bb99623d09 | ||
|
|
fdf55b1f1c | ||
|
|
c6f8c310c3 | ||
|
|
64909f17b7 | ||
|
|
f09ca909c8 | ||
|
|
a5fc62f819 | ||
|
|
fbdf26bac5 | ||
|
|
13001ee315 | ||
|
|
65329aed98 | ||
|
|
02338c9317 | ||
|
|
15ed53d272 | ||
|
|
9cc59ba089 | ||
|
|
adcbe674a4 | ||
|
|
ec9840a5db | ||
|
|
093a03a1a1 | ||
|
|
c3369f5673 |
@@ -18,11 +18,11 @@ The abstract from the paper is:
|
||||
|
||||
*Video synthesis has recently made remarkable strides benefiting from the rapid development of diffusion models. However, it still encounters challenges in terms of semantic accuracy, clarity and spatio-temporal continuity. They primarily arise from the scarcity of well-aligned text-video data and the complex inherent structure of videos, making it difficult for the model to simultaneously ensure semantic and qualitative excellence. In this report, we propose a cascaded I2VGen-XL approach that enhances model performance by decoupling these two factors and ensures the alignment of the input data by utilizing static images as a form of crucial guidance. I2VGen-XL consists of two stages: i) the base stage guarantees coherent semantics and preserves content from input images by using two hierarchical encoders, and ii) the refinement stage enhances the video's details by incorporating an additional brief text and improves the resolution to 1280×720. To improve the diversity, we collect around 35 million single-shot text-video pairs and 6 billion text-image pairs to optimize the model. By this means, I2VGen-XL can simultaneously enhance the semantic accuracy, continuity of details and clarity of generated videos. Through extensive experiments, we have investigated the underlying principles of I2VGen-XL and compared it with current top methods, which can demonstrate its effectiveness on diverse data. The source code and models will be publicly available at [this https URL](https://i2vgen-xl.github.io/).*
|
||||
|
||||
The original codebase can be found [here](https://github.com/ali-vilab/i2vgen-xl/). The model checkpoints can be found [here](https://huggingface.co/ali-vilab/).
|
||||
The original codebase can be found [here](https://github.com/ali-vilab/i2vgen-xl/). The model checkpoints can be found [here](https://huggingface.co/ali-vilab/).
|
||||
|
||||
<Tip>
|
||||
|
||||
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines. Also, to know more about reducing the memory usage of this pipeline, refer to the ["Reduce memory usage"] section [here](../../using-diffusers/svd#reduce-memory-usage).
|
||||
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines. Also, to know more about reducing the memory usage of this pipeline, refer to the ["Reduce memory usage"] section [here](../../using-diffusers/svd#reduce-memory-usage).
|
||||
|
||||
</Tip>
|
||||
|
||||
@@ -31,7 +31,7 @@ Sample output with I2VGenXL:
|
||||
<table>
|
||||
<tr>
|
||||
<td><center>
|
||||
masterpiece, bestquality, sunset.
|
||||
library.
|
||||
<br>
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/i2vgen-xl-example.gif"
|
||||
alt="library"
|
||||
@@ -43,9 +43,9 @@ Sample output with I2VGenXL:
|
||||
## Notes
|
||||
|
||||
* I2VGenXL always uses a `clip_skip` value of 1. This means it leverages the penultimate layer representations from the text encoder of CLIP.
|
||||
* It can generate videos of quality that is often on par with [Stable Video Diffusion](../../using-diffusers/svd) (SVD).
|
||||
* Unlike SVD, it additionally accepts text prompts as inputs.
|
||||
* It can generate higher resolution videos.
|
||||
* It can generate videos of quality that is often on par with [Stable Video Diffusion](../../using-diffusers/svd) (SVD).
|
||||
* Unlike SVD, it additionally accepts text prompts as inputs.
|
||||
* It can generate higher resolution videos.
|
||||
* When using the [`DDIMScheduler`] (which is default for this pipeline), less than 50 steps for inference leads to bad results.
|
||||
|
||||
## I2VGenXLPipeline
|
||||
|
||||
@@ -70,7 +70,7 @@ Here are some sample outputs:
|
||||
<table>
|
||||
<tr>
|
||||
<td><center>
|
||||
masterpiece, bestquality, sunset.
|
||||
cat in a field.
|
||||
<br>
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/pia-default-output.gif"
|
||||
alt="cat in a field"
|
||||
@@ -119,7 +119,7 @@ image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/pix2pix/cat_6.png?download=true"
|
||||
)
|
||||
image = image.resize((512, 512))
|
||||
prompt = "cat in a hat"
|
||||
prompt = "cat in a field"
|
||||
negative_prompt = "wrong white balance, dark, sketches,worst quality,low quality"
|
||||
|
||||
generator = torch.Generator("cpu").manual_seed(0)
|
||||
@@ -132,7 +132,7 @@ export_to_gif(frames, "pia-freeinit-animation.gif")
|
||||
<table>
|
||||
<tr>
|
||||
<td><center>
|
||||
masterpiece, bestquality, sunset.
|
||||
cat in a field.
|
||||
<br>
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/pia-freeinit-output-cat.gif"
|
||||
alt="cat in a field"
|
||||
|
||||
@@ -41,7 +41,7 @@ pipe = DiffusionPipeline.from_pretrained("damo-vilab/text-to-video-ms-1.7b", tor
|
||||
pipe = pipe.to("cuda")
|
||||
|
||||
prompt = "Spiderman is surfing"
|
||||
video_frames = pipe(prompt).frames
|
||||
video_frames = pipe(prompt).frames[0]
|
||||
video_path = export_to_video(video_frames)
|
||||
video_path
|
||||
```
|
||||
@@ -64,7 +64,7 @@ pipe.enable_model_cpu_offload()
|
||||
pipe.enable_vae_slicing()
|
||||
|
||||
prompt = "Darth Vader surfing a wave"
|
||||
video_frames = pipe(prompt, num_frames=64).frames
|
||||
video_frames = pipe(prompt, num_frames=64).frames[0]
|
||||
video_path = export_to_video(video_frames)
|
||||
video_path
|
||||
```
|
||||
@@ -83,7 +83,7 @@ pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
prompt = "Spiderman is surfing"
|
||||
video_frames = pipe(prompt, num_inference_steps=25).frames
|
||||
video_frames = pipe(prompt, num_inference_steps=25).frames[0]
|
||||
video_path = export_to_video(video_frames)
|
||||
video_path
|
||||
```
|
||||
@@ -130,7 +130,7 @@ pipe.unet.enable_forward_chunking(chunk_size=1, dim=1)
|
||||
pipe.enable_vae_slicing()
|
||||
|
||||
prompt = "Darth Vader surfing a wave"
|
||||
video_frames = pipe(prompt, num_frames=24).frames
|
||||
video_frames = pipe(prompt, num_frames=24).frames[0]
|
||||
video_path = export_to_video(video_frames)
|
||||
video_path
|
||||
```
|
||||
@@ -148,7 +148,7 @@ pipe.enable_vae_slicing()
|
||||
|
||||
video = [Image.fromarray(frame).resize((1024, 576)) for frame in video_frames]
|
||||
|
||||
video_frames = pipe(prompt, video=video, strength=0.6).frames
|
||||
video_frames = pipe(prompt, video=video, strength=0.6).frames[0]
|
||||
video_path = export_to_video(video_frames)
|
||||
video_path
|
||||
```
|
||||
|
||||
244
examples/advanced_diffusion_training/README.md
Normal file
244
examples/advanced_diffusion_training/README.md
Normal file
@@ -0,0 +1,244 @@
|
||||
# Advanced diffusion training examples
|
||||
|
||||
## Train Dreambooth LoRA with Stable Diffusion XL
|
||||
> [!TIP]
|
||||
> 💡 This example follows the techniques and recommended practices covered in the blog post: [LoRA training scripts of the world, unite!](https://huggingface.co/blog/sdxl_lora_advanced_script). Make sure to check it out before starting 🤗
|
||||
|
||||
[DreamBooth](https://arxiv.org/abs/2208.12242) is a method to personalize text2image models like stable diffusion given just a few(3~5) images of a subject.
|
||||
|
||||
LoRA - Low-Rank Adaption of Large Language Models, was first introduced by Microsoft in [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen*
|
||||
In a nutshell, LoRA allows to adapt pretrained models by adding pairs of rank-decomposition matrices to existing weights and **only** training those newly added weights. This has a couple of advantages:
|
||||
- Previous pretrained weights are kept frozen so that the model is not prone to [catastrophic forgetting](https://www.pnas.org/doi/10.1073/pnas.1611835114)
|
||||
- Rank-decomposition matrices have significantly fewer parameters than the original model, which means that trained LoRA weights are easily portable.
|
||||
- LoRA attention layers allow to control to which extent the model is adapted towards new training images via a `scale` parameter.
|
||||
[cloneofsimo](https://github.com/cloneofsimo) was the first to try out LoRA training for Stable Diffusion in
|
||||
the popular [lora](https://github.com/cloneofsimo/lora) GitHub repository.
|
||||
|
||||
The `train_dreambooth_lora_sdxl_advanced.py` script shows how to implement dreambooth-LoRA, combining the training process shown in `train_dreambooth_lora_sdxl.py`, with
|
||||
advanced features and techniques, inspired and built upon contributions by [Nataniel Ruiz](https://twitter.com/natanielruizg): [Dreambooth](https://dreambooth.github.io), [Rinon Gal](https://twitter.com/RinonGal): [Textual Inversion](https://textual-inversion.github.io), [Ron Mokady](https://twitter.com/MokadyRon): [Pivotal Tuning](https://arxiv.org/abs/2106.05744), [Simo Ryu](https://twitter.com/cloneofsimo): [cog-sdxl](https://github.com/replicate/cog-sdxl),
|
||||
[Kohya](https://twitter.com/kohya_tech/): [sd-scripts](https://github.com/kohya-ss/sd-scripts), [The Last Ben](https://twitter.com/__TheBen): [fast-stable-diffusion](https://github.com/TheLastBen/fast-stable-diffusion) ❤️
|
||||
|
||||
> [!NOTE]
|
||||
> 💡If this is your first time training a Dreambooth LoRA, congrats!🥳
|
||||
> You might want to familiarize yourself more with the techniques: [Dreambooth blog](https://huggingface.co/blog/dreambooth), [Using LoRA for Efficient Stable Diffusion Fine-Tuning blog](https://huggingface.co/blog/lora)
|
||||
|
||||
📚 Read more about the advanced features and best practices in this community derived blog post: [LoRA training scripts of the world, unite!](https://huggingface.co/blog/sdxl_lora_advanced_script)
|
||||
|
||||
|
||||
## Running locally with PyTorch
|
||||
|
||||
### Installing the dependencies
|
||||
|
||||
Before running the scripts, make sure to install the library's training dependencies:
|
||||
|
||||
**Important**
|
||||
|
||||
To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
|
||||
```bash
|
||||
git clone https://github.com/huggingface/diffusers
|
||||
cd diffusers
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
Then cd in the `examples/advanced_diffusion_training` folder and run
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
|
||||
|
||||
```bash
|
||||
accelerate config
|
||||
```
|
||||
|
||||
Or for a default accelerate configuration without answering questions about your environment
|
||||
|
||||
```bash
|
||||
accelerate config default
|
||||
```
|
||||
|
||||
Or if your environment doesn't support an interactive shell e.g. a notebook
|
||||
|
||||
```python
|
||||
from accelerate.utils import write_basic_config
|
||||
write_basic_config()
|
||||
```
|
||||
|
||||
When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
|
||||
Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.
|
||||
|
||||
### Pivotal Tuning
|
||||
**Training with text encoder(s)**
|
||||
|
||||
Alongside the UNet, LoRA fine-tuning of the text encoders is also supported. In addition to the text encoder optimization
|
||||
available with `train_dreambooth_lora_sdxl_advanced.py`, in the advanced script **pivotal tuning** is also supported.
|
||||
[pivotal tuning](https://huggingface.co/blog/sdxl_lora_advanced_script#pivotal-tuning) combines Textual Inversion with regular diffusion fine-tuning -
|
||||
we insert new tokens into the text encoders of the model, instead of reusing existing ones.
|
||||
We then optimize the newly-inserted token embeddings to represent the new concept.
|
||||
|
||||
To do so, just specify `--train_text_encoder_ti` while launching training (for regular text encoder optimizations, use `--train_text_encoder`).
|
||||
Please keep the following points in mind:
|
||||
|
||||
* SDXL has two text encoders. So, we fine-tune both using LoRA.
|
||||
* When not fine-tuning the text encoders, we ALWAYS precompute the text embeddings to save memoםהקרry.
|
||||
|
||||
|
||||
### 3D icon example
|
||||
|
||||
Now let's get our dataset. For this example we will use some cool images of 3d rendered icons: https://huggingface.co/datasets/linoyts/3d_icon.
|
||||
|
||||
Let's first download it locally:
|
||||
|
||||
```python
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
local_dir = "./3d_icon"
|
||||
snapshot_download(
|
||||
"LinoyTsaban/3d_icon",
|
||||
local_dir=local_dir, repo_type="dataset",
|
||||
ignore_patterns=".gitattributes",
|
||||
)
|
||||
```
|
||||
|
||||
Let's review some of the advanced features we're going to be using for this example:
|
||||
- **custom captions**:
|
||||
To use custom captioning, first ensure that you have the datasets library installed, otherwise you can install it by
|
||||
```bash
|
||||
pip install datasets
|
||||
```
|
||||
|
||||
Now we'll simply specify the name of the dataset and caption column (in this case it's "prompt")
|
||||
|
||||
```
|
||||
--dataset_name=./3d_icon
|
||||
--caption_column=prompt
|
||||
```
|
||||
|
||||
You can also load a dataset straight from by specifying it's name in `dataset_name`.
|
||||
Look [here](https://huggingface.co/blog/sdxl_lora_advanced_script#custom-captioning) for more info on creating/loadin your own caption dataset.
|
||||
|
||||
- **optimizer**: for this example, we'll use [prodigy](https://huggingface.co/blog/sdxl_lora_advanced_script#adaptive-optimizers) - an adaptive optimizer
|
||||
- **pivotal tuning**
|
||||
- **min SNR gamma**
|
||||
|
||||
**Now, we can launch training:**
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0"
|
||||
export DATASET_NAME="./3d_icon"
|
||||
export OUTPUT_DIR="3d-icon-SDXL-LoRA"
|
||||
export VAE_PATH="madebyollin/sdxl-vae-fp16-fix"
|
||||
|
||||
accelerate launch train_dreambooth_lora_sdxl_advanced.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||
--pretrained_vae_model_name_or_path=$VAE_PATH \
|
||||
--dataset_name=$DATASET_NAME \
|
||||
--instance_prompt="3d icon in the style of TOK" \
|
||||
--validation_prompt="a TOK icon of an astronaut riding a horse, in the style of TOK" \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--caption_column="prompt" \
|
||||
--mixed_precision="bf16" \
|
||||
--resolution=1024 \
|
||||
--train_batch_size=3 \
|
||||
--repeats=1 \
|
||||
--report_to="wandb"\
|
||||
--gradient_accumulation_steps=1 \
|
||||
--gradient_checkpointing \
|
||||
--learning_rate=1.0 \
|
||||
--text_encoder_lr=1.0 \
|
||||
--optimizer="prodigy"\
|
||||
--train_text_encoder_ti\
|
||||
--train_text_encoder_ti_frac=0.5\
|
||||
--snr_gamma=5.0 \
|
||||
--lr_scheduler="constant" \
|
||||
--lr_warmup_steps=0 \
|
||||
--rank=8 \
|
||||
--max_train_steps=1000 \
|
||||
--checkpointing_steps=2000 \
|
||||
--seed="0" \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
To better track our training experiments, we're using the following flags in the command above:
|
||||
|
||||
* `report_to="wandb` will ensure the training runs are tracked on Weights and Biases. To use it, be sure to install `wandb` with `pip install wandb`.
|
||||
* `validation_prompt` and `validation_epochs` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.
|
||||
|
||||
Our experiments were conducted on a single 40GB A100 GPU.
|
||||
|
||||
|
||||
### Inference
|
||||
|
||||
Once training is done, we can perform inference like so:
|
||||
1. starting with loading the unet lora weights
|
||||
```python
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download, upload_file
|
||||
from diffusers import DiffusionPipeline
|
||||
from diffusers.models import AutoencoderKL
|
||||
from safetensors.torch import load_file
|
||||
|
||||
username = "linoyts"
|
||||
repo_id = f"{username}/3d-icon-SDXL-LoRA"
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
torch_dtype=torch.float16,
|
||||
variant="fp16",
|
||||
).to("cuda")
|
||||
|
||||
|
||||
pipe.load_lora_weights(repo_id, weight_name="pytorch_lora_weights.safetensors")
|
||||
```
|
||||
2. now we load the pivotal tuning embeddings
|
||||
|
||||
```python
|
||||
text_encoders = [pipe.text_encoder, pipe.text_encoder_2]
|
||||
tokenizers = [pipe.tokenizer, pipe.tokenizer_2]
|
||||
|
||||
embedding_path = hf_hub_download(repo_id=repo_id, filename="3d-icon-SDXL-LoRA_emb.safetensors", repo_type="model")
|
||||
|
||||
state_dict = load_file(embedding_path)
|
||||
# load embeddings of text_encoder 1 (CLIP ViT-L/14)
|
||||
pipe.load_textual_inversion(state_dict["clip_l"], token=["<s0>", "<s1>"], text_encoder=pipe.text_encoder, tokenizer=pipe.tokenizer)
|
||||
# load embeddings of text_encoder 2 (CLIP ViT-G/14)
|
||||
pipe.load_textual_inversion(state_dict["clip_g"], token=["<s0>", "<s1>"], text_encoder=pipe.text_encoder_2, tokenizer=pipe.tokenizer_2)
|
||||
```
|
||||
|
||||
3. let's generate images
|
||||
|
||||
```python
|
||||
instance_token = "<s0><s1>"
|
||||
prompt = f"a {instance_token} icon of an orange llama eating ramen, in the style of {instance_token}"
|
||||
|
||||
image = pipe(prompt=prompt, num_inference_steps=25, cross_attention_kwargs={"scale": 1.0}).images[0]
|
||||
image.save("llama.png")
|
||||
```
|
||||
|
||||
### Comfy UI / AUTOMATIC1111 Inference
|
||||
The new script fully supports textual inversion loading with Comfy UI and AUTOMATIC1111 formats!
|
||||
|
||||
**AUTOMATIC1111 / SD.Next** \
|
||||
In AUTOMATIC1111/SD.Next we will load a LoRA and a textual embedding at the same time.
|
||||
- *LoRA*: Besides the diffusers format, the script will also train a WebUI compatible LoRA. It is generated as `{your_lora_name}.safetensors`. You can then include it in your `models/Lora` directory.
|
||||
- *Embedding*: the embedding is the same for diffusers and WebUI. You can download your `{lora_name}_emb.safetensors` file from a trained model, and include it in your `embeddings` directory.
|
||||
|
||||
You can then run inference by prompting `a y2k_emb webpage about the movie Mean Girls <lora:y2k:0.9>`. You can use the `y2k_emb` token normally, including increasing its weight by doing `(y2k_emb:1.2)`.
|
||||
|
||||
**ComfyUI** \
|
||||
In ComfyUI we will load a LoRA and a textual embedding at the same time.
|
||||
- *LoRA*: Besides the diffusers format, the script will also train a ComfyUI compatible LoRA. It is generated as `{your_lora_name}.safetensors`. You can then include it in your `models/Lora` directory. Then you will load the LoRALoader node and hook that up with your model and CLIP. [Official guide for loading LoRAs](https://comfyanonymous.github.io/ComfyUI_examples/lora/)
|
||||
- *Embedding*: the embedding is the same for diffusers and WebUI. You can download your `{lora_name}_emb.safetensors` file from a trained model, and include it in your `models/embeddings` directory and use it in your prompts like `embedding:y2k_emb`. [Official guide for loading embeddings](https://comfyanonymous.github.io/ComfyUI_examples/textual_inversion_embeddings/).
|
||||
-
|
||||
### Specifying a better VAE
|
||||
|
||||
SDXL's VAE is known to suffer from numerical instability issues. This is why we also expose a CLI argument namely `--pretrained_vae_model_name_or_path` that lets you specify the location of a better VAE (such as [this one](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)).
|
||||
|
||||
|
||||
### Tips and Tricks
|
||||
Check out [these recommended practices](https://huggingface.co/blog/sdxl_lora_advanced_script#additional-good-practices)
|
||||
|
||||
## Running on Colab Notebook
|
||||
Check out [this notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/SDXL_DreamBooth_LoRA_advanced_example.ipynb).
|
||||
to train using the advanced features (including pivotal tuning), and [this notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/SDXL_DreamBooth_LoRA_.ipynb) to train on a free colab, using some of the advanced features (excluding pivotal tuning)
|
||||
|
||||
7
examples/advanced_diffusion_training/requirements.txt
Normal file
7
examples/advanced_diffusion_training/requirements.txt
Normal file
@@ -0,0 +1,7 @@
|
||||
accelerate>=0.16.0
|
||||
torchvision
|
||||
transformers>=4.25.1
|
||||
ftfy
|
||||
tensorboard
|
||||
Jinja2
|
||||
peft==0.7.0
|
||||
@@ -119,10 +119,9 @@ def save_model_card(
|
||||
diffusers_imports_pivotal = """from huggingface_hub import hf_hub_download
|
||||
from safetensors.torch import load_file
|
||||
"""
|
||||
diffusers_example_pivotal = f"""embedding_path = hf_hub_download(repo_id='{repo_id}', filename='{embeddings_filename}.safetensors' repo_type="model")
|
||||
diffusers_example_pivotal = f"""embedding_path = hf_hub_download(repo_id='{repo_id}', filename='{embeddings_filename}.safetensors', repo_type="model")
|
||||
state_dict = load_file(embedding_path)
|
||||
pipeline.load_textual_inversion(state_dict["clip_l"], token=[{ti_keys}], text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer)
|
||||
pipeline.load_textual_inversion(state_dict["clip_g"], token=[{ti_keys}], text_encoder=pipeline.text_encoder_2, tokenizer=pipeline.tokenizer_2)
|
||||
"""
|
||||
webui_example_pivotal = f"""- *Embeddings*: download **[`{embeddings_filename}.safetensors` here 💾](/{repo_id}/blob/main/{embeddings_filename}.safetensors)**.
|
||||
- Place it on it on your `embeddings` folder
|
||||
@@ -389,7 +388,7 @@ def parse_args(input_args=None):
|
||||
parser.add_argument(
|
||||
"--resolution",
|
||||
type=int,
|
||||
default=1024,
|
||||
default=512,
|
||||
help=(
|
||||
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
|
||||
" resolution"
|
||||
@@ -645,6 +644,7 @@ def parse_args(input_args=None):
|
||||
parser.add_argument(
|
||||
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
|
||||
)
|
||||
parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.")
|
||||
parser.add_argument(
|
||||
"--rank",
|
||||
type=int,
|
||||
@@ -745,10 +745,11 @@ class TokenEmbeddingsHandler:
|
||||
|
||||
idx += 1
|
||||
|
||||
# copied from train_dreambooth_lora_sdxl_advanced.py
|
||||
def save_embeddings(self, file_path: str):
|
||||
assert self.train_ids is not None, "Initialize new tokens before saving embeddings."
|
||||
tensors = {}
|
||||
# text_encoder_0 - CLIP ViT-L/14, text_encoder_1 - CLIP ViT-G/14
|
||||
# text_encoder_0 - CLIP ViT-L/14, text_encoder_1 - CLIP ViT-G/14 - TODO - change for sd
|
||||
idx_to_text_encoder_name = {0: "clip_l", 1: "clip_g"}
|
||||
for idx, text_encoder in enumerate(self.text_encoders):
|
||||
assert text_encoder.text_model.embeddings.token_embedding.weight.data.shape[0] == len(
|
||||
@@ -1634,6 +1635,11 @@ def main(args):
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(model_input)
|
||||
if args.noise_offset:
|
||||
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
|
||||
noise += args.noise_offset * torch.randn(
|
||||
(model_input.shape[0], model_input.shape[1], 1, 1), device=model_input.device
|
||||
)
|
||||
bsz = model_input.shape[0]
|
||||
# Sample a random timestep for each image
|
||||
timesteps = torch.randint(
|
||||
@@ -1788,6 +1794,7 @@ def main(args):
|
||||
pipeline = StableDiffusionPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
vae=vae,
|
||||
tokenizer=tokenizer_one,
|
||||
text_encoder=accelerator.unwrap_model(text_encoder_one),
|
||||
unet=accelerator.unwrap_model(unet),
|
||||
revision=args.revision,
|
||||
@@ -1860,6 +1867,11 @@ def main(args):
|
||||
unet_lora_layers=unet_lora_layers,
|
||||
text_encoder_lora_layers=text_encoder_lora_layers,
|
||||
)
|
||||
|
||||
if args.train_text_encoder_ti:
|
||||
embeddings_path = f"{args.output_dir}/{args.output_dir}_emb.safetensors"
|
||||
embedding_handler.save_embeddings(embeddings_path)
|
||||
|
||||
images = []
|
||||
if args.validation_prompt and args.num_validation_images > 0:
|
||||
# Final inference
|
||||
@@ -1895,6 +1907,18 @@ def main(args):
|
||||
# load attention processors
|
||||
pipeline.load_lora_weights(args.output_dir)
|
||||
|
||||
# load new tokens
|
||||
if args.train_text_encoder_ti:
|
||||
state_dict = load_file(embeddings_path)
|
||||
all_new_tokens = []
|
||||
for key, value in token_abstraction_dict.items():
|
||||
all_new_tokens.extend(value)
|
||||
pipeline.load_textual_inversion(
|
||||
state_dict["clip_l"],
|
||||
token=all_new_tokens,
|
||||
text_encoder=pipeline.text_encoder,
|
||||
tokenizer=pipeline.tokenizer,
|
||||
)
|
||||
# run inference
|
||||
pipeline = pipeline.to(accelerator.device)
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
|
||||
@@ -1917,11 +1941,6 @@ def main(args):
|
||||
}
|
||||
)
|
||||
|
||||
if args.train_text_encoder_ti:
|
||||
embedding_handler.save_embeddings(
|
||||
f"{args.output_dir}/{args.output_dir}_emb.safetensors",
|
||||
)
|
||||
|
||||
# Conver to WebUI format
|
||||
lora_state_dict = load_file(f"{args.output_dir}/pytorch_lora_weights.safetensors")
|
||||
peft_state_dict = convert_all_state_dict_to_peft(lora_state_dict)
|
||||
|
||||
@@ -20,6 +20,7 @@ import itertools
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import shutil
|
||||
import warnings
|
||||
@@ -45,6 +46,7 @@ from PIL.ImageOps import exif_transpose
|
||||
from safetensors.torch import load_file, save_file
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision import transforms
|
||||
from torchvision.transforms.functional import crop
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import AutoTokenizer, PretrainedConfig
|
||||
|
||||
@@ -121,7 +123,7 @@ def save_model_card(
|
||||
diffusers_imports_pivotal = """from huggingface_hub import hf_hub_download
|
||||
from safetensors.torch import load_file
|
||||
"""
|
||||
diffusers_example_pivotal = f"""embedding_path = hf_hub_download(repo_id='{repo_id}', filename='{embeddings_filename}.safetensors' repo_type="model")
|
||||
diffusers_example_pivotal = f"""embedding_path = hf_hub_download(repo_id='{repo_id}', filename='{embeddings_filename}.safetensors', repo_type="model")
|
||||
state_dict = load_file(embedding_path)
|
||||
pipeline.load_textual_inversion(state_dict["clip_l"], token=[{ti_keys}], text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer)
|
||||
pipeline.load_textual_inversion(state_dict["clip_g"], token=[{ti_keys}], text_encoder=pipeline.text_encoder_2, tokenizer=pipeline.tokenizer_2)
|
||||
@@ -397,18 +399,6 @@ def parse_args(input_args=None):
|
||||
" resolution"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--crops_coords_top_left_h",
|
||||
type=int,
|
||||
default=0,
|
||||
help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--crops_coords_top_left_w",
|
||||
type=int,
|
||||
default=0,
|
||||
help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--center_crop",
|
||||
default=False,
|
||||
@@ -418,6 +408,11 @@ def parse_args(input_args=None):
|
||||
" cropped. The images will be resized to the resolution first before cropping."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--random_flip",
|
||||
action="store_true",
|
||||
help="whether to randomly flip images horizontally",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train_text_encoder",
|
||||
action="store_true",
|
||||
@@ -659,6 +654,7 @@ def parse_args(input_args=None):
|
||||
parser.add_argument(
|
||||
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
|
||||
)
|
||||
parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.")
|
||||
parser.add_argument(
|
||||
"--rank",
|
||||
type=int,
|
||||
@@ -901,6 +897,41 @@ class DreamBoothDataset(Dataset):
|
||||
self.instance_images = []
|
||||
for img in instance_images:
|
||||
self.instance_images.extend(itertools.repeat(img, repeats))
|
||||
|
||||
# image processing to prepare for using SD-XL micro-conditioning
|
||||
self.original_sizes = []
|
||||
self.crop_top_lefts = []
|
||||
self.pixel_values = []
|
||||
train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR)
|
||||
train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size)
|
||||
train_flip = transforms.RandomHorizontalFlip(p=1.0)
|
||||
train_transforms = transforms.Compose(
|
||||
[
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.5], [0.5]),
|
||||
]
|
||||
)
|
||||
for image in self.instance_images:
|
||||
image = exif_transpose(image)
|
||||
if not image.mode == "RGB":
|
||||
image = image.convert("RGB")
|
||||
self.original_sizes.append((image.height, image.width))
|
||||
image = train_resize(image)
|
||||
if args.random_flip and random.random() < 0.5:
|
||||
# flip
|
||||
image = train_flip(image)
|
||||
if args.center_crop:
|
||||
y1 = max(0, int(round((image.height - args.resolution) / 2.0)))
|
||||
x1 = max(0, int(round((image.width - args.resolution) / 2.0)))
|
||||
image = train_crop(image)
|
||||
else:
|
||||
y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution))
|
||||
image = crop(image, y1, x1, h, w)
|
||||
crop_top_left = (y1, x1)
|
||||
self.crop_top_lefts.append(crop_top_left)
|
||||
image = train_transforms(image)
|
||||
self.pixel_values.append(image)
|
||||
|
||||
self.num_instance_images = len(self.instance_images)
|
||||
self._length = self.num_instance_images
|
||||
|
||||
@@ -930,12 +961,12 @@ class DreamBoothDataset(Dataset):
|
||||
|
||||
def __getitem__(self, index):
|
||||
example = {}
|
||||
instance_image = self.instance_images[index % self.num_instance_images]
|
||||
instance_image = exif_transpose(instance_image)
|
||||
|
||||
if not instance_image.mode == "RGB":
|
||||
instance_image = instance_image.convert("RGB")
|
||||
example["instance_images"] = self.image_transforms(instance_image)
|
||||
instance_image = self.pixel_values[index % self.num_instance_images]
|
||||
original_size = self.original_sizes[index % self.num_instance_images]
|
||||
crop_top_left = self.crop_top_lefts[index % self.num_instance_images]
|
||||
example["instance_images"] = instance_image
|
||||
example["original_size"] = original_size
|
||||
example["crop_top_left"] = crop_top_left
|
||||
|
||||
if self.custom_instance_prompts:
|
||||
caption = self.custom_instance_prompts[index % self.num_instance_images]
|
||||
@@ -966,6 +997,8 @@ class DreamBoothDataset(Dataset):
|
||||
def collate_fn(examples, with_prior_preservation=False):
|
||||
pixel_values = [example["instance_images"] for example in examples]
|
||||
prompts = [example["instance_prompt"] for example in examples]
|
||||
original_sizes = [example["original_size"] for example in examples]
|
||||
crop_top_lefts = [example["crop_top_left"] for example in examples]
|
||||
|
||||
# Concat class and instance examples for prior preservation.
|
||||
# We do this to avoid doing two forward passes.
|
||||
@@ -976,7 +1009,12 @@ def collate_fn(examples, with_prior_preservation=False):
|
||||
pixel_values = torch.stack(pixel_values)
|
||||
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
|
||||
|
||||
batch = {"pixel_values": pixel_values, "prompts": prompts}
|
||||
batch = {
|
||||
"pixel_values": pixel_values,
|
||||
"prompts": prompts,
|
||||
"original_sizes": original_sizes,
|
||||
"crop_top_lefts": crop_top_lefts,
|
||||
}
|
||||
return batch
|
||||
|
||||
|
||||
@@ -1198,7 +1236,9 @@ def main(args):
|
||||
args.instance_prompt = args.instance_prompt.replace(token_abs, "".join(token_replacement))
|
||||
if args.with_prior_preservation:
|
||||
args.class_prompt = args.class_prompt.replace(token_abs, "".join(token_replacement))
|
||||
|
||||
if args.validation_prompt:
|
||||
args.validation_prompt = args.validation_prompt.replace(token_abs, "".join(token_replacement))
|
||||
print("validation prompt:", args.validation_prompt)
|
||||
# initialize the new tokens for textual inversion
|
||||
embedding_handler = TokenEmbeddingsHandler(
|
||||
[text_encoder_one, text_encoder_two], [tokenizer_one, tokenizer_two]
|
||||
@@ -1539,11 +1579,11 @@ def main(args):
|
||||
# pooled text embeddings
|
||||
# time ids
|
||||
|
||||
def compute_time_ids():
|
||||
def compute_time_ids(crops_coords_top_left, original_size=None):
|
||||
# Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
|
||||
original_size = (args.resolution, args.resolution)
|
||||
if original_size is None:
|
||||
original_size = (args.resolution, args.resolution)
|
||||
target_size = (args.resolution, args.resolution)
|
||||
crops_coords_top_left = (args.crops_coords_top_left_h, args.crops_coords_top_left_w)
|
||||
add_time_ids = list(original_size + crops_coords_top_left + target_size)
|
||||
add_time_ids = torch.tensor([add_time_ids])
|
||||
add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype)
|
||||
@@ -1560,9 +1600,6 @@ def main(args):
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device)
|
||||
return prompt_embeds, pooled_prompt_embeds
|
||||
|
||||
# Handle instance prompt.
|
||||
instance_time_ids = compute_time_ids()
|
||||
|
||||
# If no type of tuning is done on the text_encoder and custom instance prompts are NOT
|
||||
# provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid
|
||||
# the redundant encoding.
|
||||
@@ -1573,7 +1610,6 @@ def main(args):
|
||||
|
||||
# Handle class prompt for prior-preservation.
|
||||
if args.with_prior_preservation:
|
||||
class_time_ids = compute_time_ids()
|
||||
if freeze_text_encoder:
|
||||
class_prompt_hidden_states, class_pooled_prompt_embeds = compute_text_embeddings(
|
||||
args.class_prompt, text_encoders, tokenizers
|
||||
@@ -1588,9 +1624,6 @@ def main(args):
|
||||
# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
|
||||
# pack the statically computed variables appropriately here. This is so that we don't
|
||||
# have to pass them to the dataloader.
|
||||
add_time_ids = instance_time_ids
|
||||
if args.with_prior_preservation:
|
||||
add_time_ids = torch.cat([add_time_ids, class_time_ids], dim=0)
|
||||
|
||||
# if --train_text_encoder_ti we need add_special_tokens to be True fo textual inversion
|
||||
add_special_tokens = True if args.train_text_encoder_ti else False
|
||||
@@ -1613,12 +1646,6 @@ def main(args):
|
||||
tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0)
|
||||
tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0)
|
||||
|
||||
if args.train_text_encoder_ti and args.validation_prompt:
|
||||
# replace instances of --token_abstraction in validation prompt with the new tokens: "<si><si+1>" etc.
|
||||
for token_abs, token_replacement in train_dataset.token_abstraction_dict.items():
|
||||
args.validation_prompt = args.validation_prompt.replace(token_abs, "".join(token_replacement))
|
||||
print("validation prompt:", args.validation_prompt)
|
||||
|
||||
if args.cache_latents:
|
||||
latents_cache = []
|
||||
for batch in tqdm(train_dataloader, desc="Caching latents"):
|
||||
@@ -1778,6 +1805,12 @@ def main(args):
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(model_input)
|
||||
if args.noise_offset:
|
||||
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
|
||||
noise += args.noise_offset * torch.randn(
|
||||
(model_input.shape[0], model_input.shape[1], 1, 1), device=model_input.device
|
||||
)
|
||||
|
||||
bsz = model_input.shape[0]
|
||||
# Sample a random timestep for each image
|
||||
timesteps = torch.randint(
|
||||
@@ -1789,19 +1822,26 @@ def main(args):
|
||||
# (this is the forward diffusion process)
|
||||
noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)
|
||||
|
||||
# time ids
|
||||
add_time_ids = torch.cat(
|
||||
[
|
||||
compute_time_ids(original_size=s, crops_coords_top_left=c)
|
||||
for s, c in zip(batch["original_sizes"], batch["crop_top_lefts"])
|
||||
]
|
||||
)
|
||||
|
||||
# Calculate the elements to repeat depending on the use of prior-preservation and custom captions.
|
||||
if not train_dataset.custom_instance_prompts:
|
||||
elems_to_repeat_text_embeds = bsz // 2 if args.with_prior_preservation else bsz
|
||||
elems_to_repeat_time_ids = bsz // 2 if args.with_prior_preservation else bsz
|
||||
|
||||
else:
|
||||
elems_to_repeat_text_embeds = 1
|
||||
elems_to_repeat_time_ids = bsz // 2 if args.with_prior_preservation else bsz
|
||||
|
||||
# Predict the noise residual
|
||||
if freeze_text_encoder:
|
||||
unet_added_conditions = {
|
||||
"time_ids": add_time_ids.repeat(elems_to_repeat_time_ids, 1),
|
||||
"time_ids": add_time_ids,
|
||||
# "time_ids": add_time_ids.repeat(elems_to_repeat_time_ids, 1),
|
||||
"text_embeds": unet_add_text_embeds.repeat(elems_to_repeat_text_embeds, 1),
|
||||
}
|
||||
prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1)
|
||||
@@ -1812,7 +1852,7 @@ def main(args):
|
||||
added_cond_kwargs=unet_added_conditions,
|
||||
).sample
|
||||
else:
|
||||
unet_added_conditions = {"time_ids": add_time_ids.repeat(elems_to_repeat_time_ids, 1)}
|
||||
unet_added_conditions = {"time_ids": add_time_ids}
|
||||
prompt_embeds, pooled_prompt_embeds = encode_prompt(
|
||||
text_encoders=[text_encoder_one, text_encoder_two],
|
||||
tokenizers=None,
|
||||
@@ -1954,6 +1994,8 @@ def main(args):
|
||||
pipeline = StableDiffusionXLPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
vae=vae,
|
||||
tokenizer=tokenizer_one,
|
||||
tokenizer_2=tokenizer_two,
|
||||
text_encoder=accelerator.unwrap_model(text_encoder_one),
|
||||
text_encoder_2=accelerator.unwrap_model(text_encoder_two),
|
||||
unet=accelerator.unwrap_model(unet),
|
||||
@@ -2033,6 +2075,11 @@ def main(args):
|
||||
text_encoder_lora_layers=text_encoder_lora_layers,
|
||||
text_encoder_2_lora_layers=text_encoder_2_lora_layers,
|
||||
)
|
||||
|
||||
if args.train_text_encoder_ti:
|
||||
embeddings_path = f"{args.output_dir}/{args.output_dir}_emb.safetensors"
|
||||
embedding_handler.save_embeddings(embeddings_path)
|
||||
|
||||
images = []
|
||||
if args.validation_prompt and args.num_validation_images > 0:
|
||||
# Final inference
|
||||
@@ -2068,6 +2115,25 @@ def main(args):
|
||||
# load attention processors
|
||||
pipeline.load_lora_weights(args.output_dir)
|
||||
|
||||
# load new tokens
|
||||
if args.train_text_encoder_ti:
|
||||
state_dict = load_file(embeddings_path)
|
||||
all_new_tokens = []
|
||||
for key, value in token_abstraction_dict.items():
|
||||
all_new_tokens.extend(value)
|
||||
pipeline.load_textual_inversion(
|
||||
state_dict["clip_l"],
|
||||
token=all_new_tokens,
|
||||
text_encoder=pipeline.text_encoder,
|
||||
tokenizer=pipeline.tokenizer,
|
||||
)
|
||||
pipeline.load_textual_inversion(
|
||||
state_dict["clip_g"],
|
||||
token=all_new_tokens,
|
||||
text_encoder=pipeline.text_encoder_2,
|
||||
tokenizer=pipeline.tokenizer_2,
|
||||
)
|
||||
|
||||
# run inference
|
||||
pipeline = pipeline.to(accelerator.device)
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
|
||||
@@ -2090,11 +2156,6 @@ def main(args):
|
||||
}
|
||||
)
|
||||
|
||||
if args.train_text_encoder_ti:
|
||||
embedding_handler.save_embeddings(
|
||||
f"{args.output_dir}/{args.output_dir}_emb.safetensors",
|
||||
)
|
||||
|
||||
# Conver to WebUI format
|
||||
lora_state_dict = load_file(f"{args.output_dir}/pytorch_lora_weights.safetensors")
|
||||
peft_state_dict = convert_all_state_dict_to_peft(lora_state_dict)
|
||||
|
||||
@@ -104,6 +104,22 @@ class LoRAIPAdapterAttnProcessor(nn.Module):
|
||||
):
|
||||
residual = hidden_states
|
||||
|
||||
# separate ip_hidden_states from encoder_hidden_states
|
||||
if encoder_hidden_states is not None:
|
||||
if isinstance(encoder_hidden_states, tuple):
|
||||
encoder_hidden_states, ip_hidden_states = encoder_hidden_states
|
||||
else:
|
||||
deprecation_message = (
|
||||
"You have passed a tensor as `encoder_hidden_states`.This is deprecated and will be removed in a future release."
|
||||
" Please make sure to update your script to pass `encoder_hidden_states` as a tuple to supress this warning."
|
||||
)
|
||||
deprecate("encoder_hidden_states not a tuple", "1.0.0", deprecation_message, standard_warn=False)
|
||||
end_pos = encoder_hidden_states.shape[1] - self.num_tokens[0]
|
||||
encoder_hidden_states, ip_hidden_states = (
|
||||
encoder_hidden_states[:, :end_pos, :],
|
||||
[encoder_hidden_states[:, end_pos:, :]],
|
||||
)
|
||||
|
||||
if attn.spatial_norm is not None:
|
||||
hidden_states = attn.spatial_norm(hidden_states, temb)
|
||||
|
||||
@@ -125,15 +141,8 @@ class LoRAIPAdapterAttnProcessor(nn.Module):
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
else:
|
||||
# get encoder_hidden_states, ip_hidden_states
|
||||
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
|
||||
encoder_hidden_states, ip_hidden_states = (
|
||||
encoder_hidden_states[:, :end_pos, :],
|
||||
encoder_hidden_states[:, end_pos:, :],
|
||||
)
|
||||
if attn.norm_cross:
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||
elif attn.norm_cross:
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||
|
||||
key = attn.to_k(encoder_hidden_states) + self.lora_scale * self.to_k_lora(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states) + self.lora_scale * self.to_v_lora(encoder_hidden_states)
|
||||
@@ -233,6 +242,22 @@ class LoRAIPAdapterAttnProcessor2_0(nn.Module):
|
||||
):
|
||||
residual = hidden_states
|
||||
|
||||
# separate ip_hidden_states from encoder_hidden_states
|
||||
if encoder_hidden_states is not None:
|
||||
if isinstance(encoder_hidden_states, tuple):
|
||||
encoder_hidden_states, ip_hidden_states = encoder_hidden_states
|
||||
else:
|
||||
deprecation_message = (
|
||||
"You have passed a tensor as `encoder_hidden_states`.This is deprecated and will be removed in a future release."
|
||||
" Please make sure to update your script to pass `encoder_hidden_states` as a tuple to supress this warning."
|
||||
)
|
||||
deprecate("encoder_hidden_states not a tuple", "1.0.0", deprecation_message, standard_warn=False)
|
||||
end_pos = encoder_hidden_states.shape[1] - self.num_tokens[0]
|
||||
encoder_hidden_states, ip_hidden_states = (
|
||||
encoder_hidden_states[:, :end_pos, :],
|
||||
[encoder_hidden_states[:, end_pos:, :]],
|
||||
)
|
||||
|
||||
if attn.spatial_norm is not None:
|
||||
hidden_states = attn.spatial_norm(hidden_states, temb)
|
||||
|
||||
@@ -259,15 +284,8 @@ class LoRAIPAdapterAttnProcessor2_0(nn.Module):
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
else:
|
||||
# get encoder_hidden_states, ip_hidden_states
|
||||
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
|
||||
encoder_hidden_states, ip_hidden_states = (
|
||||
encoder_hidden_states[:, :end_pos, :],
|
||||
encoder_hidden_states[:, end_pos:, :],
|
||||
)
|
||||
if attn.norm_cross:
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||
elif attn.norm_cross:
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||
|
||||
key = attn.to_k(encoder_hidden_states) + self.lora_scale * self.to_k_lora(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states) + self.lora_scale * self.to_v_lora(encoder_hidden_states)
|
||||
@@ -951,30 +969,6 @@ class IPAdapterFaceIDStableDiffusionPipeline(
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
|
||||
dtype = next(self.image_encoder.parameters()).dtype
|
||||
|
||||
if not isinstance(image, torch.Tensor):
|
||||
image = self.feature_extractor(image, return_tensors="pt").pixel_values
|
||||
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
if output_hidden_states:
|
||||
image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
|
||||
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
uncond_image_enc_hidden_states = self.image_encoder(
|
||||
torch.zeros_like(image), output_hidden_states=True
|
||||
).hidden_states[-2]
|
||||
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
|
||||
num_images_per_prompt, dim=0
|
||||
)
|
||||
return image_enc_hidden_states, uncond_image_enc_hidden_states
|
||||
else:
|
||||
image_embeds = self.image_encoder(image).image_embeds
|
||||
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
uncond_image_embeds = torch.zeros_like(image_embeds)
|
||||
|
||||
return image_embeds, uncond_image_embeds
|
||||
|
||||
def run_safety_checker(self, image, device, dtype):
|
||||
if self.safety_checker is None:
|
||||
has_nsfw_concept = None
|
||||
@@ -1302,7 +1296,6 @@ class IPAdapterFaceIDStableDiffusionPipeline(
|
||||
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
|
||||
image_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated image embeddings.
|
||||
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
@@ -1411,7 +1404,7 @@ class IPAdapterFaceIDStableDiffusionPipeline(
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
if image_embeds is not None:
|
||||
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0).to(
|
||||
image_embeds = torch.stack([image_embeds] * num_images_per_prompt, dim=0).to(
|
||||
device=device, dtype=prompt_embeds.dtype
|
||||
)
|
||||
negative_image_embeds = torch.zeros_like(image_embeds)
|
||||
|
||||
@@ -538,7 +538,7 @@ class StableDiffusionReferencePipeline(StableDiffusionPipeline):
|
||||
|
||||
return hidden_states, output_states
|
||||
|
||||
def hacked_DownBlock2D_forward(self, hidden_states, temb=None):
|
||||
def hacked_DownBlock2D_forward(self, hidden_states, temb=None, **kwargs):
|
||||
eps = 1e-6
|
||||
|
||||
output_states = ()
|
||||
@@ -634,7 +634,9 @@ class StableDiffusionReferencePipeline(StableDiffusionPipeline):
|
||||
|
||||
return hidden_states
|
||||
|
||||
def hacked_UpBlock2D_forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
|
||||
def hacked_UpBlock2D_forward(
|
||||
self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, **kwargs
|
||||
):
|
||||
eps = 1e-6
|
||||
for i, resnet in enumerate(self.resnets):
|
||||
# pop res hidden states
|
||||
|
||||
@@ -507,7 +507,7 @@ class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline):
|
||||
|
||||
return hidden_states, output_states
|
||||
|
||||
def hacked_DownBlock2D_forward(self, hidden_states, temb=None):
|
||||
def hacked_DownBlock2D_forward(self, hidden_states, temb=None, **kwargs):
|
||||
eps = 1e-6
|
||||
|
||||
output_states = ()
|
||||
@@ -603,7 +603,9 @@ class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline):
|
||||
|
||||
return hidden_states
|
||||
|
||||
def hacked_UpBlock2D_forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
|
||||
def hacked_UpBlock2D_forward(
|
||||
self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, **kwargs
|
||||
):
|
||||
eps = 1e-6
|
||||
for i, resnet in enumerate(self.resnets):
|
||||
# pop res hidden states
|
||||
|
||||
@@ -19,6 +19,7 @@ import itertools
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import shutil
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
@@ -40,6 +41,7 @@ from PIL import Image
|
||||
from PIL.ImageOps import exif_transpose
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision import transforms
|
||||
from torchvision.transforms.functional import crop
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import AutoTokenizer, PretrainedConfig
|
||||
|
||||
@@ -304,18 +306,6 @@ def parse_args(input_args=None):
|
||||
" resolution"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--crops_coords_top_left_h",
|
||||
type=int,
|
||||
default=0,
|
||||
help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--crops_coords_top_left_w",
|
||||
type=int,
|
||||
default=0,
|
||||
help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--center_crop",
|
||||
default=False,
|
||||
@@ -325,6 +315,11 @@ def parse_args(input_args=None):
|
||||
" cropped. The images will be resized to the resolution first before cropping."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--random_flip",
|
||||
action="store_true",
|
||||
help="whether to randomly flip images horizontally",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train_text_encoder",
|
||||
action="store_true",
|
||||
@@ -669,6 +664,41 @@ class DreamBoothDataset(Dataset):
|
||||
self.instance_images = []
|
||||
for img in instance_images:
|
||||
self.instance_images.extend(itertools.repeat(img, repeats))
|
||||
|
||||
# image processing to prepare for using SD-XL micro-conditioning
|
||||
self.original_sizes = []
|
||||
self.crop_top_lefts = []
|
||||
self.pixel_values = []
|
||||
train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR)
|
||||
train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size)
|
||||
train_flip = transforms.RandomHorizontalFlip(p=1.0)
|
||||
train_transforms = transforms.Compose(
|
||||
[
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.5], [0.5]),
|
||||
]
|
||||
)
|
||||
for image in self.instance_images:
|
||||
image = exif_transpose(image)
|
||||
if not image.mode == "RGB":
|
||||
image = image.convert("RGB")
|
||||
self.original_sizes.append((image.height, image.width))
|
||||
image = train_resize(image)
|
||||
if args.random_flip and random.random() < 0.5:
|
||||
# flip
|
||||
image = train_flip(image)
|
||||
if args.center_crop:
|
||||
y1 = max(0, int(round((image.height - args.resolution) / 2.0)))
|
||||
x1 = max(0, int(round((image.width - args.resolution) / 2.0)))
|
||||
image = train_crop(image)
|
||||
else:
|
||||
y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution))
|
||||
image = crop(image, y1, x1, h, w)
|
||||
crop_top_left = (y1, x1)
|
||||
self.crop_top_lefts.append(crop_top_left)
|
||||
image = train_transforms(image)
|
||||
self.pixel_values.append(image)
|
||||
|
||||
self.num_instance_images = len(self.instance_images)
|
||||
self._length = self.num_instance_images
|
||||
|
||||
@@ -698,12 +728,12 @@ class DreamBoothDataset(Dataset):
|
||||
|
||||
def __getitem__(self, index):
|
||||
example = {}
|
||||
instance_image = self.instance_images[index % self.num_instance_images]
|
||||
instance_image = exif_transpose(instance_image)
|
||||
|
||||
if not instance_image.mode == "RGB":
|
||||
instance_image = instance_image.convert("RGB")
|
||||
example["instance_images"] = self.image_transforms(instance_image)
|
||||
instance_image = self.pixel_values[index % self.num_instance_images]
|
||||
original_size = self.original_sizes[index % self.num_instance_images]
|
||||
crop_top_left = self.crop_top_lefts[index % self.num_instance_images]
|
||||
example["instance_images"] = instance_image
|
||||
example["original_size"] = original_size
|
||||
example["crop_top_left"] = crop_top_left
|
||||
|
||||
if self.custom_instance_prompts:
|
||||
caption = self.custom_instance_prompts[index % self.num_instance_images]
|
||||
@@ -730,6 +760,8 @@ class DreamBoothDataset(Dataset):
|
||||
def collate_fn(examples, with_prior_preservation=False):
|
||||
pixel_values = [example["instance_images"] for example in examples]
|
||||
prompts = [example["instance_prompt"] for example in examples]
|
||||
original_sizes = [example["original_size"] for example in examples]
|
||||
crop_top_lefts = [example["crop_top_left"] for example in examples]
|
||||
|
||||
# Concat class and instance examples for prior preservation.
|
||||
# We do this to avoid doing two forward passes.
|
||||
@@ -740,7 +772,12 @@ def collate_fn(examples, with_prior_preservation=False):
|
||||
pixel_values = torch.stack(pixel_values)
|
||||
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
|
||||
|
||||
batch = {"pixel_values": pixel_values, "prompts": prompts}
|
||||
batch = {
|
||||
"pixel_values": pixel_values,
|
||||
"prompts": prompts,
|
||||
"original_sizes": original_sizes,
|
||||
"crop_top_lefts": crop_top_lefts,
|
||||
}
|
||||
return batch
|
||||
|
||||
|
||||
@@ -1233,11 +1270,9 @@ def main(args):
|
||||
# pooled text embeddings
|
||||
# time ids
|
||||
|
||||
def compute_time_ids():
|
||||
def compute_time_ids(original_size, crops_coords_top_left):
|
||||
# Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
|
||||
original_size = (args.resolution, args.resolution)
|
||||
target_size = (args.resolution, args.resolution)
|
||||
crops_coords_top_left = (args.crops_coords_top_left_h, args.crops_coords_top_left_w)
|
||||
add_time_ids = list(original_size + crops_coords_top_left + target_size)
|
||||
add_time_ids = torch.tensor([add_time_ids])
|
||||
add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype)
|
||||
@@ -1254,9 +1289,6 @@ def main(args):
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device)
|
||||
return prompt_embeds, pooled_prompt_embeds
|
||||
|
||||
# Handle instance prompt.
|
||||
instance_time_ids = compute_time_ids()
|
||||
|
||||
# If no type of tuning is done on the text_encoder and custom instance prompts are NOT
|
||||
# provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid
|
||||
# the redundant encoding.
|
||||
@@ -1267,7 +1299,6 @@ def main(args):
|
||||
|
||||
# Handle class prompt for prior-preservation.
|
||||
if args.with_prior_preservation:
|
||||
class_time_ids = compute_time_ids()
|
||||
if not args.train_text_encoder:
|
||||
class_prompt_hidden_states, class_pooled_prompt_embeds = compute_text_embeddings(
|
||||
args.class_prompt, text_encoders, tokenizers
|
||||
@@ -1282,9 +1313,6 @@ def main(args):
|
||||
# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
|
||||
# pack the statically computed variables appropriately here. This is so that we don't
|
||||
# have to pass them to the dataloader.
|
||||
add_time_ids = instance_time_ids
|
||||
if args.with_prior_preservation:
|
||||
add_time_ids = torch.cat([add_time_ids, class_time_ids], dim=0)
|
||||
|
||||
if not train_dataset.custom_instance_prompts:
|
||||
if not args.train_text_encoder:
|
||||
@@ -1399,8 +1427,8 @@ def main(args):
|
||||
text_encoder_two.train()
|
||||
|
||||
# set top parameter requires_grad = True for gradient checkpointing works
|
||||
text_encoder_one.text_model.embeddings.requires_grad_(True)
|
||||
text_encoder_two.text_model.embeddings.requires_grad_(True)
|
||||
accelerator.unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True)
|
||||
accelerator.unwrap_model(text_encoder_two).text_model.embeddings.requires_grad_(True)
|
||||
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
with accelerator.accumulate(unet):
|
||||
@@ -1436,18 +1464,24 @@ def main(args):
|
||||
# (this is the forward diffusion process)
|
||||
noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)
|
||||
|
||||
# time ids
|
||||
add_time_ids = torch.cat(
|
||||
[
|
||||
compute_time_ids(original_size=s, crops_coords_top_left=c)
|
||||
for s, c in zip(batch["original_sizes"], batch["crop_top_lefts"])
|
||||
]
|
||||
)
|
||||
|
||||
# Calculate the elements to repeat depending on the use of prior-preservation and custom captions.
|
||||
if not train_dataset.custom_instance_prompts:
|
||||
elems_to_repeat_text_embeds = bsz // 2 if args.with_prior_preservation else bsz
|
||||
elems_to_repeat_time_ids = bsz // 2 if args.with_prior_preservation else bsz
|
||||
else:
|
||||
elems_to_repeat_text_embeds = 1
|
||||
elems_to_repeat_time_ids = bsz // 2 if args.with_prior_preservation else bsz
|
||||
|
||||
# Predict the noise residual
|
||||
if not args.train_text_encoder:
|
||||
unet_added_conditions = {
|
||||
"time_ids": add_time_ids.repeat(elems_to_repeat_time_ids, 1),
|
||||
"time_ids": add_time_ids,
|
||||
"text_embeds": unet_add_text_embeds.repeat(elems_to_repeat_text_embeds, 1),
|
||||
}
|
||||
prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1)
|
||||
@@ -1459,7 +1493,7 @@ def main(args):
|
||||
return_dict=False,
|
||||
)[0]
|
||||
else:
|
||||
unet_added_conditions = {"time_ids": add_time_ids.repeat(elems_to_repeat_time_ids, 1)}
|
||||
unet_added_conditions = {"time_ids": add_time_ids}
|
||||
prompt_embeds, pooled_prompt_embeds = encode_prompt(
|
||||
text_encoders=[text_encoder_one, text_encoder_two],
|
||||
tokenizers=None,
|
||||
|
||||
@@ -158,6 +158,12 @@ class BasicTransformerBlock(nn.Module):
|
||||
super().__init__()
|
||||
self.only_cross_attention = only_cross_attention
|
||||
|
||||
self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
|
||||
self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
|
||||
self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
|
||||
self.use_layer_norm = norm_type == "layer_norm"
|
||||
self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous"
|
||||
|
||||
if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
|
||||
raise ValueError(
|
||||
f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
|
||||
|
||||
@@ -292,7 +292,9 @@ class AutoencoderTiny(ModelMixin, ConfigMixin):
|
||||
self, x: torch.FloatTensor, return_dict: bool = True
|
||||
) -> Union[AutoencoderTinyOutput, Tuple[torch.FloatTensor]]:
|
||||
if self.use_slicing and x.shape[0] > 1:
|
||||
output = [self._tiled_encode(x_slice) if self.use_tiling else self.encoder(x) for x_slice in x.split(1)]
|
||||
output = [
|
||||
self._tiled_encode(x_slice) if self.use_tiling else self.encoder(x_slice) for x_slice in x.split(1)
|
||||
]
|
||||
output = torch.cat(output)
|
||||
else:
|
||||
output = self._tiled_encode(x) if self.use_tiling else self.encoder(x)
|
||||
|
||||
@@ -17,7 +17,7 @@ from typing import Any, Dict, Optional, Tuple, Union
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from ...utils import is_torch_version
|
||||
from ...utils import deprecate, is_torch_version
|
||||
from ...utils.torch_utils import apply_freeu
|
||||
from ..attention import Attention
|
||||
from ..resnet import (
|
||||
@@ -44,7 +44,8 @@ def get_down_block(
|
||||
add_downsample: bool,
|
||||
resnet_eps: float,
|
||||
resnet_act_fn: str,
|
||||
num_attention_heads: int,
|
||||
num_attention_heads: Optional[int] = None,
|
||||
attention_head_dim: Optional[int] = None,
|
||||
resnet_groups: Optional[int] = None,
|
||||
cross_attention_dim: Optional[int] = None,
|
||||
downsample_padding: Optional[int] = None,
|
||||
@@ -80,6 +81,16 @@ def get_down_block(
|
||||
elif down_block_type == "CrossAttnDownBlock3D":
|
||||
if cross_attention_dim is None:
|
||||
raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D")
|
||||
if num_attention_heads is not None:
|
||||
deprecation_message = (
|
||||
" passing `num`_attention_heads` to `unet_3d_blocks.get_down_block` for CrossAttnDownBlock3D is deprecated. "
|
||||
" Please use `attention_head_dim` instead."
|
||||
)
|
||||
deprecate("num_attention_heads not None", "1.0.0", deprecation_message, standard_warn=False)
|
||||
if attention_head_dim is None:
|
||||
attention_head_dim = num_attention_heads
|
||||
if attention_head_dim is None:
|
||||
raise ValueError("`attention_head_dim` must be specified for CrossAttnDownBlock3D")
|
||||
return CrossAttnDownBlock3D(
|
||||
num_layers=num_layers,
|
||||
in_channels=in_channels,
|
||||
@@ -91,7 +102,8 @@ def get_down_block(
|
||||
resnet_groups=resnet_groups,
|
||||
downsample_padding=downsample_padding,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
num_attention_heads=None,
|
||||
attention_head_dim=attention_head_dim,
|
||||
dual_cross_attention=dual_cross_attention,
|
||||
use_linear_projection=use_linear_projection,
|
||||
only_cross_attention=only_cross_attention,
|
||||
@@ -173,7 +185,8 @@ def get_up_block(
|
||||
add_upsample: bool,
|
||||
resnet_eps: float,
|
||||
resnet_act_fn: str,
|
||||
num_attention_heads: int,
|
||||
num_attention_heads: Optional[int] = None,
|
||||
attention_head_dim: Optional[int] = None,
|
||||
resolution_idx: Optional[int] = None,
|
||||
resnet_groups: Optional[int] = None,
|
||||
cross_attention_dim: Optional[int] = None,
|
||||
@@ -212,6 +225,16 @@ def get_up_block(
|
||||
elif up_block_type == "CrossAttnUpBlock3D":
|
||||
if cross_attention_dim is None:
|
||||
raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D")
|
||||
if num_attention_heads is not None:
|
||||
deprecation_message = (
|
||||
" passing `num`_attention_heads` to `unet_3d_blocks.get_up_block` for CrossAttnUpBlock3D is deprecated. "
|
||||
" Please use `attention_head_dim` instead."
|
||||
)
|
||||
deprecate("num_attention_heads not None", "1.0.0", deprecation_message, standard_warn=False)
|
||||
if attention_head_dim is None:
|
||||
attention_head_dim = num_attention_heads
|
||||
if attention_head_dim is None:
|
||||
raise ValueError("`attention_head_dim` must be specified for CrossAttnUpBlock3D")
|
||||
return CrossAttnUpBlock3D(
|
||||
num_layers=num_layers,
|
||||
in_channels=in_channels,
|
||||
@@ -223,7 +246,8 @@ def get_up_block(
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
resnet_groups=resnet_groups,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
num_attention_heads=None,
|
||||
attention_head_dim=attention_head_dim,
|
||||
dual_cross_attention=dual_cross_attention,
|
||||
use_linear_projection=use_linear_projection,
|
||||
only_cross_attention=only_cross_attention,
|
||||
@@ -314,7 +338,8 @@ class UNetMidBlock3DCrossAttn(nn.Module):
|
||||
resnet_act_fn: str = "swish",
|
||||
resnet_groups: int = 32,
|
||||
resnet_pre_norm: bool = True,
|
||||
num_attention_heads: int = 1,
|
||||
num_attention_heads: Optional[int] = 1,
|
||||
attention_head_dim: Optional[int] = None,
|
||||
output_scale_factor: float = 1.0,
|
||||
cross_attention_dim: int = 1280,
|
||||
dual_cross_attention: bool = False,
|
||||
@@ -322,9 +347,19 @@ class UNetMidBlock3DCrossAttn(nn.Module):
|
||||
upcast_attention: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if num_attention_heads is not None:
|
||||
deprecation_message = (
|
||||
" passing `num`_attention_heads` to `unet_3d_blocks.UNetMidBlock3DCrossAttn` is deprecated. "
|
||||
" Please use `attention_head_dim` instead."
|
||||
)
|
||||
deprecate("num_attention_heads not None", "1.0.0", deprecation_message, standard_warn=False)
|
||||
if attention_head_dim is None:
|
||||
attention_head_dim = num_attention_heads
|
||||
self.num_attention_heads = num_attention_heads
|
||||
if attention_head_dim is None:
|
||||
raise ValueError("`attention_head_dim` must be specified for UNetMidBlock3DCrossAttn")
|
||||
self.has_cross_attention = True
|
||||
self.num_attention_heads = num_attention_heads
|
||||
|
||||
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
||||
|
||||
# there is always at least one resnet
|
||||
@@ -356,8 +391,8 @@ class UNetMidBlock3DCrossAttn(nn.Module):
|
||||
for _ in range(num_layers):
|
||||
attentions.append(
|
||||
Transformer2DModel(
|
||||
in_channels // num_attention_heads,
|
||||
num_attention_heads,
|
||||
in_channels // attention_head_dim,
|
||||
attention_head_dim,
|
||||
in_channels=in_channels,
|
||||
num_layers=1,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
@@ -368,8 +403,8 @@ class UNetMidBlock3DCrossAttn(nn.Module):
|
||||
)
|
||||
temp_attentions.append(
|
||||
TransformerTemporalModel(
|
||||
in_channels // num_attention_heads,
|
||||
num_attention_heads,
|
||||
in_channels // attention_head_dim,
|
||||
attention_head_dim,
|
||||
in_channels=in_channels,
|
||||
num_layers=1,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
@@ -449,7 +484,8 @@ class CrossAttnDownBlock3D(nn.Module):
|
||||
resnet_act_fn: str = "swish",
|
||||
resnet_groups: int = 32,
|
||||
resnet_pre_norm: bool = True,
|
||||
num_attention_heads: int = 1,
|
||||
num_attention_heads: Optional[int] = 1,
|
||||
attention_head_dim: Optional[int] = None,
|
||||
cross_attention_dim: int = 1280,
|
||||
output_scale_factor: float = 1.0,
|
||||
downsample_padding: int = 1,
|
||||
@@ -460,13 +496,23 @@ class CrossAttnDownBlock3D(nn.Module):
|
||||
upcast_attention: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
if num_attention_heads is not None:
|
||||
deprecation_message = (
|
||||
" passing `num`_attention_heads` to `unet_3d_blocks.CrossAttnDownBlock3D` is deprecated. "
|
||||
" Please use `attention_head_dim` instead."
|
||||
)
|
||||
deprecate("num_attention_heads not None", "1.0.0", deprecation_message, standard_warn=False)
|
||||
if attention_head_dim is None:
|
||||
attention_head_dim = num_attention_heads
|
||||
self.num_attention_heads = num_attention_heads
|
||||
if attention_head_dim is None:
|
||||
raise ValueError("`attention_head_dim` must be specified for CrossAttnDownBlock3D")
|
||||
resnets = []
|
||||
attentions = []
|
||||
temp_attentions = []
|
||||
temp_convs = []
|
||||
|
||||
self.has_cross_attention = True
|
||||
self.num_attention_heads = num_attention_heads
|
||||
|
||||
for i in range(num_layers):
|
||||
in_channels = in_channels if i == 0 else out_channels
|
||||
@@ -494,8 +540,8 @@ class CrossAttnDownBlock3D(nn.Module):
|
||||
)
|
||||
attentions.append(
|
||||
Transformer2DModel(
|
||||
out_channels // num_attention_heads,
|
||||
num_attention_heads,
|
||||
out_channels // attention_head_dim,
|
||||
attention_head_dim,
|
||||
in_channels=out_channels,
|
||||
num_layers=1,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
@@ -507,8 +553,8 @@ class CrossAttnDownBlock3D(nn.Module):
|
||||
)
|
||||
temp_attentions.append(
|
||||
TransformerTemporalModel(
|
||||
out_channels // num_attention_heads,
|
||||
num_attention_heads,
|
||||
out_channels // attention_head_dim,
|
||||
attention_head_dim,
|
||||
in_channels=out_channels,
|
||||
num_layers=1,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
@@ -681,7 +727,8 @@ class CrossAttnUpBlock3D(nn.Module):
|
||||
resnet_act_fn: str = "swish",
|
||||
resnet_groups: int = 32,
|
||||
resnet_pre_norm: bool = True,
|
||||
num_attention_heads: int = 1,
|
||||
num_attention_heads: Optional[int] = 1,
|
||||
attention_head_dim: Optional[int] = None,
|
||||
cross_attention_dim: int = 1280,
|
||||
output_scale_factor: float = 1.0,
|
||||
add_upsample: bool = True,
|
||||
@@ -692,13 +739,25 @@ class CrossAttnUpBlock3D(nn.Module):
|
||||
resolution_idx: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
if num_attention_heads is not None:
|
||||
deprecation_message = (
|
||||
" passing `num`_attention_heads` to `unet_3d_blocks.CrossAttnUpBlock3D` is deprecated. "
|
||||
" Please use `attention_head_dim` instead."
|
||||
)
|
||||
deprecate("num_attention_heads not None", "1.0.0", deprecation_message, standard_warn=False)
|
||||
if attention_head_dim is None:
|
||||
attention_head_dim = num_attention_heads
|
||||
self.num_attention_heads = num_attention_heads
|
||||
|
||||
if attention_head_dim is None:
|
||||
raise ValueError("`attention_head_dim` must be specified for CrossAttnUpBlock3D")
|
||||
|
||||
resnets = []
|
||||
temp_convs = []
|
||||
attentions = []
|
||||
temp_attentions = []
|
||||
|
||||
self.has_cross_attention = True
|
||||
self.num_attention_heads = num_attention_heads
|
||||
|
||||
for i in range(num_layers):
|
||||
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
||||
@@ -728,8 +787,8 @@ class CrossAttnUpBlock3D(nn.Module):
|
||||
)
|
||||
attentions.append(
|
||||
Transformer2DModel(
|
||||
out_channels // num_attention_heads,
|
||||
num_attention_heads,
|
||||
out_channels // attention_head_dim,
|
||||
attention_head_dim,
|
||||
in_channels=out_channels,
|
||||
num_layers=1,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
@@ -741,8 +800,8 @@ class CrossAttnUpBlock3D(nn.Module):
|
||||
)
|
||||
temp_attentions.append(
|
||||
TransformerTemporalModel(
|
||||
out_channels // num_attention_heads,
|
||||
num_attention_heads,
|
||||
out_channels // attention_head_dim,
|
||||
attention_head_dim,
|
||||
in_channels=out_channels,
|
||||
num_layers=1,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
@@ -1031,16 +1090,10 @@ class DownBlockMotion(nn.Module):
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet), hidden_states, temb, scale
|
||||
)
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(motion_module),
|
||||
hidden_states.requires_grad_(),
|
||||
temb,
|
||||
num_frames,
|
||||
)
|
||||
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb, scale=scale)
|
||||
hidden_states = motion_module(hidden_states, num_frames=num_frames)[0]
|
||||
hidden_states = motion_module(hidden_states, num_frames=num_frames)[0]
|
||||
|
||||
output_states = output_states + (hidden_states,)
|
||||
|
||||
@@ -1221,10 +1274,10 @@ class CrossAttnDownBlockMotion(nn.Module):
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
hidden_states = motion_module(
|
||||
hidden_states,
|
||||
num_frames=num_frames,
|
||||
)[0]
|
||||
hidden_states = motion_module(
|
||||
hidden_states,
|
||||
num_frames=num_frames,
|
||||
)[0]
|
||||
|
||||
# apply additional residuals to the output of the last pair of resnet and attention blocks
|
||||
if i == len(blocks) - 1 and additional_residuals is not None:
|
||||
@@ -1425,10 +1478,10 @@ class CrossAttnUpBlockMotion(nn.Module):
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
hidden_states = motion_module(
|
||||
hidden_states,
|
||||
num_frames=num_frames,
|
||||
)[0]
|
||||
hidden_states = motion_module(
|
||||
hidden_states,
|
||||
num_frames=num_frames,
|
||||
)[0]
|
||||
|
||||
if self.upsamplers is not None:
|
||||
for upsampler in self.upsamplers:
|
||||
@@ -1563,15 +1616,10 @@ class UpBlockMotion(nn.Module):
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet), hidden_states, temb
|
||||
)
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet),
|
||||
hidden_states,
|
||||
temb,
|
||||
)
|
||||
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb, scale=scale)
|
||||
hidden_states = motion_module(hidden_states, num_frames=num_frames)[0]
|
||||
hidden_states = motion_module(hidden_states, num_frames=num_frames)[0]
|
||||
|
||||
if self.upsamplers is not None:
|
||||
for upsampler in self.upsamplers:
|
||||
|
||||
@@ -132,14 +132,6 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
"At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
|
||||
)
|
||||
|
||||
# If `num_attention_heads` is not defined (which is the case for most models)
|
||||
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
|
||||
# The reason for this behavior is to correct for incorrectly named variables that were introduced
|
||||
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
|
||||
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
|
||||
# which is why we correct for the naming here.
|
||||
num_attention_heads = num_attention_heads or attention_head_dim
|
||||
|
||||
# Check inputs
|
||||
if len(down_block_types) != len(up_block_types):
|
||||
raise ValueError(
|
||||
@@ -151,9 +143,9 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
|
||||
)
|
||||
|
||||
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
|
||||
if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
|
||||
raise ValueError(
|
||||
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
|
||||
f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
|
||||
)
|
||||
|
||||
# input
|
||||
@@ -187,8 +179,8 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
self.down_blocks = nn.ModuleList([])
|
||||
self.up_blocks = nn.ModuleList([])
|
||||
|
||||
if isinstance(num_attention_heads, int):
|
||||
num_attention_heads = (num_attention_heads,) * len(down_block_types)
|
||||
if isinstance(attention_head_dim, int):
|
||||
attention_head_dim = (attention_head_dim,) * len(down_block_types)
|
||||
|
||||
# down
|
||||
output_channel = block_out_channels[0]
|
||||
@@ -208,7 +200,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
resnet_act_fn=act_fn,
|
||||
resnet_groups=norm_num_groups,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
num_attention_heads=num_attention_heads[i],
|
||||
attention_head_dim=attention_head_dim[i],
|
||||
downsample_padding=downsample_padding,
|
||||
dual_cross_attention=False,
|
||||
)
|
||||
@@ -222,7 +214,8 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
resnet_act_fn=act_fn,
|
||||
output_scale_factor=mid_block_scale_factor,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
num_attention_heads=num_attention_heads[-1],
|
||||
num_attention_heads=None,
|
||||
attention_head_dim=attention_head_dim[-1],
|
||||
resnet_groups=norm_num_groups,
|
||||
dual_cross_attention=False,
|
||||
)
|
||||
@@ -232,7 +225,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
|
||||
# up
|
||||
reversed_block_out_channels = list(reversed(block_out_channels))
|
||||
reversed_num_attention_heads = list(reversed(num_attention_heads))
|
||||
reversed_attention_head_dim = list(reversed(attention_head_dim))
|
||||
|
||||
output_channel = reversed_block_out_channels[0]
|
||||
for i, up_block_type in enumerate(up_block_types):
|
||||
@@ -261,7 +254,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
resnet_act_fn=act_fn,
|
||||
resnet_groups=norm_num_groups,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
num_attention_heads=reversed_num_attention_heads[i],
|
||||
attention_head_dim=reversed_attention_head_dim[i],
|
||||
dual_cross_attention=False,
|
||||
resolution_idx=i,
|
||||
)
|
||||
|
||||
@@ -792,6 +792,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
||||
|
||||
emb = self.time_embedding(t_emb, timestep_cond)
|
||||
emb = emb.repeat_interleave(repeats=num_frames, dim=0)
|
||||
encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=num_frames, dim=0)
|
||||
|
||||
if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj":
|
||||
if "image_embeds" not in added_cond_kwargs:
|
||||
@@ -799,10 +800,9 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
||||
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
|
||||
)
|
||||
image_embeds = added_cond_kwargs.get("image_embeds")
|
||||
image_embeds = self.encoder_hid_proj(image_embeds).to(encoder_hidden_states.dtype)
|
||||
encoder_hidden_states = torch.cat([encoder_hidden_states, image_embeds], dim=1)
|
||||
|
||||
encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=num_frames, dim=0)
|
||||
image_embeds = self.encoder_hid_proj(image_embeds)
|
||||
image_embeds = [image_embed.repeat_interleave(repeats=num_frames, dim=0) for image_embed in image_embeds]
|
||||
encoder_hidden_states = (encoder_hidden_states, image_embeds)
|
||||
|
||||
# 2. pre-process
|
||||
sample = sample.permute(0, 2, 1, 3, 4).reshape((sample.shape[0] * num_frames, -1) + sample.shape[3:])
|
||||
|
||||
@@ -11,12 +11,13 @@ from ...utils import BaseOutput
|
||||
@dataclass
|
||||
class AnimateDiffPipelineOutput(BaseOutput):
|
||||
r"""
|
||||
Output class for AnimateDiff pipelines.
|
||||
Output class for AnimateDiff pipelines.
|
||||
|
||||
Args:
|
||||
frames (`List[List[PIL.Image.Image]]` or `torch.Tensor` or `np.ndarray`):
|
||||
List of PIL Images of length `batch_size` or torch.Tensor or np.ndarray of shape
|
||||
`(batch_size, num_frames, height, width, num_channels)`.
|
||||
Args:
|
||||
frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
|
||||
List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing denoised
|
||||
PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
|
||||
`(batch_size, num_frames, channels, height, width)`
|
||||
"""
|
||||
|
||||
frames: Union[List[List[PIL.Image.Image]], torch.Tensor, np.ndarray]
|
||||
frames: Union[torch.Tensor, np.ndarray, List[List[PIL.Image.Image]]]
|
||||
|
||||
@@ -789,6 +789,8 @@ class StableDiffusionControlNetImg2ImgPipeline(
|
||||
|
||||
t_start = max(num_inference_steps - init_timestep, 0)
|
||||
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
|
||||
if hasattr(self.scheduler, "set_begin_index"):
|
||||
self.scheduler.set_begin_index(t_start * self.scheduler.order)
|
||||
|
||||
return timesteps, num_inference_steps - t_start
|
||||
|
||||
|
||||
@@ -705,6 +705,8 @@ class StableDiffusionControlNetInpaintPipeline(
|
||||
|
||||
t_start = max(num_inference_steps - init_timestep, 0)
|
||||
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
|
||||
if hasattr(self.scheduler, "set_begin_index"):
|
||||
self.scheduler.set_begin_index(t_start * self.scheduler.order)
|
||||
|
||||
return timesteps, num_inference_steps - t_start
|
||||
|
||||
|
||||
@@ -871,6 +871,8 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
|
||||
|
||||
t_start = max(num_inference_steps - init_timestep, 0)
|
||||
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
|
||||
if hasattr(self.scheduler, "set_begin_index"):
|
||||
self.scheduler.set_begin_index(t_start * self.scheduler.order)
|
||||
|
||||
return timesteps, num_inference_steps - t_start
|
||||
|
||||
|
||||
@@ -566,6 +566,8 @@ class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor
|
||||
|
||||
t_start = max(num_inference_steps - init_timestep, 0)
|
||||
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
|
||||
if hasattr(self.scheduler, "set_begin_index"):
|
||||
self.scheduler.set_begin_index(t_start * self.scheduler.order)
|
||||
|
||||
return timesteps, num_inference_steps - t_start
|
||||
|
||||
|
||||
@@ -536,6 +536,8 @@ class StableDiffusionInpaintPipelineLegacy(
|
||||
|
||||
t_start = max(num_inference_steps - init_timestep, 0)
|
||||
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
|
||||
if hasattr(self.scheduler, "set_begin_index"):
|
||||
self.scheduler.set_begin_index(t_start * self.scheduler.order)
|
||||
|
||||
return timesteps, num_inference_steps - t_start
|
||||
|
||||
|
||||
@@ -46,6 +46,7 @@ EXAMPLE_DOC_STRING = """
|
||||
```py
|
||||
>>> import torch
|
||||
>>> from diffusers import I2VGenXLPipeline
|
||||
>>> from diffusers.utils import export_to_gif, load_image
|
||||
|
||||
>>> pipeline = I2VGenXLPipeline.from_pretrained("ali-vilab/i2vgen-xl", torch_dtype=torch.float16, variant="fp16")
|
||||
>>> pipeline.enable_model_cpu_offload()
|
||||
@@ -95,15 +96,16 @@ def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type:
|
||||
@dataclass
|
||||
class I2VGenXLPipelineOutput(BaseOutput):
|
||||
r"""
|
||||
Output class for image-to-video pipeline.
|
||||
Output class for image-to-video pipeline.
|
||||
|
||||
Args:
|
||||
frames (`List[np.ndarray]` or `torch.FloatTensor`)
|
||||
List of denoised frames (essentially images) as NumPy arrays of shape `(height, width, num_channels)` or as
|
||||
a `torch` tensor. The length of the list denotes the video length (the number of frames).
|
||||
Args:
|
||||
frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
|
||||
List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing denoised
|
||||
PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
|
||||
`(batch_size, num_frames, channels, height, width)`
|
||||
"""
|
||||
|
||||
frames: Union[List[np.ndarray], torch.FloatTensor]
|
||||
frames: Union[torch.Tensor, np.ndarray, List[List[PIL.Image.Image]]]
|
||||
|
||||
|
||||
class I2VGenXLPipeline(DiffusionPipeline):
|
||||
|
||||
@@ -634,6 +634,8 @@ class LatentConsistencyModelImg2ImgPipeline(
|
||||
|
||||
t_start = max(num_inference_steps - init_timestep, 0)
|
||||
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
|
||||
if hasattr(self.scheduler, "set_begin_index"):
|
||||
self.scheduler.set_begin_index(t_start * self.scheduler.order)
|
||||
|
||||
return timesteps, num_inference_steps - t_start
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ import torch.fft as fft
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
|
||||
|
||||
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from ...loaders import IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel
|
||||
from ...models.lora import adjust_lora_scale_text_encoder
|
||||
from ...models.unets.unet_motion_model import MotionAdapter
|
||||
@@ -200,16 +200,18 @@ class PIAPipelineOutput(BaseOutput):
|
||||
Output class for PIAPipeline.
|
||||
|
||||
Args:
|
||||
frames (`torch.Tensor`, `np.ndarray`, or List[PIL.Image.Image]):
|
||||
frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
|
||||
Nested list of length `batch_size` with denoised PIL image sequences of length `num_frames`,
|
||||
NumPy array of shape `(batch_size, num_frames, channels, height, width,
|
||||
Torch tensor of shape `(batch_size, num_frames, channels, height, width)`.
|
||||
"""
|
||||
|
||||
frames: Union[torch.Tensor, np.ndarray, PIL.Image.Image]
|
||||
frames: Union[torch.Tensor, np.ndarray, List[List[PIL.Image.Image]]]
|
||||
|
||||
|
||||
class PIAPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin, LoraLoaderMixin):
|
||||
class PIAPipeline(
|
||||
DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin, LoraLoaderMixin, FromSingleFileMixin
|
||||
):
|
||||
r"""
|
||||
Pipeline for text-to-video generation.
|
||||
|
||||
@@ -685,6 +687,35 @@ class PIAPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
|
||||
def prepare_ip_adapter_image_embeds(self, ip_adapter_image, device, num_images_per_prompt):
|
||||
if not isinstance(ip_adapter_image, list):
|
||||
ip_adapter_image = [ip_adapter_image]
|
||||
|
||||
if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
|
||||
raise ValueError(
|
||||
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
|
||||
)
|
||||
|
||||
image_embeds = []
|
||||
for single_ip_adapter_image, image_proj_layer in zip(
|
||||
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
|
||||
):
|
||||
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
|
||||
single_image_embeds, single_negative_image_embeds = self.encode_image(
|
||||
single_ip_adapter_image, device, 1, output_hidden_state
|
||||
)
|
||||
single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
|
||||
single_negative_image_embeds = torch.stack([single_negative_image_embeds] * num_images_per_prompt, dim=0)
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
|
||||
single_image_embeds = single_image_embeds.to(device)
|
||||
|
||||
image_embeds.append(single_image_embeds)
|
||||
|
||||
return image_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.prepare_latents
|
||||
def prepare_latents(
|
||||
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
|
||||
@@ -906,6 +937,8 @@ class PIAPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin
|
||||
|
||||
t_start = max(num_inference_steps - init_timestep, 0)
|
||||
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
|
||||
if hasattr(self.scheduler, "set_begin_index"):
|
||||
self.scheduler.set_begin_index(t_start * self.scheduler.order)
|
||||
|
||||
return timesteps, num_inference_steps - t_start
|
||||
|
||||
@@ -1105,12 +1138,9 @@ class PIAPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
if ip_adapter_image is not None:
|
||||
output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
|
||||
image_embeds, negative_image_embeds = self.encode_image(
|
||||
ip_adapter_image, device, num_videos_per_prompt, output_hidden_state
|
||||
image_embeds = self.prepare_ip_adapter_image_embeds(
|
||||
ip_adapter_image, device, batch_size * num_videos_per_prompt
|
||||
)
|
||||
if self.do_classifier_free_guidance:
|
||||
image_embeds = torch.cat([negative_image_embeds, image_embeds])
|
||||
|
||||
# 4. Prepare timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
|
||||
@@ -467,6 +467,8 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader
|
||||
|
||||
t_start = max(num_inference_steps - init_timestep, 0)
|
||||
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
|
||||
if hasattr(self.scheduler, "set_begin_index"):
|
||||
self.scheduler.set_begin_index(t_start * self.scheduler.order)
|
||||
|
||||
return timesteps, num_inference_steps - t_start
|
||||
|
||||
|
||||
@@ -659,6 +659,8 @@ class StableDiffusionImg2ImgPipeline(
|
||||
|
||||
t_start = max(num_inference_steps - init_timestep, 0)
|
||||
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
|
||||
if hasattr(self.scheduler, "set_begin_index"):
|
||||
self.scheduler.set_begin_index(t_start * self.scheduler.order)
|
||||
|
||||
return timesteps, num_inference_steps - t_start
|
||||
|
||||
|
||||
@@ -859,6 +859,8 @@ class StableDiffusionInpaintPipeline(
|
||||
|
||||
t_start = max(num_inference_steps - init_timestep, 0)
|
||||
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
|
||||
if hasattr(self.scheduler, "set_begin_index"):
|
||||
self.scheduler.set_begin_index(t_start * self.scheduler.order)
|
||||
|
||||
return timesteps, num_inference_steps - t_start
|
||||
|
||||
|
||||
@@ -754,6 +754,8 @@ class StableDiffusionDiffEditPipeline(DiffusionPipeline, TextualInversionLoaderM
|
||||
|
||||
t_start = max(num_inference_steps - init_timestep, 0)
|
||||
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
|
||||
if hasattr(self.scheduler, "set_begin_index"):
|
||||
self.scheduler.set_begin_index(t_start * self.scheduler.order)
|
||||
|
||||
return timesteps, num_inference_steps - t_start
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ from dataclasses import dataclass
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
|
||||
from ...utils import (
|
||||
@@ -12,12 +13,13 @@ from ...utils import (
|
||||
@dataclass
|
||||
class TextToVideoSDPipelineOutput(BaseOutput):
|
||||
"""
|
||||
Output class for text-to-video pipelines.
|
||||
Output class for text-to-video pipelines.
|
||||
|
||||
Args:
|
||||
frames (`List[np.ndarray]` or `torch.FloatTensor`)
|
||||
List of denoised frames (essentially images) as NumPy arrays of shape `(height, width, num_channels)` or as
|
||||
a `torch` tensor. The length of the list denotes the video length (the number of frames).
|
||||
Args:
|
||||
frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
|
||||
List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing denoised
|
||||
PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
|
||||
`(batch_size, num_frames, channels, height, width)`
|
||||
"""
|
||||
|
||||
frames: Union[List[np.ndarray], torch.FloatTensor]
|
||||
frames: Union[torch.Tensor, np.ndarray, List[List[PIL.Image.Image]]]
|
||||
|
||||
@@ -554,6 +554,8 @@ class VideoToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor
|
||||
|
||||
t_start = max(num_inference_steps - init_timestep, 0)
|
||||
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
|
||||
if hasattr(self.scheduler, "set_begin_index"):
|
||||
self.scheduler.set_begin_index(t_start * self.scheduler.order)
|
||||
|
||||
return timesteps, num_inference_steps - t_start
|
||||
|
||||
|
||||
@@ -98,15 +98,9 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.custom_timesteps = False
|
||||
self.is_scale_input_called = False
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
if schedule_timesteps is None:
|
||||
schedule_timesteps = self.timesteps
|
||||
|
||||
indices = (schedule_timesteps == timestep).nonzero()
|
||||
return indices.item()
|
||||
|
||||
@property
|
||||
def step_index(self):
|
||||
"""
|
||||
@@ -114,6 +108,24 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
return self._step_index
|
||||
|
||||
@property
|
||||
def begin_index(self):
|
||||
"""
|
||||
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
|
||||
"""
|
||||
return self._begin_index
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
|
||||
def set_begin_index(self, begin_index: int = 0):
|
||||
"""
|
||||
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
|
||||
|
||||
Args:
|
||||
begin_index (`int`):
|
||||
The begin index for the scheduler.
|
||||
"""
|
||||
self._begin_index = begin_index
|
||||
|
||||
def scale_model_input(
|
||||
self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
|
||||
) -> torch.FloatTensor:
|
||||
@@ -231,6 +243,7 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.timesteps = torch.from_numpy(timesteps).to(device=device)
|
||||
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
# Modified _convert_to_karras implementation that takes in ramp as argument
|
||||
@@ -280,23 +293,29 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
|
||||
c_out = (sigma - sigma_min) * sigma_data / (sigma**2 + sigma_data**2) ** 0.5
|
||||
return c_skip, c_out
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
|
||||
def _init_step_index(self, timestep):
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
if schedule_timesteps is None:
|
||||
schedule_timesteps = self.timesteps
|
||||
|
||||
index_candidates = (self.timesteps == timestep).nonzero()
|
||||
indices = (schedule_timesteps == timestep).nonzero()
|
||||
|
||||
# The sigma index that is taken for the **very** first `step`
|
||||
# is always the second index (or the last index if there is only 1)
|
||||
# This way we can ensure we don't accidentally skip a sigma in
|
||||
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
||||
if len(index_candidates) > 1:
|
||||
step_index = index_candidates[1]
|
||||
else:
|
||||
step_index = index_candidates[0]
|
||||
pos = 1 if len(indices) > 1 else 0
|
||||
|
||||
self._step_index = step_index.item()
|
||||
return indices[pos].item()
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
|
||||
def _init_step_index(self, timestep):
|
||||
if self.begin_index is None:
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
self._step_index = self.index_for_timestep(timestep)
|
||||
else:
|
||||
self._step_index = self._begin_index
|
||||
|
||||
def step(
|
||||
self,
|
||||
@@ -412,7 +431,11 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
|
||||
schedule_timesteps = self.timesteps.to(original_samples.device)
|
||||
timesteps = timesteps.to(original_samples.device)
|
||||
|
||||
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
|
||||
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
|
||||
if self.begin_index is None:
|
||||
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
|
||||
else:
|
||||
step_indices = [self.begin_index] * timesteps.shape[0]
|
||||
|
||||
sigma = sigmas[step_indices].flatten()
|
||||
while len(sigma.shape) < len(original_samples.shape):
|
||||
|
||||
@@ -187,6 +187,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.model_outputs = [None] * solver_order
|
||||
self.lower_order_nums = 0
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
@property
|
||||
@@ -196,6 +197,24 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
return self._step_index
|
||||
|
||||
@property
|
||||
def begin_index(self):
|
||||
"""
|
||||
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
|
||||
"""
|
||||
return self._begin_index
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
|
||||
def set_begin_index(self, begin_index: int = 0):
|
||||
"""
|
||||
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
|
||||
|
||||
Args:
|
||||
begin_index (`int`):
|
||||
The begin index for the scheduler.
|
||||
"""
|
||||
self._begin_index = begin_index
|
||||
|
||||
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
|
||||
"""
|
||||
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
||||
@@ -255,6 +274,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# add an index counter for schedulers that allow duplicated timesteps
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
||||
@@ -620,11 +640,12 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
else:
|
||||
raise NotImplementedError("only support log-rho multistep deis now")
|
||||
|
||||
def _init_step_index(self, timestep):
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
if schedule_timesteps is None:
|
||||
schedule_timesteps = self.timesteps
|
||||
|
||||
index_candidates = (self.timesteps == timestep).nonzero()
|
||||
index_candidates = (schedule_timesteps == timestep).nonzero()
|
||||
|
||||
if len(index_candidates) == 0:
|
||||
step_index = len(self.timesteps) - 1
|
||||
@@ -637,7 +658,20 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
else:
|
||||
step_index = index_candidates[0].item()
|
||||
|
||||
self._step_index = step_index
|
||||
return step_index
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index
|
||||
def _init_step_index(self, timestep):
|
||||
"""
|
||||
Initialize the step_index counter for the scheduler.
|
||||
"""
|
||||
|
||||
if self.begin_index is None:
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
self._step_index = self.index_for_timestep(timestep)
|
||||
else:
|
||||
self._step_index = self._begin_index
|
||||
|
||||
def step(
|
||||
self,
|
||||
@@ -736,16 +770,11 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
schedule_timesteps = self.timesteps.to(original_samples.device)
|
||||
timesteps = timesteps.to(original_samples.device)
|
||||
|
||||
step_indices = []
|
||||
for timestep in timesteps:
|
||||
index_candidates = (schedule_timesteps == timestep).nonzero()
|
||||
if len(index_candidates) == 0:
|
||||
step_index = len(schedule_timesteps) - 1
|
||||
elif len(index_candidates) > 1:
|
||||
step_index = index_candidates[1].item()
|
||||
else:
|
||||
step_index = index_candidates[0].item()
|
||||
step_indices.append(step_index)
|
||||
# begin_index is None when the scheduler is used for training
|
||||
if self.begin_index is None:
|
||||
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
|
||||
else:
|
||||
step_indices = [self.begin_index] * timesteps.shape[0]
|
||||
|
||||
sigma = sigmas[step_indices].flatten()
|
||||
while len(sigma.shape) < len(original_samples.shape):
|
||||
|
||||
@@ -227,6 +227,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.model_outputs = [None] * solver_order
|
||||
self.lower_order_nums = 0
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
@property
|
||||
@@ -236,6 +237,23 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
return self._step_index
|
||||
|
||||
@property
|
||||
def begin_index(self):
|
||||
"""
|
||||
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
|
||||
"""
|
||||
return self._begin_index
|
||||
|
||||
def set_begin_index(self, begin_index: int = 0):
|
||||
"""
|
||||
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
|
||||
|
||||
Args:
|
||||
begin_index (`int`):
|
||||
The begin index for the scheduler.
|
||||
"""
|
||||
self._begin_index = begin_index
|
||||
|
||||
def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None):
|
||||
"""
|
||||
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
||||
@@ -311,6 +329,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# add an index counter for schedulers that allow duplicated timesteps
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
||||
@@ -792,11 +811,11 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
)
|
||||
return x_t
|
||||
|
||||
def _init_step_index(self, timestep):
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
if schedule_timesteps is None:
|
||||
schedule_timesteps = self.timesteps
|
||||
|
||||
index_candidates = (self.timesteps == timestep).nonzero()
|
||||
index_candidates = (schedule_timesteps == timestep).nonzero()
|
||||
|
||||
if len(index_candidates) == 0:
|
||||
step_index = len(self.timesteps) - 1
|
||||
@@ -809,7 +828,19 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
else:
|
||||
step_index = index_candidates[0].item()
|
||||
|
||||
self._step_index = step_index
|
||||
return step_index
|
||||
|
||||
def _init_step_index(self, timestep):
|
||||
"""
|
||||
Initialize the step_index counter for the scheduler.
|
||||
"""
|
||||
|
||||
if self.begin_index is None:
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
self._step_index = self.index_for_timestep(timestep)
|
||||
else:
|
||||
self._step_index = self._begin_index
|
||||
|
||||
def step(
|
||||
self,
|
||||
@@ -920,16 +951,11 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
schedule_timesteps = self.timesteps.to(original_samples.device)
|
||||
timesteps = timesteps.to(original_samples.device)
|
||||
|
||||
step_indices = []
|
||||
for timestep in timesteps:
|
||||
index_candidates = (schedule_timesteps == timestep).nonzero()
|
||||
if len(index_candidates) == 0:
|
||||
step_index = len(schedule_timesteps) - 1
|
||||
elif len(index_candidates) > 1:
|
||||
step_index = index_candidates[1].item()
|
||||
else:
|
||||
step_index = index_candidates[0].item()
|
||||
step_indices.append(step_index)
|
||||
# begin_index is None when the scheduler is used for training
|
||||
if self.begin_index is None:
|
||||
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
|
||||
else:
|
||||
step_indices = [self.begin_index] * timesteps.shape[0]
|
||||
|
||||
sigma = sigmas[step_indices].flatten()
|
||||
while len(sigma.shape) < len(original_samples.shape):
|
||||
|
||||
@@ -767,7 +767,6 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
)
|
||||
return x_t
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index
|
||||
def _init_step_index(self, timestep):
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
@@ -879,7 +878,6 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
return sample
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.add_noise
|
||||
def add_noise(
|
||||
self,
|
||||
original_samples: torch.FloatTensor,
|
||||
|
||||
@@ -13,7 +13,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from collections import defaultdict
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
@@ -198,9 +197,10 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.noise_sampler = None
|
||||
self.noise_sampler_seed = noise_sampler_seed
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.index_for_timestep
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
if schedule_timesteps is None:
|
||||
schedule_timesteps = self.timesteps
|
||||
@@ -211,31 +211,18 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
|
||||
# is always the second index (or the last index if there is only 1)
|
||||
# This way we can ensure we don't accidentally skip a sigma in
|
||||
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
||||
if len(self._index_counter) == 0:
|
||||
pos = 1 if len(indices) > 1 else 0
|
||||
else:
|
||||
timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep
|
||||
pos = self._index_counter[timestep_int]
|
||||
pos = 1 if len(indices) > 1 else 0
|
||||
|
||||
return indices[pos].item()
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
|
||||
def _init_step_index(self, timestep):
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
|
||||
index_candidates = (self.timesteps == timestep).nonzero()
|
||||
|
||||
# The sigma index that is taken for the **very** first `step`
|
||||
# is always the second index (or the last index if there is only 1)
|
||||
# This way we can ensure we don't accidentally skip a sigma in
|
||||
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
||||
if len(index_candidates) > 1:
|
||||
step_index = index_candidates[1]
|
||||
if self.begin_index is None:
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
self._step_index = self.index_for_timestep(timestep)
|
||||
else:
|
||||
step_index = index_candidates[0]
|
||||
|
||||
self._step_index = step_index.item()
|
||||
self._step_index = self._begin_index
|
||||
|
||||
@property
|
||||
def init_noise_sigma(self):
|
||||
@@ -252,6 +239,24 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
return self._step_index
|
||||
|
||||
@property
|
||||
def begin_index(self):
|
||||
"""
|
||||
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
|
||||
"""
|
||||
return self._begin_index
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
|
||||
def set_begin_index(self, begin_index: int = 0):
|
||||
"""
|
||||
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
|
||||
|
||||
Args:
|
||||
begin_index (`int`):
|
||||
The begin index for the scheduler.
|
||||
"""
|
||||
self._begin_index = begin_index
|
||||
|
||||
def scale_model_input(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
@@ -348,13 +353,10 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.mid_point_sigma = None
|
||||
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
self.noise_sampler = None
|
||||
|
||||
# for exp beta schedules, such as the one for `pipeline_shap_e.py`
|
||||
# we need an index counter
|
||||
self._index_counter = defaultdict(int)
|
||||
|
||||
def _second_order_timesteps(self, sigmas, log_sigmas):
|
||||
def sigma_fn(_t):
|
||||
return np.exp(-_t)
|
||||
@@ -444,10 +446,6 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
|
||||
if self.step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
|
||||
# advance index counter by 1
|
||||
timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep
|
||||
self._index_counter[timestep_int] += 1
|
||||
|
||||
# Create a noise sampler if it hasn't been created yet
|
||||
if self.noise_sampler is None:
|
||||
min_sigma, max_sigma = self.sigmas[self.sigmas > 0].min(), self.sigmas.max()
|
||||
@@ -527,7 +525,7 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
return SchedulerOutput(prev_sample=prev_sample)
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.add_noise
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
|
||||
def add_noise(
|
||||
self,
|
||||
original_samples: torch.FloatTensor,
|
||||
@@ -544,7 +542,11 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
|
||||
schedule_timesteps = self.timesteps.to(original_samples.device)
|
||||
timesteps = timesteps.to(original_samples.device)
|
||||
|
||||
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
|
||||
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
|
||||
if self.begin_index is None:
|
||||
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
|
||||
else:
|
||||
step_indices = [self.begin_index] * timesteps.shape[0]
|
||||
|
||||
sigma = sigmas[step_indices].flatten()
|
||||
while len(sigma.shape) < len(original_samples.shape):
|
||||
|
||||
@@ -210,6 +210,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.sample = None
|
||||
self.order_list = self.get_order_list(num_train_timesteps)
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
def get_order_list(self, num_inference_steps: int) -> List[int]:
|
||||
@@ -253,6 +254,24 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
return self._step_index
|
||||
|
||||
@property
|
||||
def begin_index(self):
|
||||
"""
|
||||
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
|
||||
"""
|
||||
return self._begin_index
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
|
||||
def set_begin_index(self, begin_index: int = 0):
|
||||
"""
|
||||
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
|
||||
|
||||
Args:
|
||||
begin_index (`int`):
|
||||
The begin index for the scheduler.
|
||||
"""
|
||||
self._begin_index = begin_index
|
||||
|
||||
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
|
||||
"""
|
||||
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
||||
@@ -315,6 +334,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# add an index counter for schedulers that allow duplicated timesteps
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
||||
@@ -813,11 +833,12 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
else:
|
||||
raise ValueError(f"Order must be 1, 2, 3, got {order}")
|
||||
|
||||
def _init_step_index(self, timestep):
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
if schedule_timesteps is None:
|
||||
schedule_timesteps = self.timesteps
|
||||
|
||||
index_candidates = (self.timesteps == timestep).nonzero()
|
||||
index_candidates = (schedule_timesteps == timestep).nonzero()
|
||||
|
||||
if len(index_candidates) == 0:
|
||||
step_index = len(self.timesteps) - 1
|
||||
@@ -830,7 +851,20 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
else:
|
||||
step_index = index_candidates[0].item()
|
||||
|
||||
self._step_index = step_index
|
||||
return step_index
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index
|
||||
def _init_step_index(self, timestep):
|
||||
"""
|
||||
Initialize the step_index counter for the scheduler.
|
||||
"""
|
||||
|
||||
if self.begin_index is None:
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
self._step_index = self.index_for_timestep(timestep)
|
||||
else:
|
||||
self._step_index = self._begin_index
|
||||
|
||||
def step(
|
||||
self,
|
||||
@@ -925,16 +959,11 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
schedule_timesteps = self.timesteps.to(original_samples.device)
|
||||
timesteps = timesteps.to(original_samples.device)
|
||||
|
||||
step_indices = []
|
||||
for timestep in timesteps:
|
||||
index_candidates = (schedule_timesteps == timestep).nonzero()
|
||||
if len(index_candidates) == 0:
|
||||
step_index = len(schedule_timesteps) - 1
|
||||
elif len(index_candidates) > 1:
|
||||
step_index = index_candidates[1].item()
|
||||
else:
|
||||
step_index = index_candidates[0].item()
|
||||
step_indices.append(step_index)
|
||||
# begin_index is None when the scheduler is used for training
|
||||
if self.begin_index is None:
|
||||
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
|
||||
else:
|
||||
step_indices = [self.begin_index] * timesteps.shape[0]
|
||||
|
||||
sigma = sigmas[step_indices].flatten()
|
||||
while len(sigma.shape) < len(original_samples.shape):
|
||||
|
||||
@@ -216,6 +216,7 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.is_scale_input_called = False
|
||||
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
@property
|
||||
@@ -233,6 +234,24 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
return self._step_index
|
||||
|
||||
@property
|
||||
def begin_index(self):
|
||||
"""
|
||||
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
|
||||
"""
|
||||
return self._begin_index
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
|
||||
def set_begin_index(self, begin_index: int = 0):
|
||||
"""
|
||||
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
|
||||
|
||||
Args:
|
||||
begin_index (`int`):
|
||||
The begin index for the scheduler.
|
||||
"""
|
||||
self._begin_index = begin_index
|
||||
|
||||
def scale_model_input(
|
||||
self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
|
||||
) -> torch.FloatTensor:
|
||||
@@ -300,25 +319,32 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
self.timesteps = torch.from_numpy(timesteps).to(device=device)
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
|
||||
def _init_step_index(self, timestep):
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
if schedule_timesteps is None:
|
||||
schedule_timesteps = self.timesteps
|
||||
|
||||
index_candidates = (self.timesteps == timestep).nonzero()
|
||||
indices = (schedule_timesteps == timestep).nonzero()
|
||||
|
||||
# The sigma index that is taken for the **very** first `step`
|
||||
# is always the second index (or the last index if there is only 1)
|
||||
# This way we can ensure we don't accidentally skip a sigma in
|
||||
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
||||
if len(index_candidates) > 1:
|
||||
step_index = index_candidates[1]
|
||||
else:
|
||||
step_index = index_candidates[0]
|
||||
pos = 1 if len(indices) > 1 else 0
|
||||
|
||||
self._step_index = step_index.item()
|
||||
return indices[pos].item()
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
|
||||
def _init_step_index(self, timestep):
|
||||
if self.begin_index is None:
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
self._step_index = self.index_for_timestep(timestep)
|
||||
else:
|
||||
self._step_index = self._begin_index
|
||||
|
||||
def step(
|
||||
self,
|
||||
@@ -440,7 +466,11 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
schedule_timesteps = self.timesteps.to(original_samples.device)
|
||||
timesteps = timesteps.to(original_samples.device)
|
||||
|
||||
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
|
||||
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
|
||||
if self.begin_index is None:
|
||||
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
|
||||
else:
|
||||
step_indices = [self.begin_index] * timesteps.shape[0]
|
||||
|
||||
sigma = sigmas[step_indices].flatten()
|
||||
while len(sigma.shape) < len(original_samples.shape):
|
||||
|
||||
@@ -237,6 +237,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.use_karras_sigmas = use_karras_sigmas
|
||||
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
@property
|
||||
@@ -255,6 +256,24 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
return self._step_index
|
||||
|
||||
@property
|
||||
def begin_index(self):
|
||||
"""
|
||||
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
|
||||
"""
|
||||
return self._begin_index
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
|
||||
def set_begin_index(self, begin_index: int = 0):
|
||||
"""
|
||||
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
|
||||
|
||||
Args:
|
||||
begin_index (`int`):
|
||||
The begin index for the scheduler.
|
||||
"""
|
||||
self._begin_index = begin_index
|
||||
|
||||
def scale_model_input(
|
||||
self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
|
||||
) -> torch.FloatTensor:
|
||||
@@ -342,6 +361,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
def _sigma_to_t(self, sigma, log_sigmas):
|
||||
@@ -393,22 +413,27 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
|
||||
return sigmas
|
||||
|
||||
def _init_step_index(self, timestep):
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
if schedule_timesteps is None:
|
||||
schedule_timesteps = self.timesteps
|
||||
|
||||
index_candidates = (self.timesteps == timestep).nonzero()
|
||||
indices = (schedule_timesteps == timestep).nonzero()
|
||||
|
||||
# The sigma index that is taken for the **very** first `step`
|
||||
# is always the second index (or the last index if there is only 1)
|
||||
# This way we can ensure we don't accidentally skip a sigma in
|
||||
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
||||
if len(index_candidates) > 1:
|
||||
step_index = index_candidates[1]
|
||||
else:
|
||||
step_index = index_candidates[0]
|
||||
pos = 1 if len(indices) > 1 else 0
|
||||
|
||||
self._step_index = step_index.item()
|
||||
return indices[pos].item()
|
||||
|
||||
def _init_step_index(self, timestep):
|
||||
if self.begin_index is None:
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
self._step_index = self.index_for_timestep(timestep)
|
||||
else:
|
||||
self._step_index = self._begin_index
|
||||
|
||||
def step(
|
||||
self,
|
||||
@@ -538,7 +563,11 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
schedule_timesteps = self.timesteps.to(original_samples.device)
|
||||
timesteps = timesteps.to(original_samples.device)
|
||||
|
||||
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
|
||||
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
|
||||
if self.begin_index is None:
|
||||
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
|
||||
else:
|
||||
step_indices = [self.begin_index] * timesteps.shape[0]
|
||||
|
||||
sigma = sigmas[step_indices].flatten()
|
||||
while len(sigma.shape) < len(original_samples.shape):
|
||||
|
||||
@@ -13,7 +13,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from collections import defaultdict
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
@@ -148,8 +147,10 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.use_karras_sigmas = use_karras_sigmas
|
||||
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
if schedule_timesteps is None:
|
||||
schedule_timesteps = self.timesteps
|
||||
@@ -160,11 +161,7 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
# is always the second index (or the last index if there is only 1)
|
||||
# This way we can ensure we don't accidentally skip a sigma in
|
||||
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
||||
if len(self._index_counter) == 0:
|
||||
pos = 1 if len(indices) > 1 else 0
|
||||
else:
|
||||
timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep
|
||||
pos = self._index_counter[timestep_int]
|
||||
pos = 1 if len(indices) > 1 else 0
|
||||
|
||||
return indices[pos].item()
|
||||
|
||||
@@ -183,6 +180,24 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
return self._step_index
|
||||
|
||||
@property
|
||||
def begin_index(self):
|
||||
"""
|
||||
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
|
||||
"""
|
||||
return self._begin_index
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
|
||||
def set_begin_index(self, begin_index: int = 0):
|
||||
"""
|
||||
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
|
||||
|
||||
Args:
|
||||
begin_index (`int`):
|
||||
The begin index for the scheduler.
|
||||
"""
|
||||
self._begin_index = begin_index
|
||||
|
||||
def scale_model_input(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
@@ -270,13 +285,9 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.dt = None
|
||||
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
# (YiYi Notes: keep this for now since we are keeping add_noise function which use index_for_timestep)
|
||||
# for exp beta schedules, such as the one for `pipeline_shap_e.py`
|
||||
# we need an index counter
|
||||
self._index_counter = defaultdict(int)
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
|
||||
def _sigma_to_t(self, sigma, log_sigmas):
|
||||
# get log sigma
|
||||
@@ -333,21 +344,12 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
|
||||
def _init_step_index(self, timestep):
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
|
||||
index_candidates = (self.timesteps == timestep).nonzero()
|
||||
|
||||
# The sigma index that is taken for the **very** first `step`
|
||||
# is always the second index (or the last index if there is only 1)
|
||||
# This way we can ensure we don't accidentally skip a sigma in
|
||||
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
||||
if len(index_candidates) > 1:
|
||||
step_index = index_candidates[1]
|
||||
if self.begin_index is None:
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
self._step_index = self.index_for_timestep(timestep)
|
||||
else:
|
||||
step_index = index_candidates[0]
|
||||
|
||||
self._step_index = step_index.item()
|
||||
self._step_index = self._begin_index
|
||||
|
||||
def step(
|
||||
self,
|
||||
@@ -378,11 +380,6 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
if self.step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
|
||||
# (YiYi notes: keep this for now since we are keeping the add_noise method)
|
||||
# advance index counter by 1
|
||||
timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep
|
||||
self._index_counter[timestep_int] += 1
|
||||
|
||||
if self.state_in_first_order:
|
||||
sigma = self.sigmas[self.step_index]
|
||||
sigma_next = self.sigmas[self.step_index + 1]
|
||||
@@ -453,6 +450,7 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
return SchedulerOutput(prev_sample=prev_sample)
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
|
||||
def add_noise(
|
||||
self,
|
||||
original_samples: torch.FloatTensor,
|
||||
@@ -469,7 +467,11 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
schedule_timesteps = self.timesteps.to(original_samples.device)
|
||||
timesteps = timesteps.to(original_samples.device)
|
||||
|
||||
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
|
||||
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
|
||||
if self.begin_index is None:
|
||||
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
|
||||
else:
|
||||
step_indices = [self.begin_index] * timesteps.shape[0]
|
||||
|
||||
sigma = sigmas[step_indices].flatten()
|
||||
while len(sigma.shape) < len(original_samples.shape):
|
||||
|
||||
@@ -56,6 +56,7 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
# running values
|
||||
self.ets = []
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
|
||||
@property
|
||||
def step_index(self):
|
||||
@@ -64,6 +65,24 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
return self._step_index
|
||||
|
||||
@property
|
||||
def begin_index(self):
|
||||
"""
|
||||
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
|
||||
"""
|
||||
return self._begin_index
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
|
||||
def set_begin_index(self, begin_index: int = 0):
|
||||
"""
|
||||
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
|
||||
|
||||
Args:
|
||||
begin_index (`int`):
|
||||
The begin index for the scheduler.
|
||||
"""
|
||||
self._begin_index = begin_index
|
||||
|
||||
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
|
||||
"""
|
||||
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
||||
@@ -90,24 +109,31 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
self.ets = []
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
|
||||
def _init_step_index(self, timestep):
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
if schedule_timesteps is None:
|
||||
schedule_timesteps = self.timesteps
|
||||
|
||||
index_candidates = (self.timesteps == timestep).nonzero()
|
||||
indices = (schedule_timesteps == timestep).nonzero()
|
||||
|
||||
# The sigma index that is taken for the **very** first `step`
|
||||
# is always the second index (or the last index if there is only 1)
|
||||
# This way we can ensure we don't accidentally skip a sigma in
|
||||
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
||||
if len(index_candidates) > 1:
|
||||
step_index = index_candidates[1]
|
||||
else:
|
||||
step_index = index_candidates[0]
|
||||
pos = 1 if len(indices) > 1 else 0
|
||||
|
||||
self._step_index = step_index.item()
|
||||
return indices[pos].item()
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
|
||||
def _init_step_index(self, timestep):
|
||||
if self.begin_index is None:
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
self._step_index = self.index_for_timestep(timestep)
|
||||
else:
|
||||
self._step_index = self._begin_index
|
||||
|
||||
def step(
|
||||
self,
|
||||
|
||||
@@ -13,7 +13,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from collections import defaultdict
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
@@ -140,27 +139,9 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
# set all values
|
||||
self.set_timesteps(num_train_timesteps, None, num_train_timesteps)
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.index_for_timestep
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
if schedule_timesteps is None:
|
||||
schedule_timesteps = self.timesteps
|
||||
|
||||
indices = (schedule_timesteps == timestep).nonzero()
|
||||
|
||||
# The sigma index that is taken for the **very** first `step`
|
||||
# is always the second index (or the last index if there is only 1)
|
||||
# This way we can ensure we don't accidentally skip a sigma in
|
||||
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
||||
if len(self._index_counter) == 0:
|
||||
pos = 1 if len(indices) > 1 else 0
|
||||
else:
|
||||
timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep
|
||||
pos = self._index_counter[timestep_int]
|
||||
|
||||
return indices[pos].item()
|
||||
|
||||
@property
|
||||
def init_noise_sigma(self):
|
||||
# standard deviation of the initial noise distribution
|
||||
@@ -176,6 +157,24 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
return self._step_index
|
||||
|
||||
@property
|
||||
def begin_index(self):
|
||||
"""
|
||||
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
|
||||
"""
|
||||
return self._begin_index
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
|
||||
def set_begin_index(self, begin_index: int = 0):
|
||||
"""
|
||||
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
|
||||
|
||||
Args:
|
||||
begin_index (`int`):
|
||||
The begin index for the scheduler.
|
||||
"""
|
||||
self._begin_index = begin_index
|
||||
|
||||
def scale_model_input(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
@@ -295,11 +294,8 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
self.sample = None
|
||||
|
||||
# for exp beta schedules, such as the one for `pipeline_shap_e.py`
|
||||
# we need an index counter
|
||||
self._index_counter = defaultdict(int)
|
||||
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
|
||||
@@ -356,23 +352,29 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
def state_in_first_order(self):
|
||||
return self.sample is None
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
|
||||
def _init_step_index(self, timestep):
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
if schedule_timesteps is None:
|
||||
schedule_timesteps = self.timesteps
|
||||
|
||||
index_candidates = (self.timesteps == timestep).nonzero()
|
||||
indices = (schedule_timesteps == timestep).nonzero()
|
||||
|
||||
# The sigma index that is taken for the **very** first `step`
|
||||
# is always the second index (or the last index if there is only 1)
|
||||
# This way we can ensure we don't accidentally skip a sigma in
|
||||
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
||||
if len(index_candidates) > 1:
|
||||
step_index = index_candidates[1]
|
||||
else:
|
||||
step_index = index_candidates[0]
|
||||
pos = 1 if len(indices) > 1 else 0
|
||||
|
||||
self._step_index = step_index.item()
|
||||
return indices[pos].item()
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
|
||||
def _init_step_index(self, timestep):
|
||||
if self.begin_index is None:
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
self._step_index = self.index_for_timestep(timestep)
|
||||
else:
|
||||
self._step_index = self._begin_index
|
||||
|
||||
def step(
|
||||
self,
|
||||
@@ -406,10 +408,6 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
if self.step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
|
||||
# advance index counter by 1
|
||||
timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep
|
||||
self._index_counter[timestep_int] += 1
|
||||
|
||||
if self.state_in_first_order:
|
||||
sigma = self.sigmas[self.step_index]
|
||||
sigma_interpol = self.sigmas_interpol[self.step_index]
|
||||
@@ -478,7 +476,7 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
return SchedulerOutput(prev_sample=prev_sample)
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.add_noise
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
|
||||
def add_noise(
|
||||
self,
|
||||
original_samples: torch.FloatTensor,
|
||||
@@ -495,7 +493,11 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
schedule_timesteps = self.timesteps.to(original_samples.device)
|
||||
timesteps = timesteps.to(original_samples.device)
|
||||
|
||||
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
|
||||
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
|
||||
if self.begin_index is None:
|
||||
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
|
||||
else:
|
||||
step_indices = [self.begin_index] * timesteps.shape[0]
|
||||
|
||||
sigma = sigmas[step_indices].flatten()
|
||||
while len(sigma.shape) < len(original_samples.shape):
|
||||
|
||||
@@ -13,7 +13,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from collections import defaultdict
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
@@ -140,27 +139,9 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.set_timesteps(num_train_timesteps, None, num_train_timesteps)
|
||||
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.index_for_timestep
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
if schedule_timesteps is None:
|
||||
schedule_timesteps = self.timesteps
|
||||
|
||||
indices = (schedule_timesteps == timestep).nonzero()
|
||||
|
||||
# The sigma index that is taken for the **very** first `step`
|
||||
# is always the second index (or the last index if there is only 1)
|
||||
# This way we can ensure we don't accidentally skip a sigma in
|
||||
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
||||
if len(self._index_counter) == 0:
|
||||
pos = 1 if len(indices) > 1 else 0
|
||||
else:
|
||||
timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep
|
||||
pos = self._index_counter[timestep_int]
|
||||
|
||||
return indices[pos].item()
|
||||
|
||||
@property
|
||||
def init_noise_sigma(self):
|
||||
# standard deviation of the initial noise distribution
|
||||
@@ -176,6 +157,24 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
return self._step_index
|
||||
|
||||
@property
|
||||
def begin_index(self):
|
||||
"""
|
||||
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
|
||||
"""
|
||||
return self._begin_index
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
|
||||
def set_begin_index(self, begin_index: int = 0):
|
||||
"""
|
||||
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
|
||||
|
||||
Args:
|
||||
begin_index (`int`):
|
||||
The begin index for the scheduler.
|
||||
"""
|
||||
self._begin_index = begin_index
|
||||
|
||||
def scale_model_input(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
@@ -280,34 +279,37 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
self.sample = None
|
||||
|
||||
# for exp beta schedules, such as the one for `pipeline_shap_e.py`
|
||||
# we need an index counter
|
||||
self._index_counter = defaultdict(int)
|
||||
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
@property
|
||||
def state_in_first_order(self):
|
||||
return self.sample is None
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
|
||||
def _init_step_index(self, timestep):
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
if schedule_timesteps is None:
|
||||
schedule_timesteps = self.timesteps
|
||||
|
||||
index_candidates = (self.timesteps == timestep).nonzero()
|
||||
indices = (schedule_timesteps == timestep).nonzero()
|
||||
|
||||
# The sigma index that is taken for the **very** first `step`
|
||||
# is always the second index (or the last index if there is only 1)
|
||||
# This way we can ensure we don't accidentally skip a sigma in
|
||||
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
||||
if len(index_candidates) > 1:
|
||||
step_index = index_candidates[1]
|
||||
else:
|
||||
step_index = index_candidates[0]
|
||||
pos = 1 if len(indices) > 1 else 0
|
||||
|
||||
self._step_index = step_index.item()
|
||||
return indices[pos].item()
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
|
||||
def _init_step_index(self, timestep):
|
||||
if self.begin_index is None:
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
self._step_index = self.index_for_timestep(timestep)
|
||||
else:
|
||||
self._step_index = self._begin_index
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
|
||||
def _sigma_to_t(self, sigma, log_sigmas):
|
||||
@@ -388,10 +390,6 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
if self.step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
|
||||
# advance index counter by 1
|
||||
timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep
|
||||
self._index_counter[timestep_int] += 1
|
||||
|
||||
if self.state_in_first_order:
|
||||
sigma = self.sigmas[self.step_index]
|
||||
sigma_interpol = self.sigmas_interpol[self.step_index + 1]
|
||||
@@ -453,7 +451,7 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
return SchedulerOutput(prev_sample=prev_sample)
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.add_noise
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
|
||||
def add_noise(
|
||||
self,
|
||||
original_samples: torch.FloatTensor,
|
||||
@@ -470,7 +468,11 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
schedule_timesteps = self.timesteps.to(original_samples.device)
|
||||
timesteps = timesteps.to(original_samples.device)
|
||||
|
||||
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
|
||||
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
|
||||
if self.begin_index is None:
|
||||
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
|
||||
else:
|
||||
step_indices = [self.begin_index] * timesteps.shape[0]
|
||||
|
||||
sigma = sigmas[step_indices].flatten()
|
||||
while len(sigma.shape) < len(original_samples.shape):
|
||||
|
||||
@@ -250,29 +250,54 @@ class LCMScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.custom_timesteps = False
|
||||
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
|
||||
def _init_step_index(self, timestep):
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
if schedule_timesteps is None:
|
||||
schedule_timesteps = self.timesteps
|
||||
|
||||
index_candidates = (self.timesteps == timestep).nonzero()
|
||||
indices = (schedule_timesteps == timestep).nonzero()
|
||||
|
||||
# The sigma index that is taken for the **very** first `step`
|
||||
# is always the second index (or the last index if there is only 1)
|
||||
# This way we can ensure we don't accidentally skip a sigma in
|
||||
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
||||
if len(index_candidates) > 1:
|
||||
step_index = index_candidates[1]
|
||||
else:
|
||||
step_index = index_candidates[0]
|
||||
pos = 1 if len(indices) > 1 else 0
|
||||
|
||||
self._step_index = step_index.item()
|
||||
return indices[pos].item()
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
|
||||
def _init_step_index(self, timestep):
|
||||
if self.begin_index is None:
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
self._step_index = self.index_for_timestep(timestep)
|
||||
else:
|
||||
self._step_index = self._begin_index
|
||||
|
||||
@property
|
||||
def step_index(self):
|
||||
return self._step_index
|
||||
|
||||
@property
|
||||
def begin_index(self):
|
||||
"""
|
||||
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
|
||||
"""
|
||||
return self._begin_index
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
|
||||
def set_begin_index(self, begin_index: int = 0):
|
||||
"""
|
||||
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
|
||||
|
||||
Args:
|
||||
begin_index (`int`):
|
||||
The begin index for the scheduler.
|
||||
"""
|
||||
self._begin_index = begin_index
|
||||
|
||||
def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
|
||||
"""
|
||||
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
||||
@@ -462,6 +487,7 @@ class LCMScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.long)
|
||||
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
|
||||
def get_scalings_for_boundary_condition_discrete(self, timestep):
|
||||
self.sigma_data = 0.5 # Default: 0.5
|
||||
|
||||
@@ -168,6 +168,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.is_scale_input_called = False
|
||||
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
@property
|
||||
@@ -185,6 +186,24 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
return self._step_index
|
||||
|
||||
@property
|
||||
def begin_index(self):
|
||||
"""
|
||||
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
|
||||
"""
|
||||
return self._begin_index
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
|
||||
def set_begin_index(self, begin_index: int = 0):
|
||||
"""
|
||||
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
|
||||
|
||||
Args:
|
||||
begin_index (`int`):
|
||||
The begin index for the scheduler.
|
||||
"""
|
||||
self._begin_index = begin_index
|
||||
|
||||
def scale_model_input(
|
||||
self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
|
||||
) -> torch.FloatTensor:
|
||||
@@ -280,27 +299,34 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.sigmas = torch.from_numpy(sigmas).to(device=device)
|
||||
self.timesteps = torch.from_numpy(timesteps).to(device=device)
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
self.derivatives = []
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
|
||||
def _init_step_index(self, timestep):
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
if schedule_timesteps is None:
|
||||
schedule_timesteps = self.timesteps
|
||||
|
||||
index_candidates = (self.timesteps == timestep).nonzero()
|
||||
indices = (schedule_timesteps == timestep).nonzero()
|
||||
|
||||
# The sigma index that is taken for the **very** first `step`
|
||||
# is always the second index (or the last index if there is only 1)
|
||||
# This way we can ensure we don't accidentally skip a sigma in
|
||||
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
||||
if len(index_candidates) > 1:
|
||||
step_index = index_candidates[1]
|
||||
else:
|
||||
step_index = index_candidates[0]
|
||||
pos = 1 if len(indices) > 1 else 0
|
||||
|
||||
self._step_index = step_index.item()
|
||||
return indices[pos].item()
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
|
||||
def _init_step_index(self, timestep):
|
||||
if self.begin_index is None:
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
self._step_index = self.index_for_timestep(timestep)
|
||||
else:
|
||||
self._step_index = self._begin_index
|
||||
|
||||
# copied from diffusers.schedulers.scheduling_euler_discrete._sigma_to_t
|
||||
def _sigma_to_t(self, sigma, log_sigmas):
|
||||
@@ -434,7 +460,11 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
schedule_timesteps = self.timesteps.to(original_samples.device)
|
||||
timesteps = timesteps.to(original_samples.device)
|
||||
|
||||
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
|
||||
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
|
||||
if self.begin_index is None:
|
||||
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
|
||||
else:
|
||||
step_indices = [self.begin_index] * timesteps.shape[0]
|
||||
|
||||
sigma = sigmas[step_indices].flatten()
|
||||
while len(sigma.shape) < len(original_samples.shape):
|
||||
|
||||
@@ -212,6 +212,7 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.lower_order_nums = 0
|
||||
self.last_sample = None
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
@property
|
||||
@@ -221,6 +222,24 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
return self._step_index
|
||||
|
||||
@property
|
||||
def begin_index(self):
|
||||
"""
|
||||
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
|
||||
"""
|
||||
return self._begin_index
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
|
||||
def set_begin_index(self, begin_index: int = 0):
|
||||
"""
|
||||
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
|
||||
|
||||
Args:
|
||||
begin_index (`int`):
|
||||
The begin index for the scheduler.
|
||||
"""
|
||||
self._begin_index = begin_index
|
||||
|
||||
def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None):
|
||||
"""
|
||||
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
||||
@@ -283,6 +302,7 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# add an index counter for schedulers that allow duplicated timesteps
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
||||
@@ -925,11 +945,12 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
|
||||
x_t = x_t.to(x.dtype)
|
||||
return x_t
|
||||
|
||||
def _init_step_index(self, timestep):
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
if schedule_timesteps is None:
|
||||
schedule_timesteps = self.timesteps
|
||||
|
||||
index_candidates = (self.timesteps == timestep).nonzero()
|
||||
index_candidates = (schedule_timesteps == timestep).nonzero()
|
||||
|
||||
if len(index_candidates) == 0:
|
||||
step_index = len(self.timesteps) - 1
|
||||
@@ -942,7 +963,20 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
|
||||
else:
|
||||
step_index = index_candidates[0].item()
|
||||
|
||||
self._step_index = step_index
|
||||
return step_index
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index
|
||||
def _init_step_index(self, timestep):
|
||||
"""
|
||||
Initialize the step_index counter for the scheduler.
|
||||
"""
|
||||
|
||||
if self.begin_index is None:
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
self._step_index = self.index_for_timestep(timestep)
|
||||
else:
|
||||
self._step_index = self._begin_index
|
||||
|
||||
def step(
|
||||
self,
|
||||
|
||||
@@ -198,6 +198,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.solver_p = solver_p
|
||||
self.last_sample = None
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
@property
|
||||
@@ -207,6 +208,24 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
return self._step_index
|
||||
|
||||
@property
|
||||
def begin_index(self):
|
||||
"""
|
||||
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
|
||||
"""
|
||||
return self._begin_index
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
|
||||
def set_begin_index(self, begin_index: int = 0):
|
||||
"""
|
||||
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
|
||||
|
||||
Args:
|
||||
begin_index (`int`):
|
||||
The begin index for the scheduler.
|
||||
"""
|
||||
self._begin_index = begin_index
|
||||
|
||||
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
|
||||
"""
|
||||
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
||||
@@ -269,6 +288,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# add an index counter for schedulers that allow duplicated timesteps
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
||||
@@ -698,11 +718,12 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
x_t = x_t.to(x.dtype)
|
||||
return x_t
|
||||
|
||||
def _init_step_index(self, timestep):
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
if schedule_timesteps is None:
|
||||
schedule_timesteps = self.timesteps
|
||||
|
||||
index_candidates = (self.timesteps == timestep).nonzero()
|
||||
index_candidates = (schedule_timesteps == timestep).nonzero()
|
||||
|
||||
if len(index_candidates) == 0:
|
||||
step_index = len(self.timesteps) - 1
|
||||
@@ -715,7 +736,20 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
else:
|
||||
step_index = index_candidates[0].item()
|
||||
|
||||
self._step_index = step_index
|
||||
return step_index
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index
|
||||
def _init_step_index(self, timestep):
|
||||
"""
|
||||
Initialize the step_index counter for the scheduler.
|
||||
"""
|
||||
|
||||
if self.begin_index is None:
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
self._step_index = self.index_for_timestep(timestep)
|
||||
else:
|
||||
self._step_index = self._begin_index
|
||||
|
||||
def step(
|
||||
self,
|
||||
@@ -830,16 +864,11 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
schedule_timesteps = self.timesteps.to(original_samples.device)
|
||||
timesteps = timesteps.to(original_samples.device)
|
||||
|
||||
step_indices = []
|
||||
for timestep in timesteps:
|
||||
index_candidates = (schedule_timesteps == timestep).nonzero()
|
||||
if len(index_candidates) == 0:
|
||||
step_index = len(schedule_timesteps) - 1
|
||||
elif len(index_candidates) > 1:
|
||||
step_index = index_candidates[1].item()
|
||||
else:
|
||||
step_index = index_candidates[0].item()
|
||||
step_indices.append(step_index)
|
||||
# begin_index is None when the scheduler is used for training
|
||||
if self.begin_index is None:
|
||||
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
|
||||
else:
|
||||
step_indices = [self.begin_index] * timesteps.shape[0]
|
||||
|
||||
sigma = sigmas[step_indices].flatten()
|
||||
while len(sigma.shape) < len(original_samples.shape):
|
||||
|
||||
@@ -5,7 +5,6 @@ from typing import Any, Dict, Iterable, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torchvision import transforms
|
||||
|
||||
from .models import UNet2DConditionModel
|
||||
from .utils import (
|
||||
@@ -13,6 +12,7 @@ from .utils import (
|
||||
convert_state_dict_to_peft,
|
||||
deprecate,
|
||||
is_peft_available,
|
||||
is_torchvision_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
|
||||
@@ -23,6 +23,9 @@ if is_transformers_available():
|
||||
if is_peft_available():
|
||||
from peft import set_peft_model_state_dict
|
||||
|
||||
if is_torchvision_available():
|
||||
from torchvision import transforms
|
||||
|
||||
|
||||
def set_seed(seed: int):
|
||||
"""
|
||||
@@ -79,6 +82,11 @@ def resolve_interpolation_mode(interpolation_type: str):
|
||||
`torchvision.transforms.InterpolationMode`: an `InterpolationMode` enum used by torchvision's `resize`
|
||||
transform.
|
||||
"""
|
||||
if not is_torchvision_available():
|
||||
raise ImportError(
|
||||
"Please make sure to install `torchvision` to be able to use the `resolve_interpolation_mode()` function."
|
||||
)
|
||||
|
||||
if interpolation_type == "bilinear":
|
||||
interpolation_mode = transforms.InterpolationMode.BILINEAR
|
||||
elif interpolation_type == "bicubic":
|
||||
|
||||
@@ -75,6 +75,7 @@ from .import_utils import (
|
||||
is_torch_version,
|
||||
is_torch_xla_available,
|
||||
is_torchsde_available,
|
||||
is_torchvision_available,
|
||||
is_transformers_available,
|
||||
is_transformers_version,
|
||||
is_unidecode_available,
|
||||
|
||||
@@ -278,6 +278,13 @@ try:
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_peft_available = False
|
||||
|
||||
_torchvision_available = importlib.util.find_spec("torchvision") is not None
|
||||
try:
|
||||
_torchvision_version = importlib_metadata.version("torchvision")
|
||||
logger.debug(f"Successfully imported torchvision version {_torchvision_version}")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_torchvision_available = False
|
||||
|
||||
|
||||
def is_torch_available():
|
||||
return _torch_available
|
||||
@@ -367,6 +374,10 @@ def is_peft_available():
|
||||
return _peft_available
|
||||
|
||||
|
||||
def is_torchvision_available():
|
||||
return _torchvision_available
|
||||
|
||||
|
||||
# docstyle-ignore
|
||||
FLAX_IMPORT_ERROR = """
|
||||
{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the
|
||||
|
||||
@@ -854,6 +854,8 @@ def _is_torch_fp64_available(device):
|
||||
|
||||
import torch
|
||||
|
||||
device = torch.device(device)
|
||||
|
||||
try:
|
||||
x = torch.zeros((2, 2), dtype=torch.float64).to(device)
|
||||
_ = torch.mul(x, x)
|
||||
|
||||
0
tests/models/autoencoders/__init__.py
Normal file
0
tests/models/autoencoders/__init__.py
Normal file
@@ -46,7 +46,7 @@ from diffusers.utils.testing_utils import (
|
||||
)
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
from .test_modeling_common import ModelTesterMixin, UNetTesterMixin
|
||||
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
@@ -4,7 +4,7 @@ from diffusers import FlaxAutoencoderKL
|
||||
from diffusers.utils import is_flax_available
|
||||
from diffusers.utils.testing_utils import require_flax
|
||||
|
||||
from .test_modeling_common_flax import FlaxModelTesterMixin
|
||||
from ..test_modeling_common_flax import FlaxModelTesterMixin
|
||||
|
||||
|
||||
if is_flax_available():
|
||||
@@ -25,7 +25,7 @@ from diffusers.utils.testing_utils import (
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from .test_modeling_common import ModelTesterMixin, UNetTesterMixin
|
||||
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
0
tests/models/unets/__init__.py
Normal file
0
tests/models/unets/__init__.py
Normal file
@@ -25,7 +25,7 @@ from diffusers.utils.testing_utils import (
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from .test_modeling_common import ModelTesterMixin, UNetTesterMixin
|
||||
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
|
||||
|
||||
|
||||
class UNet1DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
@@ -30,7 +30,7 @@ from diffusers.utils.testing_utils import (
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from .test_modeling_common import ModelTesterMixin, UNetTesterMixin
|
||||
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@@ -48,7 +48,7 @@ from diffusers.utils.testing_utils import (
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from .test_modeling_common import ModelTesterMixin, UNetTesterMixin
|
||||
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@@ -23,7 +23,7 @@ from diffusers.utils import logging
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.utils.testing_utils import enable_full_determinism, floats_tensor, skip_mps, torch_device
|
||||
|
||||
from .test_modeling_common import ModelTesterMixin, UNetTesterMixin
|
||||
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
@@ -30,7 +30,7 @@ from diffusers.utils.testing_utils import (
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from .test_modeling_common import ModelTesterMixin, UNetTesterMixin
|
||||
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@@ -28,7 +28,7 @@ from diffusers.utils.testing_utils import (
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from .test_modeling_common import ModelTesterMixin, UNetTesterMixin
|
||||
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@@ -35,6 +35,7 @@ from diffusers.models.attention_processor import AttnProcessor, AttnProcessor2_0
|
||||
from diffusers.utils import load_image
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
numpy_cosine_similarity_distance,
|
||||
require_torch_gpu,
|
||||
slow,
|
||||
torch_device,
|
||||
@@ -119,7 +120,8 @@ class IPAdapterSDIntegrationTests(IPAdapterNightlyTestsMixin):
|
||||
|
||||
expected_slice = np.array([0.80810547, 0.88183594, 0.9296875, 0.9189453, 0.9848633, 1.0, 0.97021484, 1.0, 1.0])
|
||||
|
||||
assert np.allclose(image_slice, expected_slice, atol=1e-3)
|
||||
max_diff = numpy_cosine_similarity_distance(image_slice, expected_slice)
|
||||
assert max_diff < 5e-4
|
||||
|
||||
pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter-plus_sd15.bin")
|
||||
|
||||
@@ -131,7 +133,8 @@ class IPAdapterSDIntegrationTests(IPAdapterNightlyTestsMixin):
|
||||
[0.30444336, 0.26513672, 0.22436523, 0.2758789, 0.25585938, 0.20751953, 0.25390625, 0.24633789, 0.21923828]
|
||||
)
|
||||
|
||||
assert np.allclose(image_slice, expected_slice, atol=1e-3)
|
||||
max_diff = numpy_cosine_similarity_distance(image_slice, expected_slice)
|
||||
assert max_diff < 5e-4
|
||||
|
||||
def test_image_to_image(self):
|
||||
image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder")
|
||||
@@ -149,7 +152,8 @@ class IPAdapterSDIntegrationTests(IPAdapterNightlyTestsMixin):
|
||||
[0.22167969, 0.21875, 0.21728516, 0.22607422, 0.21948242, 0.23925781, 0.22387695, 0.25268555, 0.2722168]
|
||||
)
|
||||
|
||||
assert np.allclose(image_slice, expected_slice, atol=1e-3)
|
||||
max_diff = numpy_cosine_similarity_distance(image_slice, expected_slice)
|
||||
assert max_diff < 5e-4
|
||||
|
||||
pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter-plus_sd15.bin")
|
||||
|
||||
@@ -161,7 +165,8 @@ class IPAdapterSDIntegrationTests(IPAdapterNightlyTestsMixin):
|
||||
[0.35913086, 0.265625, 0.26367188, 0.24658203, 0.19750977, 0.39990234, 0.15258789, 0.20336914, 0.5517578]
|
||||
)
|
||||
|
||||
assert np.allclose(image_slice, expected_slice, atol=1e-3)
|
||||
max_diff = numpy_cosine_similarity_distance(image_slice, expected_slice)
|
||||
assert max_diff < 5e-4
|
||||
|
||||
def test_inpainting(self):
|
||||
image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder")
|
||||
@@ -179,7 +184,8 @@ class IPAdapterSDIntegrationTests(IPAdapterNightlyTestsMixin):
|
||||
[0.27148438, 0.24047852, 0.22167969, 0.23217773, 0.21118164, 0.21142578, 0.21875, 0.20751953, 0.20019531]
|
||||
)
|
||||
|
||||
assert np.allclose(image_slice, expected_slice, atol=1e-3)
|
||||
max_diff = numpy_cosine_similarity_distance(image_slice, expected_slice)
|
||||
assert max_diff < 5e-4
|
||||
|
||||
pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter-plus_sd15.bin")
|
||||
|
||||
@@ -187,11 +193,8 @@ class IPAdapterSDIntegrationTests(IPAdapterNightlyTestsMixin):
|
||||
images = pipeline(**inputs).images
|
||||
image_slice = images[0, :3, :3, -1].flatten()
|
||||
|
||||
expected_slice = np.array(
|
||||
[0.27294922, 0.24023438, 0.21948242, 0.23242188, 0.20825195, 0.2055664, 0.21679688, 0.20336914, 0.19360352]
|
||||
)
|
||||
|
||||
assert np.allclose(image_slice, expected_slice, atol=1e-3)
|
||||
max_diff = numpy_cosine_similarity_distance(image_slice, expected_slice)
|
||||
assert max_diff < 5e-4
|
||||
|
||||
def test_text_to_image_model_cpu_offload(self):
|
||||
image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder")
|
||||
@@ -233,11 +236,10 @@ class IPAdapterSDIntegrationTests(IPAdapterNightlyTestsMixin):
|
||||
images = pipeline(**inputs).images
|
||||
image_slice = images[0, :3, :3, -1].flatten()
|
||||
|
||||
expected_slice = np.array(
|
||||
[0.18115234, 0.13500977, 0.13427734, 0.24194336, 0.17138672, 0.16625977, 0.4260254, 0.43359375, 0.4416504]
|
||||
)
|
||||
expected_slice = np.array([0.1958, 0.1475, 0.1396, 0.2412, 0.1658, 0.1533, 0.3997, 0.4055, 0.4128])
|
||||
|
||||
assert np.allclose(image_slice, expected_slice, atol=1e-3)
|
||||
max_diff = numpy_cosine_similarity_distance(image_slice, expected_slice)
|
||||
assert max_diff < 5e-4
|
||||
|
||||
def test_unload(self):
|
||||
image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder")
|
||||
@@ -277,7 +279,9 @@ class IPAdapterSDIntegrationTests(IPAdapterNightlyTestsMixin):
|
||||
expected_slice = np.array(
|
||||
[0.5234375, 0.53515625, 0.5629883, 0.57128906, 0.59521484, 0.62109375, 0.57910156, 0.6201172, 0.6508789]
|
||||
)
|
||||
assert np.allclose(image_slice, expected_slice, atol=1e-3)
|
||||
|
||||
max_diff = numpy_cosine_similarity_distance(image_slice, expected_slice)
|
||||
assert max_diff < 5e-4
|
||||
|
||||
|
||||
@slow
|
||||
@@ -314,7 +318,8 @@ class IPAdapterSDXLIntegrationTests(IPAdapterNightlyTestsMixin):
|
||||
]
|
||||
)
|
||||
|
||||
assert np.allclose(image_slice, expected_slice, atol=1e-3)
|
||||
max_diff = numpy_cosine_similarity_distance(image_slice, expected_slice)
|
||||
assert max_diff < 5e-4
|
||||
|
||||
image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder")
|
||||
|
||||
@@ -339,7 +344,8 @@ class IPAdapterSDXLIntegrationTests(IPAdapterNightlyTestsMixin):
|
||||
[0.0576596, 0.05600825, 0.04479006, 0.05288461, 0.05461192, 0.05137569, 0.04867965, 0.05301541, 0.04939842]
|
||||
)
|
||||
|
||||
assert np.allclose(image_slice, expected_slice, atol=1e-3)
|
||||
max_diff = numpy_cosine_similarity_distance(image_slice, expected_slice)
|
||||
assert max_diff < 5e-4
|
||||
|
||||
def test_image_to_image_sdxl(self):
|
||||
image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="sdxl_models/image_encoder")
|
||||
@@ -432,7 +438,8 @@ class IPAdapterSDXLIntegrationTests(IPAdapterNightlyTestsMixin):
|
||||
[0.14181179, 0.1493012, 0.14283323, 0.14602411, 0.14915377, 0.15015268, 0.14725655, 0.15009224, 0.15164584]
|
||||
)
|
||||
|
||||
assert np.allclose(image_slice, expected_slice, atol=1e-3)
|
||||
max_diff = numpy_cosine_similarity_distance(image_slice, expected_slice)
|
||||
assert max_diff < 5e-4
|
||||
|
||||
image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder")
|
||||
feature_extractor = self.get_image_processor("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k")
|
||||
@@ -457,4 +464,5 @@ class IPAdapterSDXLIntegrationTests(IPAdapterNightlyTestsMixin):
|
||||
|
||||
expected_slice = np.array([0.1398, 0.1476, 0.1407, 0.1442, 0.1470, 0.1480, 0.1449, 0.1481, 0.1494])
|
||||
|
||||
assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4)
|
||||
max_diff = numpy_cosine_similarity_distance(image_slice, expected_slice)
|
||||
assert max_diff < 5e-4
|
||||
|
||||
@@ -38,7 +38,7 @@ from diffusers.utils.testing_utils import (
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ..models.test_models_vae import (
|
||||
from ..models.autoencoders.test_models_vae import (
|
||||
get_asym_autoencoder_kl_config,
|
||||
get_autoencoder_kl_config,
|
||||
get_autoencoder_tiny_config,
|
||||
|
||||
Reference in New Issue
Block a user