mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-10 06:24:19 +08:00
Compare commits
35 Commits
animatedif
...
add-peft-t
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0f9427f0dd | ||
|
|
f6844d3cf6 | ||
|
|
daa75665cf | ||
|
|
43672b4a22 | ||
|
|
9df3d84382 | ||
|
|
c751449011 | ||
|
|
c1e8bdf1d4 | ||
|
|
78b87dc25a | ||
|
|
0af12f1f8a | ||
|
|
6e123688dc | ||
|
|
f0a588b8e2 | ||
|
|
fa31704420 | ||
|
|
9d79991da0 | ||
|
|
7d865ac9c6 | ||
|
|
fb02316db8 | ||
|
|
98a2b3d2d8 | ||
|
|
2026ec0a02 | ||
|
|
3706aa3305 | ||
|
|
d4f10ea362 | ||
|
|
3aba99af8f | ||
|
|
6683f97959 | ||
|
|
4e7b0cb396 | ||
|
|
35b81fffae | ||
|
|
e0d8c910e9 | ||
|
|
38aece94c4 | ||
|
|
b03aa10375 | ||
|
|
2bfdcabadc | ||
|
|
a837033105 | ||
|
|
566aaab423 | ||
|
|
9b910bdc5c | ||
|
|
965b40aa17 | ||
|
|
d62076ac5f | ||
|
|
3338ce0d40 | ||
|
|
565416c2e7 | ||
|
|
6338ad5b0b |
@@ -19,6 +19,8 @@
|
||||
title: Train a diffusion model
|
||||
- local: tutorials/using_peft_for_inference
|
||||
title: Inference with PEFT
|
||||
- local: tutorials/fast_diffusion
|
||||
title: Accelerate inference of text-to-image diffusion models
|
||||
title: Tutorials
|
||||
- sections:
|
||||
- sections:
|
||||
@@ -264,10 +266,6 @@
|
||||
title: ControlNet
|
||||
- local: api/pipelines/controlnet_sdxl
|
||||
title: ControlNet with Stable Diffusion XL
|
||||
- local: api/pipelines/controlnetxs
|
||||
title: ControlNet-XS
|
||||
- local: api/pipelines/controlnetxs_sdxl
|
||||
title: ControlNet-XS with Stable Diffusion XL
|
||||
- local: api/pipelines/dance_diffusion
|
||||
title: Dance Diffusion
|
||||
- local: api/pipelines/ddim
|
||||
|
||||
@@ -18,12 +18,24 @@ Amused is a vqvae token based transformer that can generate an image in fewer fo
|
||||
|
||||
| Model | Params |
|
||||
|-------|--------|
|
||||
| [amused-256](https://huggingface.co/huggingface/amused-256) | 603M |
|
||||
| [amused-512](https://huggingface.co/huggingface/amused-512) | 608M |
|
||||
| [amused-256](https://huggingface.co/amused/amused-256) | 603M |
|
||||
| [amused-512](https://huggingface.co/amused/amused-512) | 608M |
|
||||
|
||||
## AmusedPipeline
|
||||
|
||||
[[autodoc]] AmusedPipeline
|
||||
- __call__
|
||||
- all
|
||||
- enable_xformers_memory_efficient_attention
|
||||
- disable_xformers_memory_efficient_attention
|
||||
|
||||
[[autodoc]] AmusedImg2ImgPipeline
|
||||
- __call__
|
||||
- all
|
||||
- enable_xformers_memory_efficient_attention
|
||||
- disable_xformers_memory_efficient_attention
|
||||
|
||||
[[autodoc]] AmusedInpaintPipeline
|
||||
- __call__
|
||||
- all
|
||||
- enable_xformers_memory_efficient_attention
|
||||
|
||||
318
docs/source/en/tutorials/fast_diffusion.md
Normal file
318
docs/source/en/tutorials/fast_diffusion.md
Normal file
@@ -0,0 +1,318 @@
|
||||
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Accelerate inference of text-to-image diffusion models
|
||||
|
||||
Diffusion models are known to be slower than their counter parts, GANs, because of the iterative and sequential reverse diffusion process. Recent works try to address limitation with:
|
||||
|
||||
* progressive timestep distillation (such as [LCM LoRA](../using-diffusers/inference_with_lcm_lora.md))
|
||||
* model compression (such as [SSD-1B](https://huggingface.co/segmind/SSD-1B))
|
||||
* reusing adjacent features of the denoiser (such as [DeepCache](https://github.com/horseee/DeepCache))
|
||||
|
||||
In this tutorial, we focus on leveraging the power of PyTorch 2 to accelerate the inference latency of text-to-image diffusion pipeline, instead. We will use [Stable Diffusion XL (SDXL)](../using-diffusers/sdxl.md) as a case study, but the techniques we will discuss should extend to other text-to-image diffusion pipelines.
|
||||
|
||||
## Setup
|
||||
|
||||
Make sure you're on the latest version of `diffusers`:
|
||||
|
||||
```bash
|
||||
pip install -U diffusers
|
||||
```
|
||||
|
||||
Then upgrade the other required libraries too:
|
||||
|
||||
```bash
|
||||
pip install -U transformers accelerate peft
|
||||
```
|
||||
|
||||
To benefit from the fastest kernels, use PyTorch nightly. You can find the installation instructions [here](https://pytorch.org/).
|
||||
|
||||
To report the numbers shown below, we used an 80GB 400W A100 with its clock rate set to the maximum.
|
||||
|
||||
_This tutorial doesn't present the benchmarking code and focuses on how to perform the optimizations, instead. For the full benchmarking code, refer to: [https://github.com/huggingface/diffusion-fast](https://github.com/huggingface/diffusion-fast)._
|
||||
|
||||
## Baseline
|
||||
|
||||
Let's start with a baseline. Disable the use of a reduced precision and [`scaled_dot_product_attention`](../optimization/torch2.0.md):
|
||||
|
||||
```python
|
||||
from diffusers import StableDiffusionXLPipeline
|
||||
|
||||
# Load the pipeline in full-precision and place its model components on CUDA.
|
||||
pipe = StableDiffusionXLPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0"
|
||||
).to("cuda")
|
||||
|
||||
# Run the attention ops without efficiency.
|
||||
pipe.unet.set_default_attn_processor()
|
||||
pipe.vae.set_default_attn_processor()
|
||||
|
||||
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
|
||||
image = pipe(prompt, num_inference_steps=30).images[0]
|
||||
```
|
||||
|
||||
This takes 7.36 seconds:
|
||||
|
||||
<div align="center">
|
||||
|
||||
<img src="https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/progressive-acceleration-sdxl/SDXL%2C_Batch_Size%3A_1%2C_Steps%3A_30_0.png" width=500>
|
||||
|
||||
</div>
|
||||
|
||||
## Running inference in bfloat16
|
||||
|
||||
Enable the first optimization: use a reduced precision to run the inference.
|
||||
|
||||
```python
|
||||
from diffusers import StableDiffusionXLPipeline
|
||||
import torch
|
||||
|
||||
pipe = StableDiffusionXLPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
|
||||
).to("cuda")
|
||||
|
||||
# Run the attention ops without efficiency.
|
||||
pipe.unet.set_default_attn_processor()
|
||||
pipe.vae.set_default_attn_processor()
|
||||
|
||||
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
|
||||
image = pipe(prompt, num_inference_steps=30).images[0]
|
||||
```
|
||||
|
||||
bfloat16 reduces the latency from 7.36 seconds to 4.63 seconds:
|
||||
|
||||
<div align="center">
|
||||
|
||||
<img src="https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/progressive-acceleration-sdxl/SDXL%2C_Batch_Size%3A_1%2C_Steps%3A_30_1.png" width=500>
|
||||
|
||||
</div>
|
||||
|
||||
**Why bfloat16?**
|
||||
|
||||
* Using a reduced numerical precision (such as float16, bfloat16) to run inference doesn’t affect the generation quality but significantly improves latency.
|
||||
* The benefits of using the bfloat16 numerical precision as compared to float16 are hardware-dependent. Modern generations of GPUs tend to favor bfloat16.
|
||||
* Furthermore, in our experiments, we bfloat16 to be much more resilient when used with quantization in comparison to float16.
|
||||
|
||||
We have a [dedicated guide](../optimization/fp16.md) for running inference in a reduced precision.
|
||||
|
||||
## Running attention efficiently
|
||||
|
||||
Attention blocks are intensive to run. But with PyTorch's [`scaled_dot_product_attention`](../optimization/torch2.0.md), we can run them efficiently.
|
||||
|
||||
```python
|
||||
from diffusers import StableDiffusionXLPipeline
|
||||
import torch
|
||||
|
||||
pipe = StableDiffusionXLPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
|
||||
).to("cuda")
|
||||
|
||||
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
|
||||
image = pipe(prompt, num_inference_steps=30).images[0]
|
||||
```
|
||||
|
||||
`scaled_dot_product_attention` improves the latency from 4.63 seconds to 3.31 seconds.
|
||||
|
||||
<div align="center">
|
||||
|
||||
<img src="https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/progressive-acceleration-sdxl/SDXL%2C_Batch_Size%3A_1%2C_Steps%3A_30_2.png" width=500>
|
||||
|
||||
</div>
|
||||
|
||||
## Use faster kernels with torch.compile
|
||||
|
||||
Compile the UNet and the VAE to benefit from the faster kernels. First, configure a few compiler flags:
|
||||
|
||||
```python
|
||||
from diffusers import StableDiffusionXLPipeline
|
||||
import torch
|
||||
|
||||
torch._inductor.config.conv_1x1_as_mm = True
|
||||
torch._inductor.config.coordinate_descent_tuning = True
|
||||
torch._inductor.config.epilogue_fusion = False
|
||||
torch._inductor.config.coordinate_descent_check_all_directions = True
|
||||
```
|
||||
|
||||
For the full list of compiler flags, refer to [this file](https://github.com/pytorch/pytorch/blob/main/torch/_inductor/config.py).
|
||||
|
||||
It is also important to change the memory layout of the UNet and the VAE to “channels_last” when compiling them. This ensures maximum speed:
|
||||
|
||||
```python
|
||||
pipe.unet.to(memory_format=torch.channels_last)
|
||||
pipe.vae.to(memory_format=torch.channels_last)
|
||||
```
|
||||
|
||||
Then, compile and perform inference:
|
||||
|
||||
```python
|
||||
# Compile the UNet and VAE.
|
||||
pipe.unet = torch.compile(pipe.unet, mode="max-autotune", fullgraph=True)
|
||||
pipe.vae.decode = torch.compile(pipe.vae.decode, mode="max-autotune", fullgraph=True)
|
||||
|
||||
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
|
||||
|
||||
# First call to `pipe` will be slow, subsequent ones will be faster.
|
||||
image = pipe(prompt, num_inference_steps=30).images[0]
|
||||
```
|
||||
|
||||
`torch.compile` offers different backends and modes. As we’re aiming for maximum inference speed, we opt for the inductor backend using the “max-autotune”. “max-autotune” uses CUDA graphs and optimizes the compilation graph specifically for latency. Specifying fullgraph to be True ensures that there are no graph breaks in the underlying model, ensuring the fullest potential of `torch.compile`.
|
||||
|
||||
Using SDPA attention and compiling both the UNet and VAE reduces the latency from 3.31 seconds to 2.54 seconds.
|
||||
|
||||
<div align="center">
|
||||
|
||||
<img src="https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/progressive-acceleration-sdxl/SDXL%2C_Batch_Size%3A_1%2C_Steps%3A_30_3.png" width=500>
|
||||
|
||||
</div>
|
||||
|
||||
## Combine the projection matrices of attention
|
||||
|
||||
Both the UNet and the VAE used in SDXL make use of Transformer-like blocks. A Transformer block consists of attention blocks and feed-forward blocks.
|
||||
|
||||
In an attention block, the input is projected into three sub-spaces using three different projection matrices – Q, K, and V. In the naive implementation, these projections are performed separately on the input. But we can horizontally combine the projection matrices into a single matrix and perform the projection in one shot. This increases the size of the matmuls of the input projections and improves the impact of quantization (to be discussed next).
|
||||
|
||||
Enabling this kind of computation in Diffusers just takes a single line of code:
|
||||
|
||||
```python
|
||||
pipe.fuse_qkv_projections()
|
||||
```
|
||||
|
||||
It provides a minor boost from 2.54 seconds to 2.52 seconds.
|
||||
|
||||
<div align="center">
|
||||
|
||||
<img src="https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/progressive-acceleration-sdxl/SDXL%2C_Batch_Size%3A_1%2C_Steps%3A_30_4.png" width=500>
|
||||
|
||||
</div>
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Support for `fuse_qkv_projections()` is limited and experimental. As such, it's not available for many non-SD pipelines such as [Kandinsky](../using-diffusers/kandinsky.md). You can refer to [this PR](https://github.com/huggingface/diffusers/pull/6179) to get an idea about how to support this kind of computation.
|
||||
|
||||
</Tip>
|
||||
|
||||
## Dynamic quantization
|
||||
|
||||
Aapply [dynamic int8 quantization](https://pytorch.org/tutorials/recipes/recipes/dynamic_quantization.html) to both the UNet and the VAE. This is because quantization adds additional conversion overhead to the model that is hopefully made up for by faster matmuls (dynamic quantization). If the matmuls are too small, these techniques may degrade performance.
|
||||
|
||||
<Tip>
|
||||
|
||||
Through experimentation, we found that certain linear layers in the UNet and the VAE don’t benefit from dynamic int8 quantization. You can check out the full code for filtering those layers [here](https://github.com/huggingface/diffusion-fast/blob/0f169640b1db106fe6a479f78c1ed3bfaeba3386/utils/pipeline_utils.py#L16) (referred to as `dynamic_quant_filter_fn` below).
|
||||
|
||||
</Tip>
|
||||
|
||||
You will leverage the ultra-lightweight pure PyTorch library [torchao](https://github.com/pytorch-labs/ao) to use its user-friendly APIs for quantization.
|
||||
|
||||
First, configure all the compiler tags:
|
||||
|
||||
```python
|
||||
from diffusers import StableDiffusionXLPipeline
|
||||
import torch
|
||||
|
||||
# Notice the two new flags at the end.
|
||||
torch._inductor.config.conv_1x1_as_mm = True
|
||||
torch._inductor.config.coordinate_descent_tuning = True
|
||||
torch._inductor.config.epilogue_fusion = False
|
||||
torch._inductor.config.coordinate_descent_check_all_directions = True
|
||||
torch._inductor.config.force_fuse_int_mm_with_mul = True
|
||||
torch._inductor.config.use_mixed_mm = True
|
||||
```
|
||||
|
||||
Define the filtering functions:
|
||||
|
||||
```python
|
||||
def dynamic_quant_filter_fn(mod, *args):
|
||||
return (
|
||||
isinstance(mod, torch.nn.Linear)
|
||||
and mod.in_features > 16
|
||||
and (mod.in_features, mod.out_features)
|
||||
not in [
|
||||
(1280, 640),
|
||||
(1920, 1280),
|
||||
(1920, 640),
|
||||
(2048, 1280),
|
||||
(2048, 2560),
|
||||
(2560, 1280),
|
||||
(256, 128),
|
||||
(2816, 1280),
|
||||
(320, 640),
|
||||
(512, 1536),
|
||||
(512, 256),
|
||||
(512, 512),
|
||||
(640, 1280),
|
||||
(640, 1920),
|
||||
(640, 320),
|
||||
(640, 5120),
|
||||
(640, 640),
|
||||
(960, 320),
|
||||
(960, 640),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def conv_filter_fn(mod, *args):
|
||||
return (
|
||||
isinstance(mod, torch.nn.Conv2d) and mod.kernel_size == (1, 1) and 128 in [mod.in_channels, mod.out_channels]
|
||||
)
|
||||
```
|
||||
|
||||
Then apply all the optimizations discussed so far:
|
||||
|
||||
```python
|
||||
# SDPA + bfloat16.
|
||||
pipe = StableDiffusionXLPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
|
||||
).to("cuda")
|
||||
|
||||
# Combine attention projection matrices.
|
||||
pipe.fuse_qkv_projections()
|
||||
|
||||
# Change the memory layout.
|
||||
pipe.unet.to(memory_format=torch.channels_last)
|
||||
pipe.vae.to(memory_format=torch.channels_last)
|
||||
```
|
||||
|
||||
Since this quantization support is limited to linear layers only, we also turn suitable pointwise convolution layers into linear layers to maximize the benefit.
|
||||
|
||||
```python
|
||||
from torchao import swap_conv2d_1x1_to_linear
|
||||
|
||||
swap_conv2d_1x1_to_linear(pipe.unet, conv_filter_fn)
|
||||
swap_conv2d_1x1_to_linear(pipe.vae, conv_filter_fn)
|
||||
```
|
||||
|
||||
Apply dynamic quantization:
|
||||
|
||||
```python
|
||||
from torchao import apply_dynamic_quant
|
||||
|
||||
apply_dynamic_quant(pipe.unet, dynamic_quant_filter_fn)
|
||||
apply_dynamic_quant(pipe.vae, dynamic_quant_filter_fn)
|
||||
```
|
||||
|
||||
Finally, compile and perform inference:
|
||||
|
||||
```python
|
||||
pipe.unet = torch.compile(pipe.unet, mode="max-autotune", fullgraph=True)
|
||||
pipe.vae.decode = torch.compile(pipe.vae.decode, mode="max-autotune", fullgraph=True)
|
||||
|
||||
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
|
||||
image = pipe(prompt, num_inference_steps=30).images[0]
|
||||
```
|
||||
|
||||
Applying dynamic quantization improves the latency from 2.52 seconds to 2.43 seconds.
|
||||
|
||||
<div align="center">
|
||||
|
||||
<img src="https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/progressive-acceleration-sdxl/SDXL%2C_Batch_Size%3A_1%2C_Steps%3A_30_5.png" width=500>
|
||||
|
||||
</div>
|
||||
@@ -183,3 +183,26 @@ image = pipe(prompt, num_inference_steps=30, generator=torch.manual_seed(0)).ima
|
||||
# Gets the Unet back to the original state
|
||||
pipe.unfuse_lora()
|
||||
```
|
||||
|
||||
You can also fuse some adapters using `adapter_names` for faster generation:
|
||||
|
||||
```py
|
||||
pipe.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
|
||||
pipe.load_lora_weights("CiroN2022/toy-face", weight_name="toy_face_sdxl.safetensors", adapter_name="toy")
|
||||
|
||||
pipe.set_adapters(["pixel"], adapter_weights=[0.5, 1.0])
|
||||
# Fuses the LoRAs into the Unet
|
||||
pipe.fuse_lora(adapter_names=["pixel"])
|
||||
|
||||
prompt = "a hacker with a hoodie, pixel art"
|
||||
image = pipe(prompt, num_inference_steps=30, generator=torch.manual_seed(0)).images[0]
|
||||
|
||||
# Gets the Unet back to the original state
|
||||
pipe.unfuse_lora()
|
||||
|
||||
# Fuse all adapters
|
||||
pipe.fuse_lora(adapter_names=["pixel", "toy"])
|
||||
|
||||
prompt = "toy_face of a hacker with a hoodie, pixel art"
|
||||
image = pipe(prompt, num_inference_steps=30, generator=torch.manual_seed(0)).images[0]
|
||||
```
|
||||
|
||||
@@ -63,3 +63,42 @@ With callbacks, you can implement features such as dynamic CFG without having to
|
||||
🤗 Diffusers currently only supports `callback_on_step_end`, but feel free to open a [feature request](https://github.com/huggingface/diffusers/issues/new/choose) if you have a cool use-case and require a callback function with a different execution point!
|
||||
|
||||
</Tip>
|
||||
|
||||
|
||||
## Using Callbacks to interrupt the Diffusion Process
|
||||
|
||||
The following Pipelines support interrupting the diffusion process via callback
|
||||
|
||||
- [StableDiffusionPipeline](../api/pipelines/stable_diffusion/overview.md)
|
||||
- [StableDiffusionImg2ImgPipeline](..api/pipelines/stable_diffusion/img2img.md)
|
||||
- [StableDiffusionInpaintPipeline](..api/pipelines/stable_diffusion/inpaint.md)
|
||||
- [StableDiffusionXLPipeline](../api/pipelines/stable_diffusion/stable_diffusion_xl.md)
|
||||
- [StableDiffusionXLImg2ImgPipeline](../api/pipelines/stable_diffusion/stable_diffusion_xl.md)
|
||||
- [StableDiffusionXLInpaintPipeline](../api/pipelines/stable_diffusion/stable_diffusion_xl.md)
|
||||
|
||||
Interrupting the diffusion process is particularly useful when building UIs that work with Diffusers because it allows users to stop the generation process if they're unhappy with the intermediate results. You can incorporate this into your pipeline with a callback.
|
||||
|
||||
This callback function should take the following arguments: `pipe`, `i`, `t`, and `callback_kwargs` (this must be returned). Set the pipeline's `_interrupt` attribute to `True` to stop the diffusion process after a certain number of steps. You are also free to implement your own custom stopping logic inside the callback.
|
||||
|
||||
In this example, the diffusion process is stopped after 10 steps even though `num_inference_steps` is set to 50.
|
||||
|
||||
```python
|
||||
from diffusers import StableDiffusionPipeline
|
||||
|
||||
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
|
||||
pipe.enable_model_cpu_offload()
|
||||
num_inference_steps = 50
|
||||
|
||||
def interrupt_callback(pipe, i, t, callback_kwargs):
|
||||
stop_idx = 10
|
||||
if i == stop_idx:
|
||||
pipe._interrupt = True
|
||||
|
||||
return callback_kwargs
|
||||
|
||||
pipe(
|
||||
"A photo of a cat",
|
||||
num_inference_steps=num_inference_steps,
|
||||
callback_on_step_end=interrupt_callback,
|
||||
)
|
||||
```
|
||||
|
||||
@@ -44,7 +44,7 @@ pipe = StableVideoDiffusionPipeline.from_pretrained(
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
# Load the conditioning image
|
||||
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/svd/rocket.png?download=true")
|
||||
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/svd/rocket.png")
|
||||
image = image.resize((1024, 576))
|
||||
|
||||
generator = torch.manual_seed(42)
|
||||
@@ -58,6 +58,11 @@ export_to_video(frames, "generated.mp4", fps=7)
|
||||
<source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/svd/rocket_generated.mp4" type="video/mp4" />
|
||||
</video>
|
||||
|
||||
| **Source Image** | **Video** |
|
||||
|:------------:|:-----:|
|
||||
|  |  |
|
||||
|
||||
|
||||
<Tip>
|
||||
Since generating videos is more memory intensive we can use the `decode_chunk_size` argument to control how many frames are decoded at once. This will reduce the memory usage. It's recommended to tweak this value based on your GPU memory.
|
||||
Setting `decode_chunk_size=1` will decode one frame at a time and will use the least amount of memory but the video might have some flickering.
|
||||
@@ -120,7 +125,7 @@ pipe = StableVideoDiffusionPipeline.from_pretrained(
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
# Load the conditioning image
|
||||
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/svd/rocket.png?download=true")
|
||||
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/svd/rocket.png")
|
||||
image = image.resize((1024, 576))
|
||||
|
||||
generator = torch.manual_seed(42)
|
||||
@@ -128,7 +133,5 @@ frames = pipe(image, decode_chunk_size=8, generator=generator, motion_bucket_id=
|
||||
export_to_video(frames, "generated.mp4", fps=7)
|
||||
```
|
||||
|
||||
<video width="1024" height="576" controls>
|
||||
<source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/svd/rocket_generated_motion.mp4" type="video/mp4">
|
||||
</video>
|
||||

|
||||
|
||||
|
||||
@@ -37,6 +37,8 @@ from accelerate.logging import get_logger
|
||||
from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
|
||||
from huggingface_hub import create_repo, upload_folder
|
||||
from packaging import version
|
||||
from peft import LoraConfig
|
||||
from peft.utils import get_peft_model_state_dict
|
||||
from PIL import Image
|
||||
from PIL.ImageOps import exif_transpose
|
||||
from safetensors.torch import save_file
|
||||
@@ -54,10 +56,9 @@ from diffusers import (
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.loaders import LoraLoaderMixin
|
||||
from diffusers.models.lora import LoRALinearLayer
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import compute_snr, unet_lora_state_dict
|
||||
from diffusers.utils import check_min_version, is_wandb_available
|
||||
from diffusers.training_utils import compute_snr
|
||||
from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
@@ -67,39 +68,6 @@ check_min_version("0.25.0.dev0")
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
# TODO: This function should be removed once training scripts are rewritten in PEFT
|
||||
def text_encoder_lora_state_dict(text_encoder):
|
||||
state_dict = {}
|
||||
|
||||
def text_encoder_attn_modules(text_encoder):
|
||||
from transformers import CLIPTextModel, CLIPTextModelWithProjection
|
||||
|
||||
attn_modules = []
|
||||
|
||||
if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
|
||||
for i, layer in enumerate(text_encoder.text_model.encoder.layers):
|
||||
name = f"text_model.encoder.layers.{i}.self_attn"
|
||||
mod = layer.self_attn
|
||||
attn_modules.append((name, mod))
|
||||
|
||||
return attn_modules
|
||||
|
||||
for name, module in text_encoder_attn_modules(text_encoder):
|
||||
for k, v in module.q_proj.lora_linear_layer.state_dict().items():
|
||||
state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v
|
||||
|
||||
for k, v in module.k_proj.lora_linear_layer.state_dict().items():
|
||||
state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v
|
||||
|
||||
for k, v in module.v_proj.lora_linear_layer.state_dict().items():
|
||||
state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v
|
||||
|
||||
for k, v in module.out_proj.lora_linear_layer.state_dict().items():
|
||||
state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v
|
||||
|
||||
return state_dict
|
||||
|
||||
|
||||
def save_model_card(
|
||||
repo_id: str,
|
||||
images=None,
|
||||
@@ -1262,54 +1230,25 @@ def main(args):
|
||||
text_encoder_two.gradient_checkpointing_enable()
|
||||
|
||||
# now we will add new LoRA weights to the attention layers
|
||||
# Set correct lora layers
|
||||
unet_lora_parameters = []
|
||||
for attn_processor_name, attn_processor in unet.attn_processors.items():
|
||||
# Parse the attention module.
|
||||
attn_module = unet
|
||||
for n in attn_processor_name.split(".")[:-1]:
|
||||
attn_module = getattr(attn_module, n)
|
||||
|
||||
# Set the `lora_layer` attribute of the attention-related matrices.
|
||||
attn_module.to_q.set_lora_layer(
|
||||
LoRALinearLayer(
|
||||
in_features=attn_module.to_q.in_features, out_features=attn_module.to_q.out_features, rank=args.rank
|
||||
)
|
||||
)
|
||||
attn_module.to_k.set_lora_layer(
|
||||
LoRALinearLayer(
|
||||
in_features=attn_module.to_k.in_features, out_features=attn_module.to_k.out_features, rank=args.rank
|
||||
)
|
||||
)
|
||||
attn_module.to_v.set_lora_layer(
|
||||
LoRALinearLayer(
|
||||
in_features=attn_module.to_v.in_features, out_features=attn_module.to_v.out_features, rank=args.rank
|
||||
)
|
||||
)
|
||||
attn_module.to_out[0].set_lora_layer(
|
||||
LoRALinearLayer(
|
||||
in_features=attn_module.to_out[0].in_features,
|
||||
out_features=attn_module.to_out[0].out_features,
|
||||
rank=args.rank,
|
||||
)
|
||||
)
|
||||
|
||||
# Accumulate the LoRA params to optimize.
|
||||
unet_lora_parameters.extend(attn_module.to_q.lora_layer.parameters())
|
||||
unet_lora_parameters.extend(attn_module.to_k.lora_layer.parameters())
|
||||
unet_lora_parameters.extend(attn_module.to_v.lora_layer.parameters())
|
||||
unet_lora_parameters.extend(attn_module.to_out[0].lora_layer.parameters())
|
||||
unet_lora_config = LoraConfig(
|
||||
r=args.rank,
|
||||
lora_alpha=args.rank,
|
||||
init_lora_weights="gaussian",
|
||||
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
|
||||
)
|
||||
unet.add_adapter(unet_lora_config)
|
||||
|
||||
# The text encoder comes from 🤗 transformers, so we cannot directly modify it.
|
||||
# So, instead, we monkey-patch the forward calls of its attention-blocks.
|
||||
if args.train_text_encoder:
|
||||
# ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
|
||||
text_lora_parameters_one = LoraLoaderMixin._modify_text_encoder(
|
||||
text_encoder_one, dtype=torch.float32, rank=args.rank
|
||||
)
|
||||
text_lora_parameters_two = LoraLoaderMixin._modify_text_encoder(
|
||||
text_encoder_two, dtype=torch.float32, rank=args.rank
|
||||
text_lora_config = LoraConfig(
|
||||
r=args.rank,
|
||||
lora_alpha=args.rank,
|
||||
init_lora_weights="gaussian",
|
||||
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
|
||||
)
|
||||
text_encoder_one.add_adapter(text_lora_config)
|
||||
text_encoder_two.add_adapter(text_lora_config)
|
||||
|
||||
# if we use textual inversion, we freeze all parameters except for the token embeddings
|
||||
# in text encoder
|
||||
@@ -1333,6 +1272,17 @@ def main(args):
|
||||
else:
|
||||
param.requires_grad = False
|
||||
|
||||
# Make sure the trainable params are in float32.
|
||||
if args.mixed_precision == "fp16":
|
||||
models = [unet]
|
||||
if args.train_text_encoder:
|
||||
models.extend([text_encoder_one, text_encoder_two])
|
||||
for model in models:
|
||||
for param in model.parameters():
|
||||
# only upcast trainable parameters (LoRA) into fp32
|
||||
if param.requires_grad:
|
||||
param.data = param.to(torch.float32)
|
||||
|
||||
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
||||
def save_model_hook(models, weights, output_dir):
|
||||
if accelerator.is_main_process:
|
||||
@@ -1344,11 +1294,15 @@ def main(args):
|
||||
|
||||
for model in models:
|
||||
if isinstance(model, type(accelerator.unwrap_model(unet))):
|
||||
unet_lora_layers_to_save = unet_lora_state_dict(model)
|
||||
unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model))
|
||||
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
|
||||
text_encoder_one_lora_layers_to_save = text_encoder_lora_state_dict(model)
|
||||
text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers(
|
||||
get_peft_model_state_dict(model)
|
||||
)
|
||||
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))):
|
||||
text_encoder_two_lora_layers_to_save = text_encoder_lora_state_dict(model)
|
||||
text_encoder_two_lora_layers_to_save = convert_state_dict_to_diffusers(
|
||||
get_peft_model_state_dict(model)
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
|
||||
@@ -1405,6 +1359,12 @@ def main(args):
|
||||
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
|
||||
)
|
||||
|
||||
unet_lora_parameters = list(filter(lambda p: p.requires_grad, unet.parameters()))
|
||||
|
||||
if args.train_text_encoder:
|
||||
text_lora_parameters_one = list(filter(lambda p: p.requires_grad, text_encoder_one.parameters()))
|
||||
text_lora_parameters_two = list(filter(lambda p: p.requires_grad, text_encoder_two.parameters()))
|
||||
|
||||
# If neither --train_text_encoder nor --train_text_encoder_ti, text_encoders remain frozen during training
|
||||
freeze_text_encoder = not (args.train_text_encoder or args.train_text_encoder_ti)
|
||||
|
||||
@@ -1995,13 +1955,17 @@ def main(args):
|
||||
if accelerator.is_main_process:
|
||||
unet = accelerator.unwrap_model(unet)
|
||||
unet = unet.to(torch.float32)
|
||||
unet_lora_layers = unet_lora_state_dict(unet)
|
||||
unet_lora_layers = get_peft_model_state_dict(unet)
|
||||
|
||||
if args.train_text_encoder:
|
||||
text_encoder_one = accelerator.unwrap_model(text_encoder_one)
|
||||
text_encoder_lora_layers = text_encoder_lora_state_dict(text_encoder_one.to(torch.float32))
|
||||
text_encoder_lora_layers = convert_state_dict_to_diffusers(
|
||||
get_peft_model_state_dict(text_encoder_one.to(torch.float32))
|
||||
)
|
||||
text_encoder_two = accelerator.unwrap_model(text_encoder_two)
|
||||
text_encoder_2_lora_layers = text_encoder_lora_state_dict(text_encoder_two.to(torch.float32))
|
||||
text_encoder_2_lora_layers = convert_state_dict_to_diffusers(
|
||||
get_peft_model_state_dict(text_encoder_two.to(torch.float32))
|
||||
)
|
||||
else:
|
||||
text_encoder_lora_layers = None
|
||||
text_encoder_2_lora_layers = None
|
||||
|
||||
@@ -29,7 +29,7 @@ accelerate launch train_amused.py \
|
||||
--train_batch_size <batch size> \
|
||||
--gradient_accumulation_steps <gradient accumulation steps> \
|
||||
--learning_rate 1e-4 \
|
||||
--pretrained_model_name_or_path huggingface/amused-256 \
|
||||
--pretrained_model_name_or_path amused/amused-256 \
|
||||
--instance_data_dataset 'm1guelpf/nouns' \
|
||||
--image_key image \
|
||||
--prompt_key text \
|
||||
@@ -70,7 +70,7 @@ accelerate launch train_amused.py \
|
||||
--gradient_accumulation_steps <gradient accumulation steps> \
|
||||
--learning_rate 2e-5 \
|
||||
--use_8bit_adam \
|
||||
--pretrained_model_name_or_path huggingface/amused-256 \
|
||||
--pretrained_model_name_or_path amused/amused-256 \
|
||||
--instance_data_dataset 'm1guelpf/nouns' \
|
||||
--image_key image \
|
||||
--prompt_key text \
|
||||
@@ -109,7 +109,7 @@ accelerate launch train_amused.py \
|
||||
--gradient_accumulation_steps <gradient accumulation steps> \
|
||||
--learning_rate 8e-4 \
|
||||
--use_lora \
|
||||
--pretrained_model_name_or_path huggingface/amused-256 \
|
||||
--pretrained_model_name_or_path amused/amused-256 \
|
||||
--instance_data_dataset 'm1guelpf/nouns' \
|
||||
--image_key image \
|
||||
--prompt_key text \
|
||||
@@ -155,7 +155,7 @@ accelerate launch train_amused.py \
|
||||
--train_batch_size <batch size> \
|
||||
--gradient_accumulation_steps <gradient accumulation steps> \
|
||||
--learning_rate 8e-5 \
|
||||
--pretrained_model_name_or_path huggingface/amused-512 \
|
||||
--pretrained_model_name_or_path amused/amused-512 \
|
||||
--instance_data_dataset 'monadical-labs/minecraft-preview' \
|
||||
--prompt_prefix 'minecraft ' \
|
||||
--image_key image \
|
||||
@@ -191,7 +191,7 @@ accelerate launch train_amused.py \
|
||||
--train_batch_size <batch size> \
|
||||
--gradient_accumulation_steps <gradient accumulation steps> \
|
||||
--learning_rate 5e-6 \
|
||||
--pretrained_model_name_or_path huggingface/amused-512 \
|
||||
--pretrained_model_name_or_path amused/amused-512 \
|
||||
--instance_data_dataset 'monadical-labs/minecraft-preview' \
|
||||
--prompt_prefix 'minecraft ' \
|
||||
--image_key image \
|
||||
@@ -228,7 +228,7 @@ accelerate launch train_amused.py \
|
||||
--gradient_accumulation_steps <gradient accumulation steps> \
|
||||
--learning_rate 1e-4 \
|
||||
--use_lora \
|
||||
--pretrained_model_name_or_path huggingface/amused-512 \
|
||||
--pretrained_model_name_or_path amused/amused-512 \
|
||||
--instance_data_dataset 'monadical-labs/minecraft-preview' \
|
||||
--prompt_prefix 'minecraft ' \
|
||||
--image_key image \
|
||||
@@ -276,7 +276,7 @@ accelerate launch train_amused.py \
|
||||
--mixed_precision fp16 \
|
||||
--report_to wandb \
|
||||
--use_lora \
|
||||
--pretrained_model_name_or_path huggingface/amused-256 \
|
||||
--pretrained_model_name_or_path amused/amused-256 \
|
||||
--train_batch_size 1 \
|
||||
--lr_scheduler constant \
|
||||
--learning_rate 4e-4 \
|
||||
@@ -308,7 +308,7 @@ accelerate launch train_amused.py \
|
||||
--mixed_precision fp16 \
|
||||
--report_to wandb \
|
||||
--use_lora \
|
||||
--pretrained_model_name_or_path huggingface/amused-512 \
|
||||
--pretrained_model_name_or_path amused/amused-512 \
|
||||
--train_batch_size 1 \
|
||||
--lr_scheduler constant \
|
||||
--learning_rate 1e-3 \
|
||||
|
||||
@@ -50,6 +50,7 @@ from diffusers.pipelines.stable_diffusion import (
|
||||
StableDiffusionPipelineOutput,
|
||||
StableDiffusionSafetyChecker,
|
||||
)
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import retrieve_latents
|
||||
from diffusers.schedulers import DDIMScheduler
|
||||
from diffusers.utils import logging
|
||||
|
||||
@@ -608,7 +609,7 @@ class TorchVAEEncoder(torch.nn.Module):
|
||||
self.vae_encoder = model
|
||||
|
||||
def forward(self, x):
|
||||
return self.vae_encoder.encode(x).latent_dist.sample()
|
||||
return retrieve_latents(self.vae_encoder.encode(x))
|
||||
|
||||
|
||||
class VAEEncoder(BaseModel):
|
||||
|
||||
@@ -111,4 +111,38 @@ accelerate launch train_lcm_distill_lora_sdxl_wds.py \
|
||||
--report_to=wandb \
|
||||
--seed=453645634 \
|
||||
--push_to_hub \
|
||||
```
|
||||
```
|
||||
|
||||
We provide another version for LCM LoRA SDXL that follows best practices of `peft` and leverages the `datasets` library for quick experimentation. The script doesn't load two UNets unlike `train_lcm_distill_lora_sdxl_wds.py` which reduces the memory requirements quite a bit.
|
||||
|
||||
Below is an example training command that trains an LCM LoRA on the [Pokemons dataset](https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions):
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0"
|
||||
export DATASET_NAME="lambdalabs/pokemon-blip-captions"
|
||||
export VAE_PATH="madebyollin/sdxl-vae-fp16-fix"
|
||||
|
||||
accelerate launch train_lcm_distill_lora_sdxl.py \
|
||||
--pretrained_teacher_model=${MODEL_NAME} \
|
||||
--pretrained_vae_model_name_or_path=${VAE_PATH} \
|
||||
--output_dir="pokemons-lora-lcm-sdxl" \
|
||||
--mixed_precision="fp16" \
|
||||
--dataset_name=$DATASET_NAME \
|
||||
--resolution=1024 \
|
||||
--train_batch_size=24 \
|
||||
--gradient_accumulation_steps=1 \
|
||||
--gradient_checkpointing \
|
||||
--use_8bit_adam \
|
||||
--lora_rank=64 \
|
||||
--learning_rate=1e-4 \
|
||||
--report_to="wandb" \
|
||||
--lr_scheduler="constant" \
|
||||
--lr_warmup_steps=0 \
|
||||
--max_train_steps=3000 \
|
||||
--checkpointing_steps=500 \
|
||||
--validation_steps=50 \
|
||||
--seed="0" \
|
||||
--report_to="wandb" \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
|
||||
112
examples/consistency_distillation/test_lcm_lora.py
Normal file
112
examples/consistency_distillation/test_lcm_lora.py
Normal file
@@ -0,0 +1,112 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
import safetensors
|
||||
|
||||
|
||||
sys.path.append("..")
|
||||
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
logger = logging.getLogger()
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
|
||||
class TextToImageLCM(ExamplesTestsAccelerate):
|
||||
def test_text_to_image_lcm_lora_sdxl(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
examples/consistency_distillation/train_lcm_distill_lora_sdxl.py
|
||||
--pretrained_teacher_model hf-internal-testing/tiny-stable-diffusion-xl-pipe
|
||||
--dataset_name hf-internal-testing/dummy_image_text_data
|
||||
--resolution 64
|
||||
--lora_rank 4
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 2
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
# save_pretrained smoke test
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
|
||||
|
||||
# make sure the state_dict has the correct naming in the parameters.
|
||||
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
|
||||
is_lora = all("lora" in k for k in lora_state_dict.keys())
|
||||
self.assertTrue(is_lora)
|
||||
|
||||
def test_text_to_image_lcm_lora_sdxl_checkpointing(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
examples/consistency_distillation/train_lcm_distill_lora_sdxl.py
|
||||
--pretrained_teacher_model hf-internal-testing/tiny-stable-diffusion-xl-pipe
|
||||
--dataset_name hf-internal-testing/dummy_image_text_data
|
||||
--resolution 64
|
||||
--lora_rank 4
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 7
|
||||
--checkpointing_steps 2
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
|
||||
self.assertEqual(
|
||||
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
|
||||
{"checkpoint-2", "checkpoint-4", "checkpoint-6"},
|
||||
)
|
||||
|
||||
test_args = f"""
|
||||
examples/consistency_distillation/train_lcm_distill_lora_sdxl.py
|
||||
--pretrained_teacher_model hf-internal-testing/tiny-stable-diffusion-xl-pipe
|
||||
--dataset_name hf-internal-testing/dummy_image_text_data
|
||||
--resolution 64
|
||||
--lora_rank 4
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 9
|
||||
--checkpointing_steps 2
|
||||
--resume_from_checkpoint latest
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
|
||||
self.assertEqual(
|
||||
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
|
||||
{"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"},
|
||||
)
|
||||
@@ -38,7 +38,7 @@ from accelerate import Accelerator
|
||||
from accelerate.logging import get_logger
|
||||
from accelerate.utils import ProjectConfiguration, set_seed
|
||||
from braceexpand import braceexpand
|
||||
from huggingface_hub import create_repo
|
||||
from huggingface_hub import create_repo, upload_folder
|
||||
from packaging import version
|
||||
from peft import LoraConfig, get_peft_model, get_peft_model_state_dict
|
||||
from torch.utils.data import default_collate
|
||||
@@ -847,7 +847,7 @@ def main(args):
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
if args.push_to_hub:
|
||||
create_repo(
|
||||
repo_id = create_repo(
|
||||
repo_id=args.hub_model_id or Path(args.output_dir).name,
|
||||
exist_ok=True,
|
||||
token=args.hub_token,
|
||||
@@ -1366,6 +1366,14 @@ def main(args):
|
||||
lora_state_dict = get_peft_model_state_dict(unet, adapter_name="default")
|
||||
StableDiffusionPipeline.save_lora_weights(os.path.join(args.output_dir, "unet_lora"), lora_state_dict)
|
||||
|
||||
if args.push_to_hub:
|
||||
upload_folder(
|
||||
repo_id=repo_id,
|
||||
folder_path=args.output_dir,
|
||||
commit_message="End of training",
|
||||
ignore_patterns=["step_*", "epoch_*"],
|
||||
)
|
||||
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
|
||||
1358
examples/consistency_distillation/train_lcm_distill_lora_sdxl.py
Normal file
1358
examples/consistency_distillation/train_lcm_distill_lora_sdxl.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -39,7 +39,7 @@ from accelerate import Accelerator
|
||||
from accelerate.logging import get_logger
|
||||
from accelerate.utils import ProjectConfiguration, set_seed
|
||||
from braceexpand import braceexpand
|
||||
from huggingface_hub import create_repo
|
||||
from huggingface_hub import create_repo, upload_folder
|
||||
from packaging import version
|
||||
from peft import LoraConfig, get_peft_model, get_peft_model_state_dict
|
||||
from torch.utils.data import default_collate
|
||||
@@ -842,7 +842,7 @@ def main(args):
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
if args.push_to_hub:
|
||||
create_repo(
|
||||
repo_id = create_repo(
|
||||
repo_id=args.hub_model_id or Path(args.output_dir).name,
|
||||
exist_ok=True,
|
||||
token=args.hub_token,
|
||||
@@ -1424,6 +1424,14 @@ def main(args):
|
||||
lora_state_dict = get_peft_model_state_dict(unet, adapter_name="default")
|
||||
StableDiffusionXLPipeline.save_lora_weights(os.path.join(args.output_dir, "unet_lora"), lora_state_dict)
|
||||
|
||||
if args.push_to_hub:
|
||||
upload_folder(
|
||||
repo_id=repo_id,
|
||||
folder_path=args.output_dir,
|
||||
commit_message="End of training",
|
||||
ignore_patterns=["step_*", "epoch_*"],
|
||||
)
|
||||
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
|
||||
@@ -38,7 +38,7 @@ from accelerate import Accelerator
|
||||
from accelerate.logging import get_logger
|
||||
from accelerate.utils import ProjectConfiguration, set_seed
|
||||
from braceexpand import braceexpand
|
||||
from huggingface_hub import create_repo
|
||||
from huggingface_hub import create_repo, upload_folder
|
||||
from packaging import version
|
||||
from torch.utils.data import default_collate
|
||||
from torchvision import transforms
|
||||
@@ -835,7 +835,7 @@ def main(args):
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
if args.push_to_hub:
|
||||
create_repo(
|
||||
repo_id = create_repo(
|
||||
repo_id=args.hub_model_id or Path(args.output_dir).name,
|
||||
exist_ok=True,
|
||||
token=args.hub_token,
|
||||
@@ -889,6 +889,7 @@ def main(args):
|
||||
# Add `time_cond_proj_dim` to the student U-Net if `teacher_unet.config.time_cond_proj_dim` is None
|
||||
if teacher_unet.config.time_cond_proj_dim is None:
|
||||
teacher_unet.config["time_cond_proj_dim"] = args.unet_time_cond_proj_dim
|
||||
time_cond_proj_dim = teacher_unet.config.time_cond_proj_dim
|
||||
unet = UNet2DConditionModel(**teacher_unet.config)
|
||||
# load teacher_unet weights into unet
|
||||
unet.load_state_dict(teacher_unet.state_dict(), strict=False)
|
||||
@@ -1175,7 +1176,7 @@ def main(args):
|
||||
|
||||
# 5. Sample a random guidance scale w from U[w_min, w_max] and embed it
|
||||
w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min
|
||||
w_embedding = guidance_scale_embedding(w, embedding_dim=unet.config.time_cond_proj_dim)
|
||||
w_embedding = guidance_scale_embedding(w, embedding_dim=time_cond_proj_dim)
|
||||
w = w.reshape(bsz, 1, 1, 1)
|
||||
# Move to U-Net device and dtype
|
||||
w = w.to(device=latents.device, dtype=latents.dtype)
|
||||
@@ -1353,6 +1354,14 @@ def main(args):
|
||||
target_unet = accelerator.unwrap_model(target_unet)
|
||||
target_unet.save_pretrained(os.path.join(args.output_dir, "unet_target"))
|
||||
|
||||
if args.push_to_hub:
|
||||
upload_folder(
|
||||
repo_id=repo_id,
|
||||
folder_path=args.output_dir,
|
||||
commit_message="End of training",
|
||||
ignore_patterns=["step_*", "epoch_*"],
|
||||
)
|
||||
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
|
||||
@@ -39,7 +39,7 @@ from accelerate import Accelerator
|
||||
from accelerate.logging import get_logger
|
||||
from accelerate.utils import ProjectConfiguration, set_seed
|
||||
from braceexpand import braceexpand
|
||||
from huggingface_hub import create_repo
|
||||
from huggingface_hub import create_repo, upload_folder
|
||||
from packaging import version
|
||||
from torch.utils.data import default_collate
|
||||
from torchvision import transforms
|
||||
@@ -875,7 +875,7 @@ def main(args):
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
if args.push_to_hub:
|
||||
create_repo(
|
||||
repo_id = create_repo(
|
||||
repo_id=args.hub_model_id or Path(args.output_dir).name,
|
||||
exist_ok=True,
|
||||
token=args.hub_token,
|
||||
@@ -948,6 +948,7 @@ def main(args):
|
||||
# Add `time_cond_proj_dim` to the student U-Net if `teacher_unet.config.time_cond_proj_dim` is None
|
||||
if teacher_unet.config.time_cond_proj_dim is None:
|
||||
teacher_unet.config["time_cond_proj_dim"] = args.unet_time_cond_proj_dim
|
||||
time_cond_proj_dim = teacher_unet.config.time_cond_proj_dim
|
||||
unet = UNet2DConditionModel(**teacher_unet.config)
|
||||
# load teacher_unet weights into unet
|
||||
unet.load_state_dict(teacher_unet.state_dict(), strict=False)
|
||||
@@ -1273,7 +1274,7 @@ def main(args):
|
||||
|
||||
# 5. Sample a random guidance scale w from U[w_min, w_max] and embed it
|
||||
w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min
|
||||
w_embedding = guidance_scale_embedding(w, embedding_dim=unet.config.time_cond_proj_dim)
|
||||
w_embedding = guidance_scale_embedding(w, embedding_dim=time_cond_proj_dim)
|
||||
w = w.reshape(bsz, 1, 1, 1)
|
||||
# Move to U-Net device and dtype
|
||||
w = w.to(device=latents.device, dtype=latents.dtype)
|
||||
@@ -1456,6 +1457,14 @@ def main(args):
|
||||
target_unet = accelerator.unwrap_model(target_unet)
|
||||
target_unet.save_pretrained(os.path.join(args.output_dir, "unet_target"))
|
||||
|
||||
if args.push_to_hub:
|
||||
upload_folder(
|
||||
repo_id=repo_id,
|
||||
folder_path=args.output_dir,
|
||||
commit_message="End of training",
|
||||
ignore_patterns=["step_*", "epoch_*"],
|
||||
)
|
||||
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
|
||||
@@ -54,7 +54,7 @@ from diffusers import (
|
||||
from diffusers.loaders import LoraLoaderMixin
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import compute_snr
|
||||
from diffusers.utils import check_min_version, is_wandb_available
|
||||
from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
@@ -1019,11 +1019,15 @@ def main(args):
|
||||
|
||||
for model in models:
|
||||
if isinstance(model, type(accelerator.unwrap_model(unet))):
|
||||
unet_lora_layers_to_save = get_peft_model_state_dict(model)
|
||||
unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model))
|
||||
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
|
||||
text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model)
|
||||
text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers(
|
||||
get_peft_model_state_dict(model)
|
||||
)
|
||||
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))):
|
||||
text_encoder_two_lora_layers_to_save = get_peft_model_state_dict(model)
|
||||
text_encoder_two_lora_layers_to_save = convert_state_dict_to_diffusers(
|
||||
get_peft_model_state_dict(model)
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
|
||||
@@ -1615,13 +1619,17 @@ def main(args):
|
||||
if accelerator.is_main_process:
|
||||
unet = accelerator.unwrap_model(unet)
|
||||
unet = unet.to(torch.float32)
|
||||
unet_lora_layers = get_peft_model_state_dict(unet)
|
||||
unet_lora_layers = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet))
|
||||
|
||||
if args.train_text_encoder:
|
||||
text_encoder_one = accelerator.unwrap_model(text_encoder_one)
|
||||
text_encoder_lora_layers = get_peft_model_state_dict(text_encoder_one.to(torch.float32))
|
||||
text_encoder_lora_layers = convert_state_dict_to_diffusers(
|
||||
get_peft_model_state_dict(text_encoder_one.to(torch.float32))
|
||||
)
|
||||
text_encoder_two = accelerator.unwrap_model(text_encoder_two)
|
||||
text_encoder_2_lora_layers = get_peft_model_state_dict(text_encoder_two.to(torch.float32))
|
||||
text_encoder_2_lora_layers = convert_state_dict_to_diffusers(
|
||||
get_peft_model_state_dict(text_encoder_two.to(torch.float32))
|
||||
)
|
||||
else:
|
||||
text_encoder_lora_layers = None
|
||||
text_encoder_2_lora_layers = None
|
||||
|
||||
@@ -71,7 +71,7 @@ accelerate launch train_instruct_pix2pix_sdxl.py \
|
||||
|
||||
We recommend this type of validation as it can be useful for model debugging. Note that you need `wandb` installed to use this. You can install `wandb` by running `pip install wandb`.
|
||||
|
||||
[Here](https://wandb.ai/sayakpaul/instruct-pix2pix/runs/ctr3kovq), you can find an example training run that includes some validation samples and the training hyperparameters.
|
||||
[Here](https://wandb.ai/sayakpaul/instruct-pix2pix-sdxl-new/runs/sw53gxmc), you can find an example training run that includes some validation samples and the training hyperparameters.
|
||||
|
||||
***Note: In the original paper, the authors observed that even when the model is trained with an image resolution of 256x256, it generalizes well to bigger resolutions such as 512x512. This is likely because of the larger dataset they used during training.***
|
||||
|
||||
|
||||
@@ -1,15 +1,3 @@
|
||||
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# ControlNet-XS
|
||||
|
||||
ControlNet-XS was introduced in [ControlNet-XS](https://vislearn.github.io/ControlNet-XS/) by Denis Zavadski and Carsten Rother. It is based on the observation that the control model in the [original ControlNet](https://huggingface.co/papers/2302.05543) can be made much smaller and still produce good results.
|
||||
@@ -24,16 +12,5 @@ Here's the overview from the [project page](https://vislearn.github.io/ControlNe
|
||||
|
||||
This model was contributed by [UmerHA](https://twitter.com/UmerHAdil). ❤️
|
||||
|
||||
<Tip>
|
||||
|
||||
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
|
||||
|
||||
</Tip>
|
||||
|
||||
## StableDiffusionControlNetXSPipeline
|
||||
[[autodoc]] StableDiffusionControlNetXSPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## StableDiffusionPipelineOutput
|
||||
[[autodoc]] pipelines.stable_diffusion.StableDiffusionPipelineOutput
|
||||
> 🧠 Make sure to check out the Schedulers [guide](https://huggingface.co/docs/diffusers/main/en/using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
|
||||
@@ -1,15 +1,3 @@
|
||||
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# ControlNet-XS with Stable Diffusion XL
|
||||
|
||||
ControlNet-XS was introduced in [ControlNet-XS](https://vislearn.github.io/ControlNet-XS/) by Denis Zavadski and Carsten Rother. It is based on the observation that the control model in the [original ControlNet](https://huggingface.co/papers/2302.05543) can be made much smaller and still produce good results.
|
||||
@@ -24,22 +12,4 @@ Here's the overview from the [project page](https://vislearn.github.io/ControlNe
|
||||
|
||||
This model was contributed by [UmerHA](https://twitter.com/UmerHAdil). ❤️
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
🧪 Many of the SDXL ControlNet checkpoints are experimental, and there is a lot of room for improvement. Feel free to open an [Issue](https://github.com/huggingface/diffusers/issues/new/choose) and leave us feedback on how we can improve!
|
||||
|
||||
</Tip>
|
||||
|
||||
<Tip>
|
||||
|
||||
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
|
||||
|
||||
</Tip>
|
||||
|
||||
## StableDiffusionXLControlNetXSPipeline
|
||||
[[autodoc]] StableDiffusionXLControlNetXSPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## StableDiffusionPipelineOutput
|
||||
[[autodoc]] pipelines.stable_diffusion.StableDiffusionPipelineOutput
|
||||
> 🧠 Make sure to check out the Schedulers [guide](https://huggingface.co/docs/diffusers/main/en/using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
|
||||
@@ -21,13 +21,12 @@ from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from torch.nn.modules.normalization import GroupNorm
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import BaseOutput, logging
|
||||
from .attention_processor import USE_PEFT_BACKEND, AttentionProcessor
|
||||
from .autoencoders import AutoencoderKL
|
||||
from .lora import LoRACompatibleConv
|
||||
from .modeling_utils import ModelMixin
|
||||
from .unet_2d_blocks import (
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
from diffusers.models.attention_processor import USE_PEFT_BACKEND, AttentionProcessor
|
||||
from diffusers.models.autoencoders import AutoencoderKL
|
||||
from diffusers.models.lora import LoRACompatibleConv
|
||||
from diffusers.models.modeling_utils import ModelMixin
|
||||
from diffusers.models.unet_2d_blocks import (
|
||||
CrossAttnDownBlock2D,
|
||||
CrossAttnUpBlock2D,
|
||||
DownBlock2D,
|
||||
@@ -37,7 +36,8 @@ from .unet_2d_blocks import (
|
||||
UpBlock2D,
|
||||
Upsample2D,
|
||||
)
|
||||
from .unet_2d_condition import UNet2DConditionModel
|
||||
from diffusers.models.unet_2d_condition import UNet2DConditionModel
|
||||
from diffusers.utils import BaseOutput, logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
@@ -0,0 +1,58 @@
|
||||
# !pip install opencv-python transformers accelerate
|
||||
import argparse
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from controlnetxs import ControlNetXSModel
|
||||
from PIL import Image
|
||||
from pipeline_controlnet_xs import StableDiffusionControlNetXSPipeline
|
||||
|
||||
from diffusers.utils import load_image
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--prompt", type=str, default="aerial view, a futuristic research complex in a bright foggy jungle, hard lighting"
|
||||
)
|
||||
parser.add_argument("--negative_prompt", type=str, default="low quality, bad quality, sketches")
|
||||
parser.add_argument("--controlnet_conditioning_scale", type=float, default=0.7)
|
||||
parser.add_argument(
|
||||
"--image_path",
|
||||
type=str,
|
||||
default="https://hf.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png",
|
||||
)
|
||||
parser.add_argument("--num_inference_steps", type=int, default=50)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
prompt = args.prompt
|
||||
negative_prompt = args.negative_prompt
|
||||
# download an image
|
||||
image = load_image(args.image_path)
|
||||
|
||||
# initialize the models and pipeline
|
||||
controlnet_conditioning_scale = args.controlnet_conditioning_scale
|
||||
controlnet = ControlNetXSModel.from_pretrained("UmerHA/ConrolNetXS-SD2.1-canny", torch_dtype=torch.float16)
|
||||
pipe = StableDiffusionControlNetXSPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-2-1", controlnet=controlnet, torch_dtype=torch.float16
|
||||
)
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
# get canny image
|
||||
image = np.array(image)
|
||||
image = cv2.Canny(image, 100, 200)
|
||||
image = image[:, :, None]
|
||||
image = np.concatenate([image, image, image], axis=2)
|
||||
canny_image = Image.fromarray(image)
|
||||
|
||||
num_inference_steps = args.num_inference_steps
|
||||
|
||||
# generate image
|
||||
image = pipe(
|
||||
prompt,
|
||||
controlnet_conditioning_scale=controlnet_conditioning_scale,
|
||||
image=canny_image,
|
||||
num_inference_steps=num_inference_steps,
|
||||
).images[0]
|
||||
image.save("cnxs_sd.canny.png")
|
||||
@@ -0,0 +1,57 @@
|
||||
# !pip install opencv-python transformers accelerate
|
||||
import argparse
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from controlnetxs import ControlNetXSModel
|
||||
from PIL import Image
|
||||
from pipeline_controlnet_xs import StableDiffusionControlNetXSPipeline
|
||||
|
||||
from diffusers.utils import load_image
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--prompt", type=str, default="aerial view, a futuristic research complex in a bright foggy jungle, hard lighting"
|
||||
)
|
||||
parser.add_argument("--negative_prompt", type=str, default="low quality, bad quality, sketches")
|
||||
parser.add_argument("--controlnet_conditioning_scale", type=float, default=0.7)
|
||||
parser.add_argument(
|
||||
"--image_path",
|
||||
type=str,
|
||||
default="https://hf.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png",
|
||||
)
|
||||
parser.add_argument("--num_inference_steps", type=int, default=50)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
prompt = args.prompt
|
||||
negative_prompt = args.negative_prompt
|
||||
# download an image
|
||||
image = load_image(args.image_path)
|
||||
# initialize the models and pipeline
|
||||
controlnet_conditioning_scale = args.controlnet_conditioning_scale
|
||||
controlnet = ControlNetXSModel.from_pretrained("UmerHA/ConrolNetXS-SDXL-canny", torch_dtype=torch.float16)
|
||||
pipe = StableDiffusionControlNetXSPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, torch_dtype=torch.float16
|
||||
)
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
# get canny image
|
||||
image = np.array(image)
|
||||
image = cv2.Canny(image, 100, 200)
|
||||
image = image[:, :, None]
|
||||
image = np.concatenate([image, image, image], axis=2)
|
||||
canny_image = Image.fromarray(image)
|
||||
|
||||
num_inference_steps = args.num_inference_steps
|
||||
|
||||
# generate image
|
||||
image = pipe(
|
||||
prompt,
|
||||
controlnet_conditioning_scale=controlnet_conditioning_scale,
|
||||
image=canny_image,
|
||||
num_inference_steps=num_inference_steps,
|
||||
).images[0]
|
||||
image.save("cnxs_sdxl.canny.png")
|
||||
@@ -19,74 +19,30 @@ import numpy as np
|
||||
import PIL.Image
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from controlnetxs import ControlNetXSModel
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, ControlNetXSModel, UNet2DConditionModel
|
||||
from ...models.lora import adjust_lora_scale_text_encoder
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import (
|
||||
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||
from diffusers.models.lora import adjust_lora_scale_text_encoder
|
||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
from diffusers.schedulers import KarrasDiffusionSchedulers
|
||||
from diffusers.utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
deprecate,
|
||||
logging,
|
||||
replace_example_docstring,
|
||||
scale_lora_layers,
|
||||
unscale_lora_layers,
|
||||
)
|
||||
from ...utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from ..stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
|
||||
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
from diffusers.utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> # !pip install opencv-python transformers accelerate
|
||||
>>> from diffusers import StableDiffusionControlNetXSPipeline, ControlNetXSModel
|
||||
>>> from diffusers.utils import load_image
|
||||
>>> import numpy as np
|
||||
>>> import torch
|
||||
|
||||
>>> import cv2
|
||||
>>> from PIL import Image
|
||||
|
||||
>>> prompt = "aerial view, a futuristic research complex in a bright foggy jungle, hard lighting"
|
||||
>>> negative_prompt = "low quality, bad quality, sketches"
|
||||
|
||||
>>> # download an image
|
||||
>>> image = load_image(
|
||||
... "https://hf.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png"
|
||||
... )
|
||||
|
||||
>>> # initialize the models and pipeline
|
||||
>>> controlnet_conditioning_scale = 0.5
|
||||
>>> controlnet = ControlNetXSModel.from_pretrained(
|
||||
... "UmerHA/ConrolNetXS-SD2.1-canny", torch_dtype=torch.float16
|
||||
... )
|
||||
>>> pipe = StableDiffusionControlNetXSPipeline.from_pretrained(
|
||||
... "stabilityai/stable-diffusion-2-1", controlnet=controlnet, torch_dtype=torch.float16
|
||||
... )
|
||||
>>> pipe.enable_model_cpu_offload()
|
||||
|
||||
>>> # get canny image
|
||||
>>> image = np.array(image)
|
||||
>>> image = cv2.Canny(image, 100, 200)
|
||||
>>> image = image[:, :, None]
|
||||
>>> image = np.concatenate([image, image, image], axis=2)
|
||||
>>> canny_image = Image.fromarray(image)
|
||||
>>> # generate image
|
||||
>>> image = pipe(
|
||||
... prompt, controlnet_conditioning_scale=controlnet_conditioning_scale, image=canny_image
|
||||
... ).images[0]
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
class StableDiffusionControlNetXSPipeline(
|
||||
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
|
||||
):
|
||||
@@ -669,7 +625,6 @@ class StableDiffusionControlNetXSPipeline(
|
||||
self.unet.disable_freeu()
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
@@ -21,76 +21,36 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
|
||||
|
||||
from diffusers.utils.import_utils import is_invisible_watermark_available
|
||||
|
||||
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from ...loaders import FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, ControlNetXSModel, UNet2DConditionModel
|
||||
from ...models.attention_processor import (
|
||||
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from diffusers.loaders import FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from diffusers.models import AutoencoderKL, ControlNetXSModel, UNet2DConditionModel
|
||||
from diffusers.models.attention_processor import (
|
||||
AttnProcessor2_0,
|
||||
LoRAAttnProcessor2_0,
|
||||
LoRAXFormersAttnProcessor,
|
||||
XFormersAttnProcessor,
|
||||
)
|
||||
from ...models.lora import adjust_lora_scale_text_encoder
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import USE_PEFT_BACKEND, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
|
||||
from diffusers.models.lora import adjust_lora_scale_text_encoder
|
||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
|
||||
from diffusers.schedulers import KarrasDiffusionSchedulers
|
||||
from diffusers.utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
logging,
|
||||
scale_lora_layers,
|
||||
unscale_lora_layers,
|
||||
)
|
||||
from diffusers.utils.import_utils import is_invisible_watermark_available
|
||||
from diffusers.utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor
|
||||
|
||||
|
||||
if is_invisible_watermark_available():
|
||||
from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
|
||||
from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> # !pip install opencv-python transformers accelerate
|
||||
>>> from diffusers import StableDiffusionXLControlNetXSPipeline, ControlNetXSModel, AutoencoderKL
|
||||
>>> from diffusers.utils import load_image
|
||||
>>> import numpy as np
|
||||
>>> import torch
|
||||
|
||||
>>> import cv2
|
||||
>>> from PIL import Image
|
||||
|
||||
>>> prompt = "aerial view, a futuristic research complex in a bright foggy jungle, hard lighting"
|
||||
>>> negative_prompt = "low quality, bad quality, sketches"
|
||||
|
||||
>>> # download an image
|
||||
>>> image = load_image(
|
||||
... "https://hf.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png"
|
||||
... )
|
||||
|
||||
>>> # initialize the models and pipeline
|
||||
>>> controlnet_conditioning_scale = 0.5 # recommended for good generalization
|
||||
>>> controlnet = ControlNetXSModel.from_pretrained("UmerHA/ConrolNetXS-SDXL-canny", torch_dtype=torch.float16)
|
||||
>>> vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
|
||||
>>> pipe = StableDiffusionXLControlNetXSPipeline.from_pretrained(
|
||||
... "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, vae=vae, torch_dtype=torch.float16
|
||||
... )
|
||||
>>> pipe.enable_model_cpu_offload()
|
||||
|
||||
>>> # get canny image
|
||||
>>> image = np.array(image)
|
||||
>>> image = cv2.Canny(image, 100, 200)
|
||||
>>> image = image[:, :, None]
|
||||
>>> image = np.concatenate([image, image, image], axis=2)
|
||||
>>> canny_image = Image.fromarray(image)
|
||||
|
||||
>>> # generate image
|
||||
>>> image = pipe(
|
||||
... prompt, controlnet_conditioning_scale=controlnet_conditioning_scale, image=canny_image
|
||||
... ).images[0]
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
class StableDiffusionXLControlNetXSPipeline(
|
||||
DiffusionPipeline, TextualInversionLoaderMixin, StableDiffusionXLLoraLoaderMixin, FromSingleFileMixin
|
||||
):
|
||||
@@ -730,7 +690,6 @@ class StableDiffusionXLControlNetXSPipeline(
|
||||
self.unet.disable_freeu()
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
@@ -527,9 +527,17 @@ def main():
|
||||
|
||||
# lora attn processor
|
||||
prior_lora_config = LoraConfig(
|
||||
r=args.rank, target_modules=["to_k", "to_q", "to_v", "to_out.0", "add_k_proj", "add_v_proj"]
|
||||
r=args.rank,
|
||||
lora_alpha=args.rank,
|
||||
target_modules=["to_k", "to_q", "to_v", "to_out.0", "add_k_proj", "add_v_proj"],
|
||||
)
|
||||
# Add adapter and make sure the trainable params are in float32.
|
||||
prior.add_adapter(prior_lora_config)
|
||||
if args.mixed_precision == "fp16":
|
||||
for param in prior.parameters():
|
||||
# only upcast trainable parameters (LoRA) into fp32
|
||||
if param.requires_grad:
|
||||
param.data = param.to(torch.float32)
|
||||
|
||||
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
||||
def save_model_hook(models, weights, output_dir):
|
||||
|
||||
51
scripts/convert_animatediff_motion_lora_to_diffusers.py
Normal file
51
scripts/convert_animatediff_motion_lora_to_diffusers.py
Normal file
@@ -0,0 +1,51 @@
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
from safetensors.torch import save_file
|
||||
|
||||
|
||||
def convert_motion_module(original_state_dict):
|
||||
converted_state_dict = {}
|
||||
for k, v in original_state_dict.items():
|
||||
if "pos_encoder" in k:
|
||||
continue
|
||||
|
||||
else:
|
||||
converted_state_dict[
|
||||
k.replace(".norms.0", ".norm1")
|
||||
.replace(".norms.1", ".norm2")
|
||||
.replace(".ff_norm", ".norm3")
|
||||
.replace(".attention_blocks.0", ".attn1")
|
||||
.replace(".attention_blocks.1", ".attn2")
|
||||
.replace(".temporal_transformer", "")
|
||||
] = v
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--ckpt_path", type=str, required=True)
|
||||
parser.add_argument("--output_path", type=str, required=True)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = get_args()
|
||||
|
||||
state_dict = torch.load(args.ckpt_path, map_location="cpu")
|
||||
|
||||
if "state_dict" in state_dict.keys():
|
||||
state_dict = state_dict["state_dict"]
|
||||
|
||||
conv_state_dict = convert_motion_module(state_dict)
|
||||
|
||||
# convert to new format
|
||||
output_dict = {}
|
||||
for module_name, params in conv_state_dict.items():
|
||||
if type(params) is not torch.Tensor:
|
||||
continue
|
||||
output_dict.update({f"unet.{module_name}": params})
|
||||
|
||||
save_file(output_dict, f"{args.output_path}/diffusion_pytorch_model.safetensors")
|
||||
51
scripts/convert_animatediff_motion_module_to_diffusers.py
Normal file
51
scripts/convert_animatediff_motion_module_to_diffusers.py
Normal file
@@ -0,0 +1,51 @@
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import MotionAdapter
|
||||
|
||||
|
||||
def convert_motion_module(original_state_dict):
|
||||
converted_state_dict = {}
|
||||
for k, v in original_state_dict.items():
|
||||
if "pos_encoder" in k:
|
||||
continue
|
||||
|
||||
else:
|
||||
converted_state_dict[
|
||||
k.replace(".norms.0", ".norm1")
|
||||
.replace(".norms.1", ".norm2")
|
||||
.replace(".ff_norm", ".norm3")
|
||||
.replace(".attention_blocks.0", ".attn1")
|
||||
.replace(".attention_blocks.1", ".attn2")
|
||||
.replace(".temporal_transformer", "")
|
||||
] = v
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--ckpt_path", type=str, required=True)
|
||||
parser.add_argument("--output_path", type=str, required=True)
|
||||
parser.add_argument("--use_motion_mid_block", action="store_true")
|
||||
parser.add_argument("--motion_max_seq_length", type=int, default=32)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = get_args()
|
||||
|
||||
state_dict = torch.load(args.ckpt_path, map_location="cpu")
|
||||
if "state_dict" in state_dict.keys():
|
||||
state_dict = state_dict["state_dict"]
|
||||
|
||||
conv_state_dict = convert_motion_module(state_dict)
|
||||
adapter = MotionAdapter(
|
||||
use_motion_mid_block=args.use_motion_mid_block, motion_max_seq_length=args.motion_max_seq_length
|
||||
)
|
||||
# skip loading position embeddings
|
||||
adapter.load_state_dict(conv_state_dict, strict=False)
|
||||
adapter.save_pretrained(args.output_path)
|
||||
adapter.save_pretrained(args.output_path, variant="fp16", torch_dtype=torch.float16)
|
||||
@@ -80,7 +80,6 @@ else:
|
||||
"AutoencoderTiny",
|
||||
"ConsistencyDecoderVAE",
|
||||
"ControlNetModel",
|
||||
"ControlNetXSModel",
|
||||
"Kandinsky3UNet",
|
||||
"ModelMixin",
|
||||
"MotionAdapter",
|
||||
@@ -256,7 +255,6 @@ else:
|
||||
"StableDiffusionControlNetImg2ImgPipeline",
|
||||
"StableDiffusionControlNetInpaintPipeline",
|
||||
"StableDiffusionControlNetPipeline",
|
||||
"StableDiffusionControlNetXSPipeline",
|
||||
"StableDiffusionDepth2ImgPipeline",
|
||||
"StableDiffusionDiffEditPipeline",
|
||||
"StableDiffusionGLIGENPipeline",
|
||||
@@ -280,7 +278,6 @@ else:
|
||||
"StableDiffusionXLControlNetImg2ImgPipeline",
|
||||
"StableDiffusionXLControlNetInpaintPipeline",
|
||||
"StableDiffusionXLControlNetPipeline",
|
||||
"StableDiffusionXLControlNetXSPipeline",
|
||||
"StableDiffusionXLImg2ImgPipeline",
|
||||
"StableDiffusionXLInpaintPipeline",
|
||||
"StableDiffusionXLInstructPix2PixPipeline",
|
||||
@@ -462,7 +459,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
AutoencoderTiny,
|
||||
ConsistencyDecoderVAE,
|
||||
ControlNetModel,
|
||||
ControlNetXSModel,
|
||||
Kandinsky3UNet,
|
||||
ModelMixin,
|
||||
MotionAdapter,
|
||||
@@ -617,7 +613,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
StableDiffusionControlNetImg2ImgPipeline,
|
||||
StableDiffusionControlNetInpaintPipeline,
|
||||
StableDiffusionControlNetPipeline,
|
||||
StableDiffusionControlNetXSPipeline,
|
||||
StableDiffusionDepth2ImgPipeline,
|
||||
StableDiffusionDiffEditPipeline,
|
||||
StableDiffusionGLIGENPipeline,
|
||||
@@ -641,7 +636,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
StableDiffusionXLControlNetImg2ImgPipeline,
|
||||
StableDiffusionXLControlNetInpaintPipeline,
|
||||
StableDiffusionXLControlNetPipeline,
|
||||
StableDiffusionXLControlNetXSPipeline,
|
||||
StableDiffusionXLImg2ImgPipeline,
|
||||
StableDiffusionXLInpaintPipeline,
|
||||
StableDiffusionXLInstructPix2PixPipeline,
|
||||
|
||||
@@ -18,7 +18,7 @@ from typing import List, Optional, Tuple, Union
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
import torch
|
||||
from PIL import Image
|
||||
from PIL import Image, ImageFilter, ImageOps
|
||||
|
||||
from .configuration_utils import ConfigMixin, register_to_config
|
||||
from .utils import CONFIG_NAME, PIL_INTERPOLATION, deprecate
|
||||
@@ -166,6 +166,244 @@ class VaeImageProcessor(ConfigMixin):
|
||||
|
||||
return image
|
||||
|
||||
@staticmethod
|
||||
def blur(image: PIL.Image.Image, blur_factor: int = 4) -> PIL.Image.Image:
|
||||
"""
|
||||
Blurs an image.
|
||||
"""
|
||||
image = image.filter(ImageFilter.GaussianBlur(blur_factor))
|
||||
|
||||
return image
|
||||
|
||||
@staticmethod
|
||||
def get_crop_region(mask_image: PIL.Image.Image, width: int, height: int, pad=0):
|
||||
"""
|
||||
Finds a rectangular region that contains all masked ares in an image, and expands region to match the aspect ratio of the original image;
|
||||
for example, if user drew mask in a 128x32 region, and the dimensions for processing are 512x512, the region will be expanded to 128x128.
|
||||
|
||||
Args:
|
||||
mask_image (PIL.Image.Image): Mask image.
|
||||
width (int): Width of the image to be processed.
|
||||
height (int): Height of the image to be processed.
|
||||
pad (int, optional): Padding to be added to the crop region. Defaults to 0.
|
||||
|
||||
Returns:
|
||||
tuple: (x1, y1, x2, y2) represent a rectangular region that contains all masked ares in an image and matches the original aspect ratio.
|
||||
"""
|
||||
|
||||
mask_image = mask_image.convert("L")
|
||||
mask = np.array(mask_image)
|
||||
|
||||
# 1. find a rectangular region that contains all masked ares in an image
|
||||
h, w = mask.shape
|
||||
crop_left = 0
|
||||
for i in range(w):
|
||||
if not (mask[:, i] == 0).all():
|
||||
break
|
||||
crop_left += 1
|
||||
|
||||
crop_right = 0
|
||||
for i in reversed(range(w)):
|
||||
if not (mask[:, i] == 0).all():
|
||||
break
|
||||
crop_right += 1
|
||||
|
||||
crop_top = 0
|
||||
for i in range(h):
|
||||
if not (mask[i] == 0).all():
|
||||
break
|
||||
crop_top += 1
|
||||
|
||||
crop_bottom = 0
|
||||
for i in reversed(range(h)):
|
||||
if not (mask[i] == 0).all():
|
||||
break
|
||||
crop_bottom += 1
|
||||
|
||||
# 2. add padding to the crop region
|
||||
x1, y1, x2, y2 = (
|
||||
int(max(crop_left - pad, 0)),
|
||||
int(max(crop_top - pad, 0)),
|
||||
int(min(w - crop_right + pad, w)),
|
||||
int(min(h - crop_bottom + pad, h)),
|
||||
)
|
||||
|
||||
# 3. expands crop region to match the aspect ratio of the image to be processed
|
||||
ratio_crop_region = (x2 - x1) / (y2 - y1)
|
||||
ratio_processing = width / height
|
||||
|
||||
if ratio_crop_region > ratio_processing:
|
||||
desired_height = (x2 - x1) / ratio_processing
|
||||
desired_height_diff = int(desired_height - (y2 - y1))
|
||||
y1 -= desired_height_diff // 2
|
||||
y2 += desired_height_diff - desired_height_diff // 2
|
||||
if y2 >= mask_image.height:
|
||||
diff = y2 - mask_image.height
|
||||
y2 -= diff
|
||||
y1 -= diff
|
||||
if y1 < 0:
|
||||
y2 -= y1
|
||||
y1 -= y1
|
||||
if y2 >= mask_image.height:
|
||||
y2 = mask_image.height
|
||||
else:
|
||||
desired_width = (y2 - y1) * ratio_processing
|
||||
desired_width_diff = int(desired_width - (x2 - x1))
|
||||
x1 -= desired_width_diff // 2
|
||||
x2 += desired_width_diff - desired_width_diff // 2
|
||||
if x2 >= mask_image.width:
|
||||
diff = x2 - mask_image.width
|
||||
x2 -= diff
|
||||
x1 -= diff
|
||||
if x1 < 0:
|
||||
x2 -= x1
|
||||
x1 -= x1
|
||||
if x2 >= mask_image.width:
|
||||
x2 = mask_image.width
|
||||
|
||||
return x1, y1, x2, y2
|
||||
|
||||
def _resize_and_fill(
|
||||
self,
|
||||
image: PIL.Image.Image,
|
||||
width: int,
|
||||
height: int,
|
||||
) -> PIL.Image.Image:
|
||||
"""
|
||||
Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, filling empty with data from image.
|
||||
|
||||
Args:
|
||||
image: The image to resize.
|
||||
width: The width to resize the image to.
|
||||
height: The height to resize the image to.
|
||||
"""
|
||||
|
||||
ratio = width / height
|
||||
src_ratio = image.width / image.height
|
||||
|
||||
src_w = width if ratio < src_ratio else image.width * height // image.height
|
||||
src_h = height if ratio >= src_ratio else image.height * width // image.width
|
||||
|
||||
resized = image.resize((src_w, src_h), resample=PIL_INTERPOLATION["lanczos"])
|
||||
res = Image.new("RGB", (width, height))
|
||||
res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
|
||||
|
||||
if ratio < src_ratio:
|
||||
fill_height = height // 2 - src_h // 2
|
||||
if fill_height > 0:
|
||||
res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0))
|
||||
res.paste(
|
||||
resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)),
|
||||
box=(0, fill_height + src_h),
|
||||
)
|
||||
elif ratio > src_ratio:
|
||||
fill_width = width // 2 - src_w // 2
|
||||
if fill_width > 0:
|
||||
res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0))
|
||||
res.paste(
|
||||
resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)),
|
||||
box=(fill_width + src_w, 0),
|
||||
)
|
||||
|
||||
return res
|
||||
|
||||
def _resize_and_crop(
|
||||
self,
|
||||
image: PIL.Image.Image,
|
||||
width: int,
|
||||
height: int,
|
||||
) -> PIL.Image.Image:
|
||||
"""
|
||||
Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, cropping the excess.
|
||||
|
||||
Args:
|
||||
image: The image to resize.
|
||||
width: The width to resize the image to.
|
||||
height: The height to resize the image to.
|
||||
"""
|
||||
ratio = width / height
|
||||
src_ratio = image.width / image.height
|
||||
|
||||
src_w = width if ratio > src_ratio else image.width * height // image.height
|
||||
src_h = height if ratio <= src_ratio else image.height * width // image.width
|
||||
|
||||
resized = image.resize((src_w, src_h), resample=PIL_INTERPOLATION["lanczos"])
|
||||
res = Image.new("RGB", (width, height))
|
||||
res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
|
||||
return res
|
||||
|
||||
def resize(
|
||||
self,
|
||||
image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
|
||||
height: int,
|
||||
width: int,
|
||||
resize_mode: str = "default", # "defalt", "fill", "crop"
|
||||
) -> Union[PIL.Image.Image, np.ndarray, torch.Tensor]:
|
||||
"""
|
||||
Resize image.
|
||||
|
||||
Args:
|
||||
image (`PIL.Image.Image`, `np.ndarray` or `torch.Tensor`):
|
||||
The image input, can be a PIL image, numpy array or pytorch tensor.
|
||||
height (`int`):
|
||||
The height to resize to.
|
||||
width (`int`):
|
||||
The width to resize to.
|
||||
resize_mode (`str`, *optional*, defaults to `default`):
|
||||
The resize mode to use, can be one of `default` or `fill`. If `default`, will resize the image to fit
|
||||
within the specified width and height, and it may not maintaining the original aspect ratio.
|
||||
If `fill`, will resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image
|
||||
within the dimensions, filling empty with data from image.
|
||||
If `crop`, will resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image
|
||||
within the dimensions, cropping the excess.
|
||||
Note that resize_mode `fill` and `crop` are only supported for PIL image input.
|
||||
|
||||
Returns:
|
||||
`PIL.Image.Image`, `np.ndarray` or `torch.Tensor`:
|
||||
The resized image.
|
||||
"""
|
||||
if resize_mode != "default" and not isinstance(image, PIL.Image.Image):
|
||||
raise ValueError(f"Only PIL image input is supported for resize_mode {resize_mode}")
|
||||
if isinstance(image, PIL.Image.Image):
|
||||
if resize_mode == "default":
|
||||
image = image.resize((width, height), resample=PIL_INTERPOLATION[self.config.resample])
|
||||
elif resize_mode == "fill":
|
||||
image = self._resize_and_fill(image, width, height)
|
||||
elif resize_mode == "crop":
|
||||
image = self._resize_and_crop(image, width, height)
|
||||
else:
|
||||
raise ValueError(f"resize_mode {resize_mode} is not supported")
|
||||
|
||||
elif isinstance(image, torch.Tensor):
|
||||
image = torch.nn.functional.interpolate(
|
||||
image,
|
||||
size=(height, width),
|
||||
)
|
||||
elif isinstance(image, np.ndarray):
|
||||
image = self.numpy_to_pt(image)
|
||||
image = torch.nn.functional.interpolate(
|
||||
image,
|
||||
size=(height, width),
|
||||
)
|
||||
image = self.pt_to_numpy(image)
|
||||
return image
|
||||
|
||||
def binarize(self, image: PIL.Image.Image) -> PIL.Image.Image:
|
||||
"""
|
||||
Create a mask.
|
||||
|
||||
Args:
|
||||
image (`PIL.Image.Image`):
|
||||
The image input, should be a PIL image.
|
||||
|
||||
Returns:
|
||||
`PIL.Image.Image`:
|
||||
The binarized image. Values less than 0.5 are set to 0, values greater than 0.5 are set to 1.
|
||||
"""
|
||||
image[image < 0.5] = 0
|
||||
image[image >= 0.5] = 1
|
||||
return image
|
||||
|
||||
def get_default_height_width(
|
||||
self,
|
||||
image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
|
||||
@@ -209,67 +447,34 @@ class VaeImageProcessor(ConfigMixin):
|
||||
|
||||
return height, width
|
||||
|
||||
def resize(
|
||||
self,
|
||||
image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
) -> Union[PIL.Image.Image, np.ndarray, torch.Tensor]:
|
||||
"""
|
||||
Resize image.
|
||||
|
||||
Args:
|
||||
image (`PIL.Image.Image`, `np.ndarray` or `torch.Tensor`):
|
||||
The image input, can be a PIL image, numpy array or pytorch tensor.
|
||||
height (`int`, *optional*, defaults to `None`):
|
||||
The height to resize to.
|
||||
width (`int`, *optional*`, defaults to `None`):
|
||||
The width to resize to.
|
||||
|
||||
Returns:
|
||||
`PIL.Image.Image`, `np.ndarray` or `torch.Tensor`:
|
||||
The resized image.
|
||||
"""
|
||||
if isinstance(image, PIL.Image.Image):
|
||||
image = image.resize((width, height), resample=PIL_INTERPOLATION[self.config.resample])
|
||||
elif isinstance(image, torch.Tensor):
|
||||
image = torch.nn.functional.interpolate(
|
||||
image,
|
||||
size=(height, width),
|
||||
)
|
||||
elif isinstance(image, np.ndarray):
|
||||
image = self.numpy_to_pt(image)
|
||||
image = torch.nn.functional.interpolate(
|
||||
image,
|
||||
size=(height, width),
|
||||
)
|
||||
image = self.pt_to_numpy(image)
|
||||
return image
|
||||
|
||||
def binarize(self, image: PIL.Image.Image) -> PIL.Image.Image:
|
||||
"""
|
||||
Create a mask.
|
||||
|
||||
Args:
|
||||
image (`PIL.Image.Image`):
|
||||
The image input, should be a PIL image.
|
||||
|
||||
Returns:
|
||||
`PIL.Image.Image`:
|
||||
The binarized image. Values less than 0.5 are set to 0, values greater than 0.5 are set to 1.
|
||||
"""
|
||||
image[image < 0.5] = 0
|
||||
image[image >= 0.5] = 1
|
||||
return image
|
||||
|
||||
def preprocess(
|
||||
self,
|
||||
image: Union[torch.FloatTensor, PIL.Image.Image, np.ndarray],
|
||||
image: PipelineImageInput,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
resize_mode: str = "default", # "defalt", "fill", "crop"
|
||||
crops_coords: Optional[Tuple[int, int, int, int]] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Preprocess the image input. Accepted formats are PIL images, NumPy arrays or PyTorch tensors.
|
||||
Preprocess the image input.
|
||||
|
||||
Args:
|
||||
image (`pipeline_image_input`):
|
||||
The image input, accepted formats are PIL images, NumPy arrays, PyTorch tensors; Also accept list of supported formats.
|
||||
height (`int`, *optional*, defaults to `None`):
|
||||
The height in preprocessed image. If `None`, will use the `get_default_height_width()` to get default height.
|
||||
width (`int`, *optional*`, defaults to `None`):
|
||||
The width in preprocessed. If `None`, will use get_default_height_width()` to get the default width.
|
||||
resize_mode (`str`, *optional*, defaults to `default`):
|
||||
The resize mode, can be one of `default` or `fill`. If `default`, will resize the image to fit
|
||||
within the specified width and height, and it may not maintaining the original aspect ratio.
|
||||
If `fill`, will resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image
|
||||
within the dimensions, filling empty with data from image.
|
||||
If `crop`, will resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image
|
||||
within the dimensions, cropping the excess.
|
||||
Note that resize_mode `fill` and `crop` are only supported for PIL image input.
|
||||
crops_coords (`List[Tuple[int, int, int, int]]`, *optional*, defaults to `None`):
|
||||
The crop coordinates for each image in the batch. If `None`, will not crop the image.
|
||||
"""
|
||||
supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor)
|
||||
|
||||
@@ -299,13 +504,15 @@ class VaeImageProcessor(ConfigMixin):
|
||||
)
|
||||
|
||||
if isinstance(image[0], PIL.Image.Image):
|
||||
if crops_coords is not None:
|
||||
image = [i.crop(crops_coords) for i in image]
|
||||
if self.config.do_resize:
|
||||
height, width = self.get_default_height_width(image[0], height, width)
|
||||
image = [self.resize(i, height, width, resize_mode=resize_mode) for i in image]
|
||||
if self.config.do_convert_rgb:
|
||||
image = [self.convert_to_rgb(i) for i in image]
|
||||
elif self.config.do_convert_grayscale:
|
||||
image = [self.convert_to_grayscale(i) for i in image]
|
||||
if self.config.do_resize:
|
||||
height, width = self.get_default_height_width(image[0], height, width)
|
||||
image = [self.resize(i, height, width) for i in image]
|
||||
image = self.pil_to_numpy(image) # to np
|
||||
image = self.numpy_to_pt(image) # to pt
|
||||
|
||||
@@ -406,6 +613,39 @@ class VaeImageProcessor(ConfigMixin):
|
||||
if output_type == "pil":
|
||||
return self.numpy_to_pil(image)
|
||||
|
||||
def apply_overlay(
|
||||
self,
|
||||
mask: PIL.Image.Image,
|
||||
init_image: PIL.Image.Image,
|
||||
image: PIL.Image.Image,
|
||||
crop_coords: Optional[Tuple[int, int, int, int]] = None,
|
||||
) -> PIL.Image.Image:
|
||||
"""
|
||||
overlay the inpaint output to the original image
|
||||
"""
|
||||
|
||||
width, height = image.width, image.height
|
||||
|
||||
init_image = self.resize(init_image, width=width, height=height)
|
||||
mask = self.resize(mask, width=width, height=height)
|
||||
|
||||
init_image_masked = PIL.Image.new("RGBa", (width, height))
|
||||
init_image_masked.paste(init_image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(mask.convert("L")))
|
||||
init_image_masked = init_image_masked.convert("RGBA")
|
||||
|
||||
if crop_coords is not None:
|
||||
x, y, w, h = crop_coords
|
||||
base_image = PIL.Image.new("RGBA", (width, height))
|
||||
image = self.resize(image, height=h, width=w, resize_mode="crop")
|
||||
base_image.paste(image, (x, y))
|
||||
image = base_image.convert("RGB")
|
||||
|
||||
image = image.convert("RGBA")
|
||||
image.alpha_composite(init_image_masked)
|
||||
image = image.convert("RGB")
|
||||
|
||||
return image
|
||||
|
||||
|
||||
class VaeImageProcessorLDM3D(VaeImageProcessor):
|
||||
"""
|
||||
|
||||
@@ -149,9 +149,11 @@ class IPAdapterMixin:
|
||||
self.feature_extractor = CLIPImageProcessor()
|
||||
|
||||
# load ip-adapter into unet
|
||||
self.unet._load_ip_adapter_weights(state_dict)
|
||||
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
|
||||
unet._load_ip_adapter_weights(state_dict)
|
||||
|
||||
def set_ip_adapter_scale(self, scale):
|
||||
for attn_processor in self.unet.attn_processors.values():
|
||||
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
|
||||
for attn_processor in unet.attn_processors.values():
|
||||
if isinstance(attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)):
|
||||
attn_processor.scale = scale
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import inspect
|
||||
import os
|
||||
from contextlib import nullcontext
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
@@ -912,10 +913,10 @@ class LoraLoaderMixin:
|
||||
)
|
||||
|
||||
if unet_lora_layers:
|
||||
state_dict.update(pack_weights(unet_lora_layers, "unet"))
|
||||
state_dict.update(pack_weights(unet_lora_layers, cls.unet_name))
|
||||
|
||||
if text_encoder_lora_layers:
|
||||
state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder"))
|
||||
state_dict.update(pack_weights(text_encoder_lora_layers, cls.text_encoder_name))
|
||||
|
||||
if transformer_lora_layers:
|
||||
state_dict.update(pack_weights(transformer_lora_layers, "transformer"))
|
||||
@@ -975,6 +976,8 @@ class LoraLoaderMixin:
|
||||
>>> ...
|
||||
```
|
||||
"""
|
||||
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
|
||||
|
||||
if not USE_PEFT_BACKEND:
|
||||
if version.parse(__version__) > version.parse("0.23"):
|
||||
logger.warn(
|
||||
@@ -982,13 +985,13 @@ class LoraLoaderMixin:
|
||||
"you can use `pipe.enable_lora()` or `pipe.disable_lora()`. After installing the latest version of PEFT."
|
||||
)
|
||||
|
||||
for _, module in self.unet.named_modules():
|
||||
for _, module in unet.named_modules():
|
||||
if hasattr(module, "set_lora_layer"):
|
||||
module.set_lora_layer(None)
|
||||
else:
|
||||
recurse_remove_peft_layers(self.unet)
|
||||
if hasattr(self.unet, "peft_config"):
|
||||
del self.unet.peft_config
|
||||
recurse_remove_peft_layers(unet)
|
||||
if hasattr(unet, "peft_config"):
|
||||
del unet.peft_config
|
||||
|
||||
# Safe to call the following regardless of LoRA.
|
||||
self._remove_text_encoder_monkey_patch()
|
||||
@@ -999,6 +1002,7 @@ class LoraLoaderMixin:
|
||||
fuse_text_encoder: bool = True,
|
||||
lora_scale: float = 1.0,
|
||||
safe_fusing: bool = False,
|
||||
adapter_names: Optional[List[str]] = None,
|
||||
):
|
||||
r"""
|
||||
Fuses the LoRA parameters into the original parameters of the corresponding blocks.
|
||||
@@ -1018,6 +1022,21 @@ class LoraLoaderMixin:
|
||||
Controls how much to influence the outputs with the LoRA parameters.
|
||||
safe_fusing (`bool`, defaults to `False`):
|
||||
Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
|
||||
adapter_names (`List[str]`, *optional*):
|
||||
Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
|
||||
|
||||
Example:
|
||||
|
||||
```py
|
||||
from diffusers import DiffusionPipeline
|
||||
import torch
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
|
||||
).to("cuda")
|
||||
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
|
||||
pipeline.fuse_lora(lora_scale=0.7)
|
||||
```
|
||||
"""
|
||||
if fuse_unet or fuse_text_encoder:
|
||||
self.num_fused_loras += 1
|
||||
@@ -1027,24 +1046,44 @@ class LoraLoaderMixin:
|
||||
)
|
||||
|
||||
if fuse_unet:
|
||||
self.unet.fuse_lora(lora_scale, safe_fusing=safe_fusing)
|
||||
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
|
||||
unet.fuse_lora(lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
|
||||
def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False):
|
||||
# TODO(Patrick, Younes): enable "safe" fusing
|
||||
def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False, adapter_names=None):
|
||||
merge_kwargs = {"safe_merge": safe_fusing}
|
||||
|
||||
for module in text_encoder.modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
if lora_scale != 1.0:
|
||||
module.scale_layer(lora_scale)
|
||||
|
||||
module.merge()
|
||||
# For BC with previous PEFT versions, we need to check the signature
|
||||
# of the `merge` method to see if it supports the `adapter_names` argument.
|
||||
supported_merge_kwargs = list(inspect.signature(module.merge).parameters)
|
||||
if "adapter_names" in supported_merge_kwargs:
|
||||
merge_kwargs["adapter_names"] = adapter_names
|
||||
elif "adapter_names" not in supported_merge_kwargs and adapter_names is not None:
|
||||
raise ValueError(
|
||||
"The `adapter_names` argument is not supported with your PEFT version. "
|
||||
"Please upgrade to the latest version of PEFT. `pip install -U peft`"
|
||||
)
|
||||
|
||||
module.merge(**merge_kwargs)
|
||||
|
||||
else:
|
||||
deprecate("fuse_text_encoder_lora", "0.27", LORA_DEPRECATION_MESSAGE)
|
||||
|
||||
def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False):
|
||||
def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False, **kwargs):
|
||||
if "adapter_names" in kwargs and kwargs["adapter_names"] is not None:
|
||||
raise ValueError(
|
||||
"The `adapter_names` argument is not supported in your environment. Please switch to PEFT "
|
||||
"backend to use this argument by installing latest PEFT and transformers."
|
||||
" `pip install -U peft transformers`"
|
||||
)
|
||||
|
||||
for _, attn_module in text_encoder_attn_modules(text_encoder):
|
||||
if isinstance(attn_module.q_proj, PatchedLoraProjection):
|
||||
attn_module.q_proj._fuse_lora(lora_scale, safe_fusing)
|
||||
@@ -1059,9 +1098,9 @@ class LoraLoaderMixin:
|
||||
|
||||
if fuse_text_encoder:
|
||||
if hasattr(self, "text_encoder"):
|
||||
fuse_text_encoder_lora(self.text_encoder, lora_scale, safe_fusing)
|
||||
fuse_text_encoder_lora(self.text_encoder, lora_scale, safe_fusing, adapter_names=adapter_names)
|
||||
if hasattr(self, "text_encoder_2"):
|
||||
fuse_text_encoder_lora(self.text_encoder_2, lora_scale, safe_fusing)
|
||||
fuse_text_encoder_lora(self.text_encoder_2, lora_scale, safe_fusing, adapter_names=adapter_names)
|
||||
|
||||
def unfuse_lora(self, unfuse_unet: bool = True, unfuse_text_encoder: bool = True):
|
||||
r"""
|
||||
@@ -1080,13 +1119,14 @@ class LoraLoaderMixin:
|
||||
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
|
||||
LoRA parameters then it won't have any effect.
|
||||
"""
|
||||
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
|
||||
if unfuse_unet:
|
||||
if not USE_PEFT_BACKEND:
|
||||
self.unet.unfuse_lora()
|
||||
unet.unfuse_lora()
|
||||
else:
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
|
||||
for module in self.unet.modules():
|
||||
for module in unet.modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
module.unmerge()
|
||||
|
||||
@@ -1202,8 +1242,9 @@ class LoraLoaderMixin:
|
||||
adapter_names: Union[List[str], str],
|
||||
adapter_weights: Optional[List[float]] = None,
|
||||
):
|
||||
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
|
||||
# Handle the UNET
|
||||
self.unet.set_adapters(adapter_names, adapter_weights)
|
||||
unet.set_adapters(adapter_names, adapter_weights)
|
||||
|
||||
# Handle the Text Encoder
|
||||
if hasattr(self, "text_encoder"):
|
||||
@@ -1216,7 +1257,8 @@ class LoraLoaderMixin:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
# Disable unet adapters
|
||||
self.unet.disable_lora()
|
||||
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
|
||||
unet.disable_lora()
|
||||
|
||||
# Disable text encoder adapters
|
||||
if hasattr(self, "text_encoder"):
|
||||
@@ -1229,7 +1271,8 @@ class LoraLoaderMixin:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
# Enable unet adapters
|
||||
self.unet.enable_lora()
|
||||
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
|
||||
unet.enable_lora()
|
||||
|
||||
# Enable text encoder adapters
|
||||
if hasattr(self, "text_encoder"):
|
||||
@@ -1251,7 +1294,8 @@ class LoraLoaderMixin:
|
||||
adapter_names = [adapter_names]
|
||||
|
||||
# Delete unet adapters
|
||||
self.unet.delete_adapters(adapter_names)
|
||||
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
|
||||
unet.delete_adapters(adapter_names)
|
||||
|
||||
for adapter_name in adapter_names:
|
||||
# Delete text encoder adapters
|
||||
@@ -1284,8 +1328,8 @@ class LoraLoaderMixin:
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
|
||||
active_adapters = []
|
||||
|
||||
for module in self.unet.modules():
|
||||
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
|
||||
for module in unet.modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
active_adapters = module.active_adapters
|
||||
break
|
||||
@@ -1309,8 +1353,9 @@ class LoraLoaderMixin:
|
||||
if hasattr(self, "text_encoder_2") and hasattr(self.text_encoder_2, "peft_config"):
|
||||
set_adapters["text_encoder_2"] = list(self.text_encoder_2.peft_config.keys())
|
||||
|
||||
if hasattr(self, "unet") and hasattr(self.unet, "peft_config"):
|
||||
set_adapters["unet"] = list(self.unet.peft_config.keys())
|
||||
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
|
||||
if hasattr(self, self.unet_name) and hasattr(unet, "peft_config"):
|
||||
set_adapters[self.unet_name] = list(self.unet.peft_config.keys())
|
||||
|
||||
return set_adapters
|
||||
|
||||
@@ -1331,7 +1376,8 @@ class LoraLoaderMixin:
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
|
||||
# Handle the UNET
|
||||
for unet_module in self.unet.modules():
|
||||
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
|
||||
for unet_module in unet.modules():
|
||||
if isinstance(unet_module, BaseTunerLayer):
|
||||
for adapter_name in adapter_names:
|
||||
unet_module.lora_A[adapter_name].to(device)
|
||||
|
||||
@@ -11,9 +11,11 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import inspect
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from contextlib import nullcontext
|
||||
from functools import partial
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
|
||||
import safetensors
|
||||
@@ -504,22 +506,43 @@ class UNet2DConditionLoadersMixin:
|
||||
save_function(state_dict, os.path.join(save_directory, weight_name))
|
||||
logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}")
|
||||
|
||||
def fuse_lora(self, lora_scale=1.0, safe_fusing=False):
|
||||
def fuse_lora(self, lora_scale=1.0, safe_fusing=False, adapter_names=None):
|
||||
self.lora_scale = lora_scale
|
||||
self._safe_fusing = safe_fusing
|
||||
self.apply(self._fuse_lora_apply)
|
||||
self.apply(partial(self._fuse_lora_apply, adapter_names=adapter_names))
|
||||
|
||||
def _fuse_lora_apply(self, module):
|
||||
def _fuse_lora_apply(self, module, adapter_names=None):
|
||||
if not USE_PEFT_BACKEND:
|
||||
if hasattr(module, "_fuse_lora"):
|
||||
module._fuse_lora(self.lora_scale, self._safe_fusing)
|
||||
|
||||
if adapter_names is not None:
|
||||
raise ValueError(
|
||||
"The `adapter_names` argument is not supported in your environment. Please switch"
|
||||
" to PEFT backend to use this argument by installing latest PEFT and transformers."
|
||||
" `pip install -U peft transformers`"
|
||||
)
|
||||
else:
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
|
||||
merge_kwargs = {"safe_merge": self._safe_fusing}
|
||||
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
if self.lora_scale != 1.0:
|
||||
module.scale_layer(self.lora_scale)
|
||||
module.merge(safe_merge=self._safe_fusing)
|
||||
|
||||
# For BC with prevous PEFT versions, we need to check the signature
|
||||
# of the `merge` method to see if it supports the `adapter_names` argument.
|
||||
supported_merge_kwargs = list(inspect.signature(module.merge).parameters)
|
||||
if "adapter_names" in supported_merge_kwargs:
|
||||
merge_kwargs["adapter_names"] = adapter_names
|
||||
elif "adapter_names" not in supported_merge_kwargs and adapter_names is not None:
|
||||
raise ValueError(
|
||||
"The `adapter_names` argument is not supported with your PEFT version. Please upgrade"
|
||||
" to the latest version of PEFT. `pip install -U peft`"
|
||||
)
|
||||
|
||||
module.merge(**merge_kwargs)
|
||||
|
||||
def unfuse_lora(self):
|
||||
self.apply(self._unfuse_lora_apply)
|
||||
|
||||
@@ -32,7 +32,6 @@ if is_torch_available():
|
||||
_import_structure["autoencoders.autoencoder_tiny"] = ["AutoencoderTiny"]
|
||||
_import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
|
||||
_import_structure["controlnet"] = ["ControlNetModel"]
|
||||
_import_structure["controlnetxs"] = ["ControlNetXSModel"]
|
||||
_import_structure["dual_transformer_2d"] = ["DualTransformer2DModel"]
|
||||
_import_structure["embeddings"] = ["ImageProjection"]
|
||||
_import_structure["modeling_utils"] = ["ModelMixin"]
|
||||
@@ -67,7 +66,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
ConsistencyDecoderVAE,
|
||||
)
|
||||
from .controlnet import ControlNetModel
|
||||
from .controlnetxs import ControlNetXSModel
|
||||
from .dual_transformer_2d import DualTransformer2DModel
|
||||
from .embeddings import ImageProjection
|
||||
from .modeling_utils import ModelMixin
|
||||
|
||||
@@ -128,12 +128,6 @@ else:
|
||||
"StableDiffusionXLControlNetPipeline",
|
||||
]
|
||||
)
|
||||
_import_structure["controlnet_xs"].extend(
|
||||
[
|
||||
"StableDiffusionControlNetXSPipeline",
|
||||
"StableDiffusionXLControlNetXSPipeline",
|
||||
]
|
||||
)
|
||||
_import_structure["deepfloyd_if"] = [
|
||||
"IFImg2ImgPipeline",
|
||||
"IFImg2ImgSuperResolutionPipeline",
|
||||
@@ -361,10 +355,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
StableDiffusionXLControlNetInpaintPipeline,
|
||||
StableDiffusionXLControlNetPipeline,
|
||||
)
|
||||
from .controlnet_xs import (
|
||||
StableDiffusionControlNetXSPipeline,
|
||||
StableDiffusionXLControlNetXSPipeline,
|
||||
)
|
||||
from .deepfloyd_if import (
|
||||
IFImg2ImgPipeline,
|
||||
IFImg2ImgSuperResolutionPipeline,
|
||||
|
||||
@@ -31,7 +31,7 @@ EXAMPLE_DOC_STRING = """
|
||||
>>> from diffusers import AmusedPipeline
|
||||
|
||||
>>> pipe = AmusedPipeline.from_pretrained(
|
||||
... "huggingface/amused-512", variant="fp16", torch_dtype=torch.float16
|
||||
... "amused/amused-512", variant="fp16", torch_dtype=torch.float16
|
||||
... )
|
||||
>>> pipe = pipe.to("cuda")
|
||||
|
||||
|
||||
@@ -32,7 +32,7 @@ EXAMPLE_DOC_STRING = """
|
||||
>>> from diffusers.utils import load_image
|
||||
|
||||
>>> pipe = AmusedImg2ImgPipeline.from_pretrained(
|
||||
... "huggingface/amused-512", variant="fp16", torch_dtype=torch.float16
|
||||
... "amused/amused-512", variant="fp16", torch_dtype=torch.float16
|
||||
... )
|
||||
>>> pipe = pipe.to("cuda")
|
||||
|
||||
|
||||
@@ -33,7 +33,7 @@ EXAMPLE_DOC_STRING = """
|
||||
>>> from diffusers.utils import load_image
|
||||
|
||||
>>> pipe = AmusedInpaintPipeline.from_pretrained(
|
||||
... "huggingface/amused-512", variant="fp16", torch_dtype=torch.float16
|
||||
... "amused/amused-512", variant="fp16", torch_dtype=torch.float16
|
||||
... )
|
||||
>>> pipe = pipe.to("cuda")
|
||||
|
||||
|
||||
@@ -33,7 +33,14 @@ from ...schedulers import (
|
||||
LMSDiscreteScheduler,
|
||||
PNDMScheduler,
|
||||
)
|
||||
from ...utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
BaseOutput,
|
||||
logging,
|
||||
replace_example_docstring,
|
||||
scale_lora_layers,
|
||||
unscale_lora_layers,
|
||||
)
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
|
||||
@@ -47,7 +54,7 @@ EXAMPLE_DOC_STRING = """
|
||||
>>> from diffusers import MotionAdapter, AnimateDiffPipeline, DDIMScheduler
|
||||
>>> from diffusers.utils import export_to_gif
|
||||
|
||||
>>> adapter = MotionAdapter.from_pretrained("diffusers/motion-adapter")
|
||||
>>> adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2")
|
||||
>>> pipe = AnimateDiffPipeline.from_pretrained("frankjoshua/toonyou_beta6", motion_adapter=adapter)
|
||||
>>> pipe.scheduler = DDIMScheduler(beta_schedule="linear", steps_offset=1, clip_sample=False)
|
||||
>>> output = pipe(prompt="A corgi walking in the park")
|
||||
@@ -533,6 +540,7 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
|
||||
return latents
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
|
||||
@@ -1,68 +0,0 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import (
|
||||
DIFFUSERS_SLOW_IMPORT,
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
get_objects_from_module,
|
||||
is_flax_available,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
|
||||
|
||||
_dummy_objects = {}
|
||||
_import_structure = {}
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils import dummy_torch_and_transformers_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["pipeline_controlnet_xs"] = ["StableDiffusionControlNetXSPipeline"]
|
||||
_import_structure["pipeline_controlnet_xs_sd_xl"] = ["StableDiffusionXLControlNetXSPipeline"]
|
||||
try:
|
||||
if not (is_transformers_available() and is_flax_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils import dummy_flax_and_transformers_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_flax_and_transformers_objects))
|
||||
else:
|
||||
pass # _import_structure["pipeline_flax_controlnet"] = ["FlaxStableDiffusionControlNetPipeline"]
|
||||
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import *
|
||||
else:
|
||||
from .pipeline_controlnet_xs import StableDiffusionControlNetXSPipeline
|
||||
from .pipeline_controlnet_xs_sd_xl import StableDiffusionXLControlNetXSPipeline
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_flax_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_flax_and_transformers_objects import * # noqa F403
|
||||
else:
|
||||
pass # from .pipeline_flax_controlnet import FlaxStableDiffusionControlNetPipeline
|
||||
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(
|
||||
__name__,
|
||||
globals()["__file__"],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
)
|
||||
for name, value in _dummy_objects.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
@@ -768,6 +768,10 @@ class StableDiffusionPipeline(
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
@@ -909,6 +913,7 @@ class StableDiffusionPipeline(
|
||||
self._guidance_rescale = guidance_rescale
|
||||
self._clip_skip = clip_skip
|
||||
self._cross_attention_kwargs = cross_attention_kwargs
|
||||
self._interrupt = False
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
@@ -986,6 +991,9 @@ class StableDiffusionPipeline(
|
||||
self._num_timesteps = len(timesteps)
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
@@ -832,6 +832,10 @@ class StableDiffusionImg2ImgPipeline(
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
@@ -963,6 +967,7 @@ class StableDiffusionImg2ImgPipeline(
|
||||
self._guidance_scale = guidance_scale
|
||||
self._clip_skip = clip_skip
|
||||
self._cross_attention_kwargs = cross_attention_kwargs
|
||||
self._interrupt = False
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
@@ -1041,6 +1046,9 @@ class StableDiffusionImg2ImgPipeline(
|
||||
self._num_timesteps = len(timesteps)
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
@@ -636,6 +636,8 @@ class StableDiffusionInpaintPipeline(
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
image,
|
||||
mask_image,
|
||||
height,
|
||||
width,
|
||||
strength,
|
||||
@@ -644,6 +646,7 @@ class StableDiffusionInpaintPipeline(
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
padding_mask_crop=None,
|
||||
):
|
||||
if strength < 0 or strength > 1:
|
||||
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
|
||||
@@ -689,6 +692,21 @@ class StableDiffusionInpaintPipeline(
|
||||
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
if padding_mask_crop is not None:
|
||||
if self.unet.config.in_channels != 4:
|
||||
raise ValueError(
|
||||
f"The UNet should have 4 input channels for inpainting mask crop, but has"
|
||||
f" {self.unet.config.in_channels} input channels."
|
||||
)
|
||||
if not isinstance(image, PIL.Image.Image):
|
||||
raise ValueError(
|
||||
f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}."
|
||||
)
|
||||
if not isinstance(mask_image, PIL.Image.Image):
|
||||
raise ValueError(
|
||||
f"The mask image should be a PIL image when inpainting mask crop, but is of type"
|
||||
f" {type(mask_image)}."
|
||||
)
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
@@ -958,6 +976,10 @@ class StableDiffusionInpaintPipeline(
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
@@ -967,6 +989,7 @@ class StableDiffusionInpaintPipeline(
|
||||
masked_image_latents: torch.FloatTensor = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
padding_mask_crop: Optional[int] = None,
|
||||
strength: float = 1.0,
|
||||
num_inference_steps: int = 50,
|
||||
timesteps: List[int] = None,
|
||||
@@ -1011,6 +1034,12 @@ class StableDiffusionInpaintPipeline(
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
||||
The width in pixels of the generated image.
|
||||
padding_mask_crop (`int`, *optional*, defaults to `None`):
|
||||
The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to image and mask_image. If
|
||||
`padding_mask_crop` is not `None`, it will first find a rectangular region with the same aspect ration of the image and
|
||||
contains all masked area, and then expand that area based on `padding_mask_crop`. The image and mask_image will then be cropped based on
|
||||
the expanded area before resizing to the original image size for inpainting. This is useful when the masked area is small while the image is large
|
||||
and contain information inreleant for inpainging, such as background.
|
||||
strength (`float`, *optional*, defaults to 1.0):
|
||||
Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
|
||||
starting point and more noise is added the higher the `strength`. The number of denoising steps depends
|
||||
@@ -1131,6 +1160,8 @@ class StableDiffusionInpaintPipeline(
|
||||
# 1. Check inputs
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
image,
|
||||
mask_image,
|
||||
height,
|
||||
width,
|
||||
strength,
|
||||
@@ -1139,11 +1170,13 @@ class StableDiffusionInpaintPipeline(
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
callback_on_step_end_tensor_inputs,
|
||||
padding_mask_crop,
|
||||
)
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._clip_skip = clip_skip
|
||||
self._cross_attention_kwargs = cross_attention_kwargs
|
||||
self._interrupt = False
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
@@ -1202,7 +1235,17 @@ class StableDiffusionInpaintPipeline(
|
||||
|
||||
# 5. Preprocess mask and image
|
||||
|
||||
init_image = self.image_processor.preprocess(image, height=height, width=width)
|
||||
if padding_mask_crop is not None:
|
||||
crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop)
|
||||
resize_mode = "fill"
|
||||
else:
|
||||
crops_coords = None
|
||||
resize_mode = "default"
|
||||
|
||||
original_image = image
|
||||
init_image = self.image_processor.preprocess(
|
||||
image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode
|
||||
)
|
||||
init_image = init_image.to(dtype=torch.float32)
|
||||
|
||||
# 6. Prepare latent variables
|
||||
@@ -1232,7 +1275,9 @@ class StableDiffusionInpaintPipeline(
|
||||
latents, noise = latents_outputs
|
||||
|
||||
# 7. Prepare mask latent variables
|
||||
mask_condition = self.mask_processor.preprocess(mask_image, height=height, width=width)
|
||||
mask_condition = self.mask_processor.preprocess(
|
||||
mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords
|
||||
)
|
||||
|
||||
if masked_image_latents is None:
|
||||
masked_image = init_image * (mask_condition < 0.5)
|
||||
@@ -1288,6 +1333,9 @@ class StableDiffusionInpaintPipeline(
|
||||
self._num_timesteps = len(timesteps)
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
||||
|
||||
@@ -1372,6 +1420,9 @@ class StableDiffusionInpaintPipeline(
|
||||
|
||||
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
|
||||
|
||||
if padding_mask_crop is not None:
|
||||
image = [self.image_processor.apply_overlay(mask_image, original_image, i, crops_coords) for i in image]
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
|
||||
@@ -849,6 +849,10 @@ class StableDiffusionXLPipeline(
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
@@ -1067,6 +1071,7 @@ class StableDiffusionXLPipeline(
|
||||
self._clip_skip = clip_skip
|
||||
self._cross_attention_kwargs = cross_attention_kwargs
|
||||
self._denoising_end = denoising_end
|
||||
self._interrupt = False
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
@@ -1196,6 +1201,9 @@ class StableDiffusionXLPipeline(
|
||||
self._num_timesteps = len(timesteps)
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
||||
|
||||
|
||||
@@ -990,6 +990,10 @@ class StableDiffusionXLImg2ImgPipeline(
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
@@ -1221,6 +1225,7 @@ class StableDiffusionXLImg2ImgPipeline(
|
||||
self._cross_attention_kwargs = cross_attention_kwargs
|
||||
self._denoising_end = denoising_end
|
||||
self._denoising_start = denoising_start
|
||||
self._interrupt = False
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
@@ -1376,6 +1381,9 @@ class StableDiffusionXLImg2ImgPipeline(
|
||||
self._num_timesteps = len(timesteps)
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
||||
|
||||
|
||||
@@ -1210,6 +1210,10 @@ class StableDiffusionXLInpaintPipeline(
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
@@ -1462,6 +1466,7 @@ class StableDiffusionXLInpaintPipeline(
|
||||
self._cross_attention_kwargs = cross_attention_kwargs
|
||||
self._denoising_end = denoising_end
|
||||
self._denoising_start = denoising_start
|
||||
self._interrupt = False
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
@@ -1684,6 +1689,8 @@ class StableDiffusionXLInpaintPipeline(
|
||||
self._num_timesteps = len(timesteps)
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
||||
|
||||
|
||||
@@ -293,9 +293,6 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
model_output: torch.FloatTensor,
|
||||
timestep: int,
|
||||
sample: torch.FloatTensor,
|
||||
eta: float = 0.0,
|
||||
use_clipped_model_output: bool = False,
|
||||
variance_noise: Optional[torch.FloatTensor] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[DDIMSchedulerOutput, Tuple]:
|
||||
"""
|
||||
@@ -332,7 +329,7 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
# 1. get previous step value (=t+1)
|
||||
prev_timestep = timestep
|
||||
timestep = min(
|
||||
timestep - self.config.num_train_timesteps // self.num_inference_steps, self.num_train_timesteps - 1
|
||||
timestep - self.config.num_train_timesteps // self.num_inference_steps, self.config.num_train_timesteps - 1
|
||||
)
|
||||
|
||||
# 2. compute alphas, betas
|
||||
|
||||
@@ -89,6 +89,43 @@ def betas_for_alpha_bar(
|
||||
return torch.tensor(betas, dtype=torch.float32)
|
||||
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
|
||||
def rescale_zero_terminal_snr(betas):
|
||||
"""
|
||||
Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
|
||||
|
||||
|
||||
Args:
|
||||
betas (`torch.FloatTensor`):
|
||||
the betas that the scheduler is being initialized with.
|
||||
|
||||
Returns:
|
||||
`torch.FloatTensor`: rescaled betas with zero terminal SNR
|
||||
"""
|
||||
# Convert betas to alphas_bar_sqrt
|
||||
alphas = 1.0 - betas
|
||||
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
||||
alphas_bar_sqrt = alphas_cumprod.sqrt()
|
||||
|
||||
# Store old values.
|
||||
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
|
||||
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
|
||||
|
||||
# Shift so the last timestep is zero.
|
||||
alphas_bar_sqrt -= alphas_bar_sqrt_T
|
||||
|
||||
# Scale so the first timestep is back to the old value.
|
||||
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
|
||||
|
||||
# Convert alphas_bar_sqrt to betas
|
||||
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
|
||||
alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
|
||||
alphas = torch.cat([alphas_bar[0:1], alphas])
|
||||
betas = 1 - alphas
|
||||
|
||||
return betas
|
||||
|
||||
|
||||
class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
`DDPMScheduler` explores the connections between denoising score matching and Langevin dynamics sampling.
|
||||
@@ -131,6 +168,10 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
An offset added to the inference steps. You can use a combination of `offset=1` and
|
||||
`set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable
|
||||
Diffusion.
|
||||
rescale_betas_zero_snr (`bool`, defaults to `False`):
|
||||
Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
|
||||
dark samples instead of limiting it to samples with medium brightness. Loosely related to
|
||||
[`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
|
||||
"""
|
||||
|
||||
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
|
||||
@@ -153,6 +194,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
sample_max_value: float = 1.0,
|
||||
timestep_spacing: str = "leading",
|
||||
steps_offset: int = 0,
|
||||
rescale_betas_zero_snr: int = False,
|
||||
):
|
||||
if trained_betas is not None:
|
||||
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
||||
@@ -171,6 +213,10 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
else:
|
||||
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
|
||||
|
||||
# Rescale for zero SNR
|
||||
if rescale_betas_zero_snr:
|
||||
self.betas = rescale_zero_terminal_snr(self.betas)
|
||||
|
||||
self.alphas = 1.0 - self.betas
|
||||
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
||||
self.one = torch.tensor(1.0)
|
||||
|
||||
@@ -91,6 +91,43 @@ def betas_for_alpha_bar(
|
||||
return torch.tensor(betas, dtype=torch.float32)
|
||||
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
|
||||
def rescale_zero_terminal_snr(betas):
|
||||
"""
|
||||
Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
|
||||
|
||||
|
||||
Args:
|
||||
betas (`torch.FloatTensor`):
|
||||
the betas that the scheduler is being initialized with.
|
||||
|
||||
Returns:
|
||||
`torch.FloatTensor`: rescaled betas with zero terminal SNR
|
||||
"""
|
||||
# Convert betas to alphas_bar_sqrt
|
||||
alphas = 1.0 - betas
|
||||
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
||||
alphas_bar_sqrt = alphas_cumprod.sqrt()
|
||||
|
||||
# Store old values.
|
||||
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
|
||||
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
|
||||
|
||||
# Shift so the last timestep is zero.
|
||||
alphas_bar_sqrt -= alphas_bar_sqrt_T
|
||||
|
||||
# Scale so the first timestep is back to the old value.
|
||||
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
|
||||
|
||||
# Convert alphas_bar_sqrt to betas
|
||||
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
|
||||
alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
|
||||
alphas = torch.cat([alphas_bar[0:1], alphas])
|
||||
betas = 1 - alphas
|
||||
|
||||
return betas
|
||||
|
||||
|
||||
class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
Denoising diffusion probabilistic models (DDPMs) explores the connections between denoising score matching and
|
||||
@@ -139,6 +176,10 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
|
||||
an offset added to the inference steps. You can use a combination of `offset=1` and
|
||||
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
|
||||
stable diffusion.
|
||||
rescale_betas_zero_snr (`bool`, defaults to `False`):
|
||||
Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
|
||||
dark samples instead of limiting it to samples with medium brightness. Loosely related to
|
||||
[`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
|
||||
"""
|
||||
|
||||
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
|
||||
@@ -163,6 +204,7 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
|
||||
sample_max_value: float = 1.0,
|
||||
timestep_spacing: str = "leading",
|
||||
steps_offset: int = 0,
|
||||
rescale_betas_zero_snr: int = False,
|
||||
):
|
||||
if trained_betas is not None:
|
||||
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
||||
@@ -181,6 +223,10 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
|
||||
else:
|
||||
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
|
||||
|
||||
# Rescale for zero SNR
|
||||
if rescale_betas_zero_snr:
|
||||
self.betas = rescale_zero_terminal_snr(self.betas)
|
||||
|
||||
self.alphas = 1.0 - self.betas
|
||||
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
||||
self.one = torch.tensor(1.0)
|
||||
|
||||
@@ -92,21 +92,6 @@ class ControlNetModel(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class ControlNetXSModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class Kandinsky3UNet(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
||||
@@ -782,21 +782,6 @@ class StableDiffusionControlNetPipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class StableDiffusionControlNetXSPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class StableDiffusionDepth2ImgPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
@@ -1142,21 +1127,6 @@ class StableDiffusionXLControlNetPipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class StableDiffusionXLControlNetXSPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class StableDiffusionXLImg2ImgPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
|
||||
@@ -79,6 +79,14 @@ PEFT_TO_DIFFUSERS = {
|
||||
".v_proj.lora_A": ".v_proj.lora_linear_layer.down",
|
||||
".out_proj.lora_B": ".out_proj.lora_linear_layer.up",
|
||||
".out_proj.lora_A": ".out_proj.lora_linear_layer.down",
|
||||
"to_k.lora_A": "to_k.lora.down",
|
||||
"to_k.lora_B": "to_k.lora.up",
|
||||
"to_q.lora_A": "to_q.lora.down",
|
||||
"to_q.lora_B": "to_q.lora.up",
|
||||
"to_v.lora_A": "to_v.lora.down",
|
||||
"to_v.lora_B": "to_v.lora.up",
|
||||
"to_out.0.lora_A": "to_out.0.lora.down",
|
||||
"to_out.0.lora_B": "to_out.0.lora.up",
|
||||
}
|
||||
|
||||
DIFFUSERS_OLD_TO_DIFFUSERS = {
|
||||
|
||||
@@ -300,6 +300,23 @@ def require_peft_backend(test_case):
|
||||
return unittest.skipUnless(USE_PEFT_BACKEND, "test requires PEFT backend")(test_case)
|
||||
|
||||
|
||||
def require_peft_version_greater(peft_version):
|
||||
"""
|
||||
Decorator marking a test that requires PEFT backend with a specific version, this would require some specific
|
||||
versions of PEFT and transformers.
|
||||
"""
|
||||
|
||||
def decorator(test_case):
|
||||
correct_peft_version = is_peft_available() and version.parse(
|
||||
version.parse(importlib.metadata.version("peft")).base_version
|
||||
) > version.parse(peft_version)
|
||||
return unittest.skipUnless(
|
||||
correct_peft_version, f"test requires PEFT backend with the version greater than {peft_version}"
|
||||
)(test_case)
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def deprecate_after_peft_backend(test_case):
|
||||
"""
|
||||
Decorator marking a test that will be skipped after PEFT backend
|
||||
|
||||
@@ -50,6 +50,7 @@ from diffusers.utils.testing_utils import (
|
||||
nightly,
|
||||
numpy_cosine_similarity_distance,
|
||||
require_peft_backend,
|
||||
require_peft_version_greater,
|
||||
require_torch_gpu,
|
||||
slow,
|
||||
torch_device,
|
||||
@@ -1105,6 +1106,68 @@ class PeftLoraLoaderMixinTests:
|
||||
{"unet": ["adapter-1", "adapter-2", "adapter-3"], "text_encoder": ["adapter-1", "adapter-2"]},
|
||||
)
|
||||
|
||||
@require_peft_version_greater(peft_version="0.6.2")
|
||||
def test_simple_inference_with_text_lora_unet_fused_multi(self):
|
||||
"""
|
||||
Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model
|
||||
and makes sure it works as expected - with unet and multi-adapter case
|
||||
"""
|
||||
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
|
||||
components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(self.torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
self.assertTrue(output_no_lora.shape == (1, 64, 64, 3))
|
||||
|
||||
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
|
||||
pipe.unet.add_adapter(unet_lora_config, "adapter-1")
|
||||
|
||||
# Attach a second adapter
|
||||
pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
|
||||
pipe.unet.add_adapter(unet_lora_config, "adapter-2")
|
||||
|
||||
self.assertTrue(
|
||||
self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
|
||||
)
|
||||
self.assertTrue(self.check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet")
|
||||
|
||||
if self.has_two_text_encoders:
|
||||
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
|
||||
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
|
||||
self.assertTrue(
|
||||
self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
|
||||
)
|
||||
|
||||
# set them to multi-adapter inference mode
|
||||
pipe.set_adapters(["adapter-1", "adapter-2"])
|
||||
ouputs_all_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
|
||||
pipe.set_adapters(["adapter-1"])
|
||||
ouputs_lora_1 = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
|
||||
pipe.fuse_lora(adapter_names=["adapter-1"])
|
||||
|
||||
# Fusing should still keep the LoRA layers so outpout should remain the same
|
||||
outputs_lora_1_fused = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
|
||||
self.assertTrue(
|
||||
np.allclose(ouputs_lora_1, outputs_lora_1_fused, atol=1e-3, rtol=1e-3),
|
||||
"Fused lora should not change the output",
|
||||
)
|
||||
|
||||
pipe.unfuse_lora()
|
||||
pipe.fuse_lora(adapter_names=["adapter-2", "adapter-1"])
|
||||
|
||||
# Fusing should still keep the LoRA layers
|
||||
output_all_lora_fused = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
self.assertTrue(
|
||||
np.allclose(output_all_lora_fused, ouputs_all_lora, atol=1e-3, rtol=1e-3),
|
||||
"Fused lora should not change the output",
|
||||
)
|
||||
|
||||
@unittest.skip("This is failing for now - need to investigate")
|
||||
def test_simple_inference_with_text_unet_lora_unfused_torch_compile(self):
|
||||
"""
|
||||
|
||||
@@ -133,7 +133,7 @@ class AmusedPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
@require_torch_gpu
|
||||
class AmusedPipelineSlowTests(unittest.TestCase):
|
||||
def test_amused_256(self):
|
||||
pipe = AmusedPipeline.from_pretrained("huggingface/amused-256")
|
||||
pipe = AmusedPipeline.from_pretrained("amused/amused-256")
|
||||
pipe.to(torch_device)
|
||||
|
||||
image = pipe("dog", generator=torch.Generator().manual_seed(0), num_inference_steps=2, output_type="np").images
|
||||
@@ -145,7 +145,7 @@ class AmusedPipelineSlowTests(unittest.TestCase):
|
||||
assert np.abs(image_slice - expected_slice).max() < 3e-3
|
||||
|
||||
def test_amused_256_fp16(self):
|
||||
pipe = AmusedPipeline.from_pretrained("huggingface/amused-256", variant="fp16", torch_dtype=torch.float16)
|
||||
pipe = AmusedPipeline.from_pretrained("amused/amused-256", variant="fp16", torch_dtype=torch.float16)
|
||||
pipe.to(torch_device)
|
||||
|
||||
image = pipe("dog", generator=torch.Generator().manual_seed(0), num_inference_steps=2, output_type="np").images
|
||||
@@ -157,7 +157,7 @@ class AmusedPipelineSlowTests(unittest.TestCase):
|
||||
assert np.abs(image_slice - expected_slice).max() < 7e-3
|
||||
|
||||
def test_amused_512(self):
|
||||
pipe = AmusedPipeline.from_pretrained("huggingface/amused-512")
|
||||
pipe = AmusedPipeline.from_pretrained("amused/amused-512")
|
||||
pipe.to(torch_device)
|
||||
|
||||
image = pipe("dog", generator=torch.Generator().manual_seed(0), num_inference_steps=2, output_type="np").images
|
||||
@@ -169,7 +169,7 @@ class AmusedPipelineSlowTests(unittest.TestCase):
|
||||
assert np.abs(image_slice - expected_slice).max() < 3e-3
|
||||
|
||||
def test_amused_512_fp16(self):
|
||||
pipe = AmusedPipeline.from_pretrained("huggingface/amused-512", variant="fp16", torch_dtype=torch.float16)
|
||||
pipe = AmusedPipeline.from_pretrained("amused/amused-512", variant="fp16", torch_dtype=torch.float16)
|
||||
pipe.to(torch_device)
|
||||
|
||||
image = pipe("dog", generator=torch.Generator().manual_seed(0), num_inference_steps=2, output_type="np").images
|
||||
|
||||
@@ -137,7 +137,7 @@ class AmusedImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
@require_torch_gpu
|
||||
class AmusedImg2ImgPipelineSlowTests(unittest.TestCase):
|
||||
def test_amused_256(self):
|
||||
pipe = AmusedImg2ImgPipeline.from_pretrained("huggingface/amused-256")
|
||||
pipe = AmusedImg2ImgPipeline.from_pretrained("amused/amused-256")
|
||||
pipe.to(torch_device)
|
||||
|
||||
image = (
|
||||
@@ -162,9 +162,7 @@ class AmusedImg2ImgPipelineSlowTests(unittest.TestCase):
|
||||
assert np.abs(image_slice - expected_slice).max() < 1e-2
|
||||
|
||||
def test_amused_256_fp16(self):
|
||||
pipe = AmusedImg2ImgPipeline.from_pretrained(
|
||||
"huggingface/amused-256", torch_dtype=torch.float16, variant="fp16"
|
||||
)
|
||||
pipe = AmusedImg2ImgPipeline.from_pretrained("amused/amused-256", torch_dtype=torch.float16, variant="fp16")
|
||||
pipe.to(torch_device)
|
||||
|
||||
image = (
|
||||
@@ -189,7 +187,7 @@ class AmusedImg2ImgPipelineSlowTests(unittest.TestCase):
|
||||
assert np.abs(image_slice - expected_slice).max() < 1e-2
|
||||
|
||||
def test_amused_512(self):
|
||||
pipe = AmusedImg2ImgPipeline.from_pretrained("huggingface/amused-512")
|
||||
pipe = AmusedImg2ImgPipeline.from_pretrained("amused/amused-512")
|
||||
pipe.to(torch_device)
|
||||
|
||||
image = (
|
||||
@@ -213,9 +211,7 @@ class AmusedImg2ImgPipelineSlowTests(unittest.TestCase):
|
||||
assert np.abs(image_slice - expected_slice).max() < 0.1
|
||||
|
||||
def test_amused_512_fp16(self):
|
||||
pipe = AmusedImg2ImgPipeline.from_pretrained(
|
||||
"huggingface/amused-512", variant="fp16", torch_dtype=torch.float16
|
||||
)
|
||||
pipe = AmusedImg2ImgPipeline.from_pretrained("amused/amused-512", variant="fp16", torch_dtype=torch.float16)
|
||||
pipe.to(torch_device)
|
||||
|
||||
image = (
|
||||
|
||||
@@ -141,7 +141,7 @@ class AmusedInpaintPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
@require_torch_gpu
|
||||
class AmusedInpaintPipelineSlowTests(unittest.TestCase):
|
||||
def test_amused_256(self):
|
||||
pipe = AmusedInpaintPipeline.from_pretrained("huggingface/amused-256")
|
||||
pipe = AmusedInpaintPipeline.from_pretrained("amused/amused-256")
|
||||
pipe.to(torch_device)
|
||||
|
||||
image = (
|
||||
@@ -174,9 +174,7 @@ class AmusedInpaintPipelineSlowTests(unittest.TestCase):
|
||||
assert np.abs(image_slice - expected_slice).max() < 0.1
|
||||
|
||||
def test_amused_256_fp16(self):
|
||||
pipe = AmusedInpaintPipeline.from_pretrained(
|
||||
"huggingface/amused-256", variant="fp16", torch_dtype=torch.float16
|
||||
)
|
||||
pipe = AmusedInpaintPipeline.from_pretrained("amused/amused-256", variant="fp16", torch_dtype=torch.float16)
|
||||
pipe.to(torch_device)
|
||||
|
||||
image = (
|
||||
@@ -209,7 +207,7 @@ class AmusedInpaintPipelineSlowTests(unittest.TestCase):
|
||||
assert np.abs(image_slice - expected_slice).max() < 0.1
|
||||
|
||||
def test_amused_512(self):
|
||||
pipe = AmusedInpaintPipeline.from_pretrained("huggingface/amused-512")
|
||||
pipe = AmusedInpaintPipeline.from_pretrained("amused/amused-512")
|
||||
pipe.to(torch_device)
|
||||
|
||||
image = (
|
||||
@@ -242,9 +240,7 @@ class AmusedInpaintPipelineSlowTests(unittest.TestCase):
|
||||
assert np.abs(image_slice - expected_slice).max() < 0.05
|
||||
|
||||
def test_amused_512_fp16(self):
|
||||
pipe = AmusedInpaintPipeline.from_pretrained(
|
||||
"huggingface/amused-512", variant="fp16", torch_dtype=torch.float16
|
||||
)
|
||||
pipe = AmusedInpaintPipeline.from_pretrained("amused/amused-512", variant="fp16", torch_dtype=torch.float16)
|
||||
pipe.to(torch_device)
|
||||
|
||||
image = (
|
||||
|
||||
@@ -1,311 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import traceback
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
ControlNetXSModel,
|
||||
DDIMScheduler,
|
||||
LCMScheduler,
|
||||
StableDiffusionControlNetXSPipeline,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
load_image,
|
||||
load_numpy,
|
||||
numpy_cosine_similarity_distance,
|
||||
require_python39_or_higher,
|
||||
require_torch_2,
|
||||
require_torch_gpu,
|
||||
run_test_in_subprocess,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
from ..pipeline_params import (
|
||||
IMAGE_TO_IMAGE_IMAGE_PARAMS,
|
||||
TEXT_TO_IMAGE_BATCH_PARAMS,
|
||||
TEXT_TO_IMAGE_IMAGE_PARAMS,
|
||||
TEXT_TO_IMAGE_PARAMS,
|
||||
)
|
||||
from ..test_pipelines_common import (
|
||||
PipelineKarrasSchedulerTesterMixin,
|
||||
PipelineLatentTesterMixin,
|
||||
PipelineTesterMixin,
|
||||
)
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
# Will be run via run_test_in_subprocess
|
||||
def _test_stable_diffusion_compile(in_queue, out_queue, timeout):
|
||||
error = None
|
||||
try:
|
||||
_ = in_queue.get(timeout=timeout)
|
||||
|
||||
controlnet = ControlNetXSModel.from_pretrained("UmerHA/ConrolNetXS-SD2.1-canny")
|
||||
|
||||
pipe = StableDiffusionControlNetXSPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-2-1", safety_checker=None, controlnet=controlnet
|
||||
)
|
||||
pipe.to("cuda")
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
pipe.unet.to(memory_format=torch.channels_last)
|
||||
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
|
||||
|
||||
pipe.controlnet.to(memory_format=torch.channels_last)
|
||||
pipe.controlnet = torch.compile(pipe.controlnet, mode="reduce-overhead", fullgraph=True)
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
prompt = "bird"
|
||||
image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png"
|
||||
).resize((512, 512))
|
||||
|
||||
output = pipe(prompt, image, num_inference_steps=10, generator=generator, output_type="np")
|
||||
image = output.images[0]
|
||||
|
||||
assert image.shape == (512, 512, 3)
|
||||
|
||||
expected_image = load_numpy(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny_out_full.npy"
|
||||
)
|
||||
expected_image = np.resize(expected_image, (512, 512, 3))
|
||||
|
||||
assert np.abs(expected_image - image).max() < 1.0
|
||||
|
||||
except Exception:
|
||||
error = f"{traceback.format_exc()}"
|
||||
|
||||
results = {"error": error}
|
||||
out_queue.put(results, timeout=timeout)
|
||||
out_queue.join()
|
||||
|
||||
|
||||
class ControlNetXSPipelineFastTests(
|
||||
PipelineLatentTesterMixin, PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, unittest.TestCase
|
||||
):
|
||||
pipeline_class = StableDiffusionControlNetXSPipeline
|
||||
params = TEXT_TO_IMAGE_PARAMS
|
||||
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
|
||||
image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
|
||||
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
|
||||
def get_dummy_components(self, time_cond_proj_dim=None):
|
||||
torch.manual_seed(0)
|
||||
unet = UNet2DConditionModel(
|
||||
block_out_channels=(4, 8),
|
||||
layers_per_block=2,
|
||||
sample_size=32,
|
||||
in_channels=4,
|
||||
out_channels=4,
|
||||
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
|
||||
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
|
||||
cross_attention_dim=32,
|
||||
norm_num_groups=1,
|
||||
time_cond_proj_dim=time_cond_proj_dim,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
controlnet = ControlNetXSModel.from_unet(
|
||||
unet=unet,
|
||||
time_embedding_mix=0.95,
|
||||
learn_embedding=True,
|
||||
size_ratio=0.5,
|
||||
conditioning_embedding_out_channels=(16, 32),
|
||||
num_attention_heads=2,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
scheduler = DDIMScheduler(
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
beta_schedule="scaled_linear",
|
||||
clip_sample=False,
|
||||
set_alpha_to_one=False,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKL(
|
||||
block_out_channels=[4, 8],
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
|
||||
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
|
||||
latent_channels=4,
|
||||
norm_num_groups=2,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
text_encoder_config = CLIPTextConfig(
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
hidden_size=32,
|
||||
intermediate_size=37,
|
||||
layer_norm_eps=1e-05,
|
||||
num_attention_heads=4,
|
||||
num_hidden_layers=5,
|
||||
pad_token_id=1,
|
||||
vocab_size=1000,
|
||||
)
|
||||
text_encoder = CLIPTextModel(text_encoder_config)
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
components = {
|
||||
"unet": unet,
|
||||
"controlnet": controlnet,
|
||||
"scheduler": scheduler,
|
||||
"vae": vae,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
"safety_checker": None,
|
||||
"feature_extractor": None,
|
||||
}
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
|
||||
controlnet_embedder_scale_factor = 2
|
||||
image = randn_tensor(
|
||||
(1, 3, 32 * controlnet_embedder_scale_factor, 32 * controlnet_embedder_scale_factor),
|
||||
generator=generator,
|
||||
device=torch.device(device),
|
||||
)
|
||||
|
||||
inputs = {
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 6.0,
|
||||
"output_type": "numpy",
|
||||
"image": image,
|
||||
}
|
||||
|
||||
return inputs
|
||||
|
||||
def test_attention_slicing_forward_pass(self):
|
||||
return self._test_attention_slicing_forward_pass(expected_max_diff=2e-3)
|
||||
|
||||
@unittest.skipIf(
|
||||
torch_device != "cuda" or not is_xformers_available(),
|
||||
reason="XFormers attention is only available with CUDA and `xformers` installed",
|
||||
)
|
||||
def test_xformers_attention_forwardGenerator_pass(self):
|
||||
self._test_xformers_attention_forwardGenerator_pass(expected_max_diff=2e-3)
|
||||
|
||||
def test_inference_batch_single_identical(self):
|
||||
self._test_inference_batch_single_identical(expected_max_diff=2e-3)
|
||||
|
||||
def test_controlnet_lcm(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
|
||||
components = self.get_dummy_components(time_cond_proj_dim=256)
|
||||
sd_pipe = StableDiffusionControlNetXSPipeline(**components)
|
||||
sd_pipe.scheduler = LCMScheduler.from_config(sd_pipe.scheduler.config)
|
||||
sd_pipe = sd_pipe.to(torch_device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
output = sd_pipe(**inputs)
|
||||
image = output.images
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 64, 64, 3)
|
||||
expected_slice = np.array(
|
||||
[0.52700454, 0.3930534, 0.25509018, 0.7132304, 0.53696585, 0.46568912, 0.7095368, 0.7059624, 0.4744786]
|
||||
)
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
class ControlNetXSPipelineSlowTests(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def test_canny(self):
|
||||
controlnet = ControlNetXSModel.from_pretrained("UmerHA/ConrolNetXS-SD2.1-canny")
|
||||
|
||||
pipe = StableDiffusionControlNetXSPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-2-1", safety_checker=None, controlnet=controlnet
|
||||
)
|
||||
pipe.enable_model_cpu_offload()
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
prompt = "bird"
|
||||
image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png"
|
||||
)
|
||||
|
||||
output = pipe(prompt, image, generator=generator, output_type="np", num_inference_steps=3)
|
||||
|
||||
image = output.images[0]
|
||||
|
||||
assert image.shape == (768, 512, 3)
|
||||
|
||||
original_image = image[-3:, -3:, -1].flatten()
|
||||
expected_image = np.array([0.1274, 0.1401, 0.147, 0.1185, 0.1555, 0.1492, 0.1565, 0.1474, 0.1701])
|
||||
|
||||
max_diff = numpy_cosine_similarity_distance(original_image, expected_image)
|
||||
assert max_diff < 1e-4
|
||||
|
||||
def test_depth(self):
|
||||
controlnet = ControlNetXSModel.from_pretrained("UmerHA/ConrolNetXS-SD2.1-depth")
|
||||
|
||||
pipe = StableDiffusionControlNetXSPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-2-1", safety_checker=None, controlnet=controlnet
|
||||
)
|
||||
pipe.enable_model_cpu_offload()
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
prompt = "Stormtrooper's lecture"
|
||||
image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/stormtrooper_depth.png"
|
||||
)
|
||||
|
||||
output = pipe(prompt, image, generator=generator, output_type="np", num_inference_steps=3)
|
||||
|
||||
image = output.images[0]
|
||||
|
||||
assert image.shape == (512, 512, 3)
|
||||
|
||||
original_image = image[-3:, -3:, -1].flatten()
|
||||
expected_image = np.array([0.1098, 0.1025, 0.1211, 0.1129, 0.1165, 0.1262, 0.1185, 0.1261, 0.1703])
|
||||
|
||||
max_diff = numpy_cosine_similarity_distance(original_image, expected_image)
|
||||
assert max_diff < 1e-4
|
||||
|
||||
@require_python39_or_higher
|
||||
@require_torch_2
|
||||
def test_stable_diffusion_compile(self):
|
||||
run_test_in_subprocess(test_case=self, target_func=_test_stable_diffusion_compile, inputs=None)
|
||||
@@ -1,362 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
ControlNetXSModel,
|
||||
EulerDiscreteScheduler,
|
||||
StableDiffusionXLControlNetXSPipeline,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.utils.testing_utils import enable_full_determinism, load_image, require_torch_gpu, slow, torch_device
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
from ..pipeline_params import (
|
||||
IMAGE_TO_IMAGE_IMAGE_PARAMS,
|
||||
TEXT_TO_IMAGE_BATCH_PARAMS,
|
||||
TEXT_TO_IMAGE_IMAGE_PARAMS,
|
||||
TEXT_TO_IMAGE_PARAMS,
|
||||
)
|
||||
from ..test_pipelines_common import (
|
||||
PipelineKarrasSchedulerTesterMixin,
|
||||
PipelineLatentTesterMixin,
|
||||
PipelineTesterMixin,
|
||||
SDXLOptionalComponentsTesterMixin,
|
||||
)
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class StableDiffusionXLControlNetXSPipelineFastTests(
|
||||
PipelineLatentTesterMixin,
|
||||
PipelineKarrasSchedulerTesterMixin,
|
||||
PipelineTesterMixin,
|
||||
SDXLOptionalComponentsTesterMixin,
|
||||
unittest.TestCase,
|
||||
):
|
||||
pipeline_class = StableDiffusionXLControlNetXSPipeline
|
||||
params = TEXT_TO_IMAGE_PARAMS
|
||||
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
|
||||
image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
|
||||
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
unet = UNet2DConditionModel(
|
||||
block_out_channels=(32, 64),
|
||||
layers_per_block=2,
|
||||
sample_size=32,
|
||||
in_channels=4,
|
||||
out_channels=4,
|
||||
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
|
||||
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
|
||||
# SD2-specific config below
|
||||
attention_head_dim=(2, 4),
|
||||
use_linear_projection=True,
|
||||
addition_embed_type="text_time",
|
||||
addition_time_embed_dim=8,
|
||||
transformer_layers_per_block=(1, 2),
|
||||
projection_class_embeddings_input_dim=80, # 6 * 8 + 32
|
||||
cross_attention_dim=64,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
controlnet = ControlNetXSModel.from_unet(
|
||||
unet,
|
||||
time_embedding_mix=0.95,
|
||||
learn_embedding=True,
|
||||
size_ratio=0.5,
|
||||
conditioning_embedding_out_channels=(16, 32),
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
scheduler = EulerDiscreteScheduler(
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
steps_offset=1,
|
||||
beta_schedule="scaled_linear",
|
||||
timestep_spacing="leading",
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKL(
|
||||
block_out_channels=[32, 64],
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
|
||||
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
|
||||
latent_channels=4,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
text_encoder_config = CLIPTextConfig(
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
hidden_size=32,
|
||||
intermediate_size=37,
|
||||
layer_norm_eps=1e-05,
|
||||
num_attention_heads=4,
|
||||
num_hidden_layers=5,
|
||||
pad_token_id=1,
|
||||
vocab_size=1000,
|
||||
# SD2-specific config below
|
||||
hidden_act="gelu",
|
||||
projection_dim=32,
|
||||
)
|
||||
text_encoder = CLIPTextModel(text_encoder_config)
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
text_encoder_2 = CLIPTextModelWithProjection(text_encoder_config)
|
||||
tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
components = {
|
||||
"unet": unet,
|
||||
"controlnet": controlnet,
|
||||
"scheduler": scheduler,
|
||||
"vae": vae,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
"text_encoder_2": text_encoder_2,
|
||||
"tokenizer_2": tokenizer_2,
|
||||
}
|
||||
return components
|
||||
|
||||
# copied from test_controlnet_sdxl.py
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
|
||||
controlnet_embedder_scale_factor = 2
|
||||
image = randn_tensor(
|
||||
(1, 3, 32 * controlnet_embedder_scale_factor, 32 * controlnet_embedder_scale_factor),
|
||||
generator=generator,
|
||||
device=torch.device(device),
|
||||
)
|
||||
|
||||
inputs = {
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 6.0,
|
||||
"output_type": "np",
|
||||
"image": image,
|
||||
}
|
||||
|
||||
return inputs
|
||||
|
||||
# copied from test_controlnet_sdxl.py
|
||||
def test_attention_slicing_forward_pass(self):
|
||||
return self._test_attention_slicing_forward_pass(expected_max_diff=2e-3)
|
||||
|
||||
# copied from test_controlnet_sdxl.py
|
||||
@unittest.skipIf(
|
||||
torch_device != "cuda" or not is_xformers_available(),
|
||||
reason="XFormers attention is only available with CUDA and `xformers` installed",
|
||||
)
|
||||
def test_xformers_attention_forwardGenerator_pass(self):
|
||||
self._test_xformers_attention_forwardGenerator_pass(expected_max_diff=2e-3)
|
||||
|
||||
# copied from test_controlnet_sdxl.py
|
||||
def test_inference_batch_single_identical(self):
|
||||
self._test_inference_batch_single_identical(expected_max_diff=2e-3)
|
||||
|
||||
# copied from test_controlnet_sdxl.py
|
||||
def test_save_load_optional_components(self):
|
||||
self._test_save_load_optional_components()
|
||||
|
||||
# copied from test_controlnet_sdxl.py
|
||||
@require_torch_gpu
|
||||
def test_stable_diffusion_xl_offloads(self):
|
||||
pipes = []
|
||||
components = self.get_dummy_components()
|
||||
sd_pipe = self.pipeline_class(**components).to(torch_device)
|
||||
pipes.append(sd_pipe)
|
||||
|
||||
components = self.get_dummy_components()
|
||||
sd_pipe = self.pipeline_class(**components)
|
||||
sd_pipe.enable_model_cpu_offload()
|
||||
pipes.append(sd_pipe)
|
||||
|
||||
components = self.get_dummy_components()
|
||||
sd_pipe = self.pipeline_class(**components)
|
||||
sd_pipe.enable_sequential_cpu_offload()
|
||||
pipes.append(sd_pipe)
|
||||
|
||||
image_slices = []
|
||||
for pipe in pipes:
|
||||
pipe.unet.set_default_attn_processor()
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
image = pipe(**inputs).images
|
||||
|
||||
image_slices.append(image[0, -3:, -3:, -1].flatten())
|
||||
|
||||
assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3
|
||||
assert np.abs(image_slices[0] - image_slices[2]).max() < 1e-3
|
||||
|
||||
# copied from test_controlnet_sdxl.py
|
||||
def test_stable_diffusion_xl_multi_prompts(self):
|
||||
components = self.get_dummy_components()
|
||||
sd_pipe = self.pipeline_class(**components).to(torch_device)
|
||||
|
||||
# forward with single prompt
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
output = sd_pipe(**inputs)
|
||||
image_slice_1 = output.images[0, -3:, -3:, -1]
|
||||
|
||||
# forward with same prompt duplicated
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
inputs["prompt_2"] = inputs["prompt"]
|
||||
output = sd_pipe(**inputs)
|
||||
image_slice_2 = output.images[0, -3:, -3:, -1]
|
||||
|
||||
# ensure the results are equal
|
||||
assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4
|
||||
|
||||
# forward with different prompt
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
inputs["prompt_2"] = "different prompt"
|
||||
output = sd_pipe(**inputs)
|
||||
image_slice_3 = output.images[0, -3:, -3:, -1]
|
||||
|
||||
# ensure the results are not equal
|
||||
assert np.abs(image_slice_1.flatten() - image_slice_3.flatten()).max() > 1e-4
|
||||
|
||||
# manually set a negative_prompt
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
inputs["negative_prompt"] = "negative prompt"
|
||||
output = sd_pipe(**inputs)
|
||||
image_slice_1 = output.images[0, -3:, -3:, -1]
|
||||
|
||||
# forward with same negative_prompt duplicated
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
inputs["negative_prompt"] = "negative prompt"
|
||||
inputs["negative_prompt_2"] = inputs["negative_prompt"]
|
||||
output = sd_pipe(**inputs)
|
||||
image_slice_2 = output.images[0, -3:, -3:, -1]
|
||||
|
||||
# ensure the results are equal
|
||||
assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4
|
||||
|
||||
# forward with different negative_prompt
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
inputs["negative_prompt"] = "negative prompt"
|
||||
inputs["negative_prompt_2"] = "different negative prompt"
|
||||
output = sd_pipe(**inputs)
|
||||
image_slice_3 = output.images[0, -3:, -3:, -1]
|
||||
|
||||
# ensure the results are not equal
|
||||
assert np.abs(image_slice_1.flatten() - image_slice_3.flatten()).max() > 1e-4
|
||||
|
||||
# copied from test_stable_diffusion_xl.py
|
||||
def test_stable_diffusion_xl_prompt_embeds(self):
|
||||
components = self.get_dummy_components()
|
||||
sd_pipe = self.pipeline_class(**components)
|
||||
sd_pipe = sd_pipe.to(torch_device)
|
||||
sd_pipe = sd_pipe.to(torch_device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
# forward without prompt embeds
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
inputs["prompt"] = 2 * [inputs["prompt"]]
|
||||
inputs["num_images_per_prompt"] = 2
|
||||
|
||||
output = sd_pipe(**inputs)
|
||||
image_slice_1 = output.images[0, -3:, -3:, -1]
|
||||
|
||||
# forward with prompt embeds
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
prompt = 2 * [inputs.pop("prompt")]
|
||||
|
||||
(
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds,
|
||||
) = sd_pipe.encode_prompt(prompt)
|
||||
|
||||
output = sd_pipe(
|
||||
**inputs,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
pooled_prompt_embeds=pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
||||
)
|
||||
image_slice_2 = output.images[0, -3:, -3:, -1]
|
||||
|
||||
# make sure that it's equal
|
||||
assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1.1e-4
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
class ControlNetSDXLPipelineXSSlowTests(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def test_canny(self):
|
||||
controlnet = ControlNetXSModel.from_pretrained("UmerHA/ConrolNetXS-SDXL-canny")
|
||||
|
||||
pipe = StableDiffusionXLControlNetXSPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet
|
||||
)
|
||||
pipe.enable_sequential_cpu_offload()
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
prompt = "bird"
|
||||
image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png"
|
||||
)
|
||||
|
||||
images = pipe(prompt, image=image, generator=generator, output_type="np", num_inference_steps=3).images
|
||||
|
||||
assert images[0].shape == (768, 512, 3)
|
||||
|
||||
original_image = images[0, -3:, -3:, -1].flatten()
|
||||
expected_image = np.array([0.4359, 0.4335, 0.4609, 0.4515, 0.4669, 0.4494, 0.452, 0.4493, 0.4382])
|
||||
assert np.allclose(original_image, expected_image, atol=1e-04)
|
||||
|
||||
def test_depth(self):
|
||||
controlnet = ControlNetXSModel.from_pretrained("UmerHA/ConrolNetXS-SDXL-depth")
|
||||
|
||||
pipe = StableDiffusionXLControlNetXSPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet
|
||||
)
|
||||
pipe.enable_sequential_cpu_offload()
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
prompt = "Stormtrooper's lecture"
|
||||
image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/stormtrooper_depth.png"
|
||||
)
|
||||
|
||||
images = pipe(prompt, image=image, generator=generator, output_type="np", num_inference_steps=3).images
|
||||
|
||||
assert images[0].shape == (512, 512, 3)
|
||||
|
||||
original_image = images[0, -3:, -3:, -1].flatten()
|
||||
expected_image = np.array([0.4411, 0.3617, 0.2654, 0.266, 0.3449, 0.3898, 0.3745, 0.353, 0.326])
|
||||
assert np.allclose(original_image, expected_image, atol=1e-04)
|
||||
@@ -692,6 +692,58 @@ class StableDiffusionPipelineFastTests(
|
||||
original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
|
||||
), "Original outputs should match when fused QKV projections are disabled."
|
||||
|
||||
def test_pipeline_interrupt(self):
|
||||
components = self.get_dummy_components()
|
||||
sd_pipe = StableDiffusionPipeline(**components)
|
||||
sd_pipe = sd_pipe.to(torch_device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "hey"
|
||||
num_inference_steps = 3
|
||||
|
||||
# store intermediate latents from the generation process
|
||||
class PipelineState:
|
||||
def __init__(self):
|
||||
self.state = []
|
||||
|
||||
def apply(self, pipe, i, t, callback_kwargs):
|
||||
self.state.append(callback_kwargs["latents"])
|
||||
return callback_kwargs
|
||||
|
||||
pipe_state = PipelineState()
|
||||
sd_pipe(
|
||||
prompt,
|
||||
num_inference_steps=num_inference_steps,
|
||||
output_type="np",
|
||||
generator=torch.Generator("cpu").manual_seed(0),
|
||||
callback_on_step_end=pipe_state.apply,
|
||||
).images
|
||||
|
||||
# interrupt generation at step index
|
||||
interrupt_step_idx = 1
|
||||
|
||||
def callback_on_step_end(pipe, i, t, callback_kwargs):
|
||||
if i == interrupt_step_idx:
|
||||
pipe._interrupt = True
|
||||
|
||||
return callback_kwargs
|
||||
|
||||
output_interrupted = sd_pipe(
|
||||
prompt,
|
||||
num_inference_steps=num_inference_steps,
|
||||
output_type="latent",
|
||||
generator=torch.Generator("cpu").manual_seed(0),
|
||||
callback_on_step_end=callback_on_step_end,
|
||||
).images
|
||||
|
||||
# fetch intermediate latents at the interrupted step
|
||||
# from the completed generation process
|
||||
intermediate_latent = pipe_state.state[interrupt_step_idx]
|
||||
|
||||
# compare the intermediate latent to the output of the interrupted process
|
||||
# they should be the same
|
||||
assert torch.allclose(intermediate_latent, output_interrupted, atol=1e-4)
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
|
||||
@@ -320,6 +320,62 @@ class StableDiffusionImg2ImgPipelineFastTests(
|
||||
def test_float16_inference(self):
|
||||
super().test_float16_inference(expected_max_diff=5e-1)
|
||||
|
||||
def test_pipeline_interrupt(self):
|
||||
components = self.get_dummy_components()
|
||||
sd_pipe = StableDiffusionImg2ImgPipeline(**components)
|
||||
sd_pipe = sd_pipe.to(torch_device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
|
||||
prompt = "hey"
|
||||
num_inference_steps = 3
|
||||
|
||||
# store intermediate latents from the generation process
|
||||
class PipelineState:
|
||||
def __init__(self):
|
||||
self.state = []
|
||||
|
||||
def apply(self, pipe, i, t, callback_kwargs):
|
||||
self.state.append(callback_kwargs["latents"])
|
||||
return callback_kwargs
|
||||
|
||||
pipe_state = PipelineState()
|
||||
sd_pipe(
|
||||
prompt,
|
||||
image=inputs["image"],
|
||||
num_inference_steps=num_inference_steps,
|
||||
output_type="np",
|
||||
generator=torch.Generator("cpu").manual_seed(0),
|
||||
callback_on_step_end=pipe_state.apply,
|
||||
).images
|
||||
|
||||
# interrupt generation at step index
|
||||
interrupt_step_idx = 1
|
||||
|
||||
def callback_on_step_end(pipe, i, t, callback_kwargs):
|
||||
if i == interrupt_step_idx:
|
||||
pipe._interrupt = True
|
||||
|
||||
return callback_kwargs
|
||||
|
||||
output_interrupted = sd_pipe(
|
||||
prompt,
|
||||
image=inputs["image"],
|
||||
num_inference_steps=num_inference_steps,
|
||||
output_type="latent",
|
||||
generator=torch.Generator("cpu").manual_seed(0),
|
||||
callback_on_step_end=callback_on_step_end,
|
||||
).images
|
||||
|
||||
# fetch intermediate latents at the interrupted step
|
||||
# from the completed generation process
|
||||
intermediate_latent = pipe_state.state[interrupt_step_idx]
|
||||
|
||||
# compare the intermediate latent to the output of the interrupted process
|
||||
# they should be the same
|
||||
assert torch.allclose(intermediate_latent, output_interrupted, atol=1e-4)
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
|
||||
@@ -319,6 +319,64 @@ class StableDiffusionInpaintPipelineFastTests(
|
||||
out_1 = sd_pipe(**inputs).images
|
||||
assert np.abs(out_0 - out_1).max() < 1e-2
|
||||
|
||||
def test_pipeline_interrupt(self):
|
||||
components = self.get_dummy_components()
|
||||
sd_pipe = StableDiffusionInpaintPipeline(**components)
|
||||
sd_pipe = sd_pipe.to(torch_device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
|
||||
prompt = "hey"
|
||||
num_inference_steps = 3
|
||||
|
||||
# store intermediate latents from the generation process
|
||||
class PipelineState:
|
||||
def __init__(self):
|
||||
self.state = []
|
||||
|
||||
def apply(self, pipe, i, t, callback_kwargs):
|
||||
self.state.append(callback_kwargs["latents"])
|
||||
return callback_kwargs
|
||||
|
||||
pipe_state = PipelineState()
|
||||
sd_pipe(
|
||||
prompt,
|
||||
image=inputs["image"],
|
||||
mask_image=inputs["mask_image"],
|
||||
num_inference_steps=num_inference_steps,
|
||||
output_type="np",
|
||||
generator=torch.Generator("cpu").manual_seed(0),
|
||||
callback_on_step_end=pipe_state.apply,
|
||||
).images
|
||||
|
||||
# interrupt generation at step index
|
||||
interrupt_step_idx = 1
|
||||
|
||||
def callback_on_step_end(pipe, i, t, callback_kwargs):
|
||||
if i == interrupt_step_idx:
|
||||
pipe._interrupt = True
|
||||
|
||||
return callback_kwargs
|
||||
|
||||
output_interrupted = sd_pipe(
|
||||
prompt,
|
||||
image=inputs["image"],
|
||||
mask_image=inputs["mask_image"],
|
||||
num_inference_steps=num_inference_steps,
|
||||
output_type="latent",
|
||||
generator=torch.Generator("cpu").manual_seed(0),
|
||||
callback_on_step_end=callback_on_step_end,
|
||||
).images
|
||||
|
||||
# fetch intermediate latents at the interrupted step
|
||||
# from the completed generation process
|
||||
intermediate_latent = pipe_state.state[interrupt_step_idx]
|
||||
|
||||
# compare the intermediate latent to the output of the interrupted process
|
||||
# they should be the same
|
||||
assert torch.allclose(intermediate_latent, output_interrupted, atol=1e-4)
|
||||
|
||||
|
||||
class StableDiffusionSimpleInpaintPipelineFastTests(StableDiffusionInpaintPipelineFastTests):
|
||||
pipeline_class = StableDiffusionInpaintPipeline
|
||||
|
||||
@@ -969,6 +969,58 @@ class StableDiffusionXLPipelineFastTests(
|
||||
original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
|
||||
), "Original outputs should match when fused QKV projections are disabled."
|
||||
|
||||
def test_pipeline_interrupt(self):
|
||||
components = self.get_dummy_components()
|
||||
sd_pipe = StableDiffusionXLPipeline(**components)
|
||||
sd_pipe = sd_pipe.to(torch_device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "hey"
|
||||
num_inference_steps = 3
|
||||
|
||||
# store intermediate latents from the generation process
|
||||
class PipelineState:
|
||||
def __init__(self):
|
||||
self.state = []
|
||||
|
||||
def apply(self, pipe, i, t, callback_kwargs):
|
||||
self.state.append(callback_kwargs["latents"])
|
||||
return callback_kwargs
|
||||
|
||||
pipe_state = PipelineState()
|
||||
sd_pipe(
|
||||
prompt,
|
||||
num_inference_steps=num_inference_steps,
|
||||
output_type="np",
|
||||
generator=torch.Generator("cpu").manual_seed(0),
|
||||
callback_on_step_end=pipe_state.apply,
|
||||
).images
|
||||
|
||||
# interrupt generation at step index
|
||||
interrupt_step_idx = 1
|
||||
|
||||
def callback_on_step_end(pipe, i, t, callback_kwargs):
|
||||
if i == interrupt_step_idx:
|
||||
pipe._interrupt = True
|
||||
|
||||
return callback_kwargs
|
||||
|
||||
output_interrupted = sd_pipe(
|
||||
prompt,
|
||||
num_inference_steps=num_inference_steps,
|
||||
output_type="latent",
|
||||
generator=torch.Generator("cpu").manual_seed(0),
|
||||
callback_on_step_end=callback_on_step_end,
|
||||
).images
|
||||
|
||||
# fetch intermediate latents at the interrupted step
|
||||
# from the completed generation process
|
||||
intermediate_latent = pipe_state.state[interrupt_step_idx]
|
||||
|
||||
# compare the intermediate latent to the output of the interrupted process
|
||||
# they should be the same
|
||||
assert torch.allclose(intermediate_latent, output_interrupted, atol=1e-4)
|
||||
|
||||
|
||||
@slow
|
||||
class StableDiffusionXLPipelineIntegrationTests(unittest.TestCase):
|
||||
|
||||
@@ -439,6 +439,64 @@ class StableDiffusionXLImg2ImgPipelineFastTests(PipelineLatentTesterMixin, Pipel
|
||||
> 1e-4
|
||||
)
|
||||
|
||||
def test_pipeline_interrupt(self):
|
||||
components = self.get_dummy_components()
|
||||
sd_pipe = StableDiffusionXLImg2ImgPipeline(**components)
|
||||
sd_pipe = sd_pipe.to(torch_device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
|
||||
prompt = "hey"
|
||||
num_inference_steps = 5
|
||||
|
||||
# store intermediate latents from the generation process
|
||||
class PipelineState:
|
||||
def __init__(self):
|
||||
self.state = []
|
||||
|
||||
def apply(self, pipe, i, t, callback_kwargs):
|
||||
self.state.append(callback_kwargs["latents"])
|
||||
return callback_kwargs
|
||||
|
||||
pipe_state = PipelineState()
|
||||
sd_pipe(
|
||||
prompt,
|
||||
image=inputs["image"],
|
||||
strength=0.8,
|
||||
num_inference_steps=num_inference_steps,
|
||||
output_type="np",
|
||||
generator=torch.Generator("cpu").manual_seed(0),
|
||||
callback_on_step_end=pipe_state.apply,
|
||||
).images
|
||||
|
||||
# interrupt generation at step index
|
||||
interrupt_step_idx = 1
|
||||
|
||||
def callback_on_step_end(pipe, i, t, callback_kwargs):
|
||||
if i == interrupt_step_idx:
|
||||
pipe._interrupt = True
|
||||
|
||||
return callback_kwargs
|
||||
|
||||
output_interrupted = sd_pipe(
|
||||
prompt,
|
||||
image=inputs["image"],
|
||||
strength=0.8,
|
||||
num_inference_steps=num_inference_steps,
|
||||
output_type="latent",
|
||||
generator=torch.Generator("cpu").manual_seed(0),
|
||||
callback_on_step_end=callback_on_step_end,
|
||||
).images
|
||||
|
||||
# fetch intermediate latents at the interrupted step
|
||||
# from the completed generation process
|
||||
intermediate_latent = pipe_state.state[interrupt_step_idx]
|
||||
|
||||
# compare the intermediate latent to the output of the interrupted process
|
||||
# they should be the same
|
||||
assert torch.allclose(intermediate_latent, output_interrupted, atol=1e-4)
|
||||
|
||||
|
||||
class StableDiffusionXLImg2ImgRefinerOnlyPipelineFastTests(
|
||||
PipelineLatentTesterMixin, PipelineTesterMixin, SDXLOptionalComponentsTesterMixin, unittest.TestCase
|
||||
|
||||
@@ -746,3 +746,63 @@ class StableDiffusionXLInpaintPipelineFastTests(PipelineLatentTesterMixin, Pipel
|
||||
image_slice1 = images[0, -3:, -3:, -1]
|
||||
image_slice2 = images[1, -3:, -3:, -1]
|
||||
assert np.abs(image_slice1.flatten() - image_slice2.flatten()).max() > 1e-2
|
||||
|
||||
def test_pipeline_interrupt(self):
|
||||
components = self.get_dummy_components()
|
||||
sd_pipe = StableDiffusionXLInpaintPipeline(**components)
|
||||
sd_pipe = sd_pipe.to(torch_device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
|
||||
prompt = "hey"
|
||||
num_inference_steps = 5
|
||||
|
||||
# store intermediate latents from the generation process
|
||||
class PipelineState:
|
||||
def __init__(self):
|
||||
self.state = []
|
||||
|
||||
def apply(self, pipe, i, t, callback_kwargs):
|
||||
self.state.append(callback_kwargs["latents"])
|
||||
return callback_kwargs
|
||||
|
||||
pipe_state = PipelineState()
|
||||
sd_pipe(
|
||||
prompt,
|
||||
image=inputs["image"],
|
||||
mask_image=inputs["mask_image"],
|
||||
strength=0.8,
|
||||
num_inference_steps=num_inference_steps,
|
||||
output_type="np",
|
||||
generator=torch.Generator("cpu").manual_seed(0),
|
||||
callback_on_step_end=pipe_state.apply,
|
||||
).images
|
||||
|
||||
# interrupt generation at step index
|
||||
interrupt_step_idx = 1
|
||||
|
||||
def callback_on_step_end(pipe, i, t, callback_kwargs):
|
||||
if i == interrupt_step_idx:
|
||||
pipe._interrupt = True
|
||||
|
||||
return callback_kwargs
|
||||
|
||||
output_interrupted = sd_pipe(
|
||||
prompt,
|
||||
image=inputs["image"],
|
||||
mask_image=inputs["mask_image"],
|
||||
strength=0.8,
|
||||
num_inference_steps=num_inference_steps,
|
||||
output_type="latent",
|
||||
generator=torch.Generator("cpu").manual_seed(0),
|
||||
callback_on_step_end=callback_on_step_end,
|
||||
).images
|
||||
|
||||
# fetch intermediate latents at the interrupted step
|
||||
# from the completed generation process
|
||||
intermediate_latent = pipe_state.state[interrupt_step_idx]
|
||||
|
||||
# compare the intermediate latent to the output of the interrupted process
|
||||
# they should be the same
|
||||
assert torch.allclose(intermediate_latent, output_interrupted, atol=1e-4)
|
||||
|
||||
@@ -7,7 +7,7 @@ from .test_schedulers import SchedulerCommonTest
|
||||
|
||||
class DDIMInverseSchedulerTest(SchedulerCommonTest):
|
||||
scheduler_classes = (DDIMInverseScheduler,)
|
||||
forward_default_kwargs = (("eta", 0.0), ("num_inference_steps", 50))
|
||||
forward_default_kwargs = (("num_inference_steps", 50),)
|
||||
|
||||
def get_scheduler_config(self, **kwargs):
|
||||
config = {
|
||||
@@ -26,7 +26,7 @@ class DDIMInverseSchedulerTest(SchedulerCommonTest):
|
||||
scheduler_config = self.get_scheduler_config(**config)
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
|
||||
num_inference_steps, eta = 10, 0.0
|
||||
num_inference_steps = 10
|
||||
|
||||
model = self.dummy_model()
|
||||
sample = self.dummy_sample_deter
|
||||
@@ -35,7 +35,7 @@ class DDIMInverseSchedulerTest(SchedulerCommonTest):
|
||||
|
||||
for t in scheduler.timesteps:
|
||||
residual = model(sample, t)
|
||||
sample = scheduler.step(residual, t, sample, eta).prev_sample
|
||||
sample = scheduler.step(residual, t, sample).prev_sample
|
||||
|
||||
return sample
|
||||
|
||||
|
||||
@@ -68,6 +68,10 @@ class DDPMSchedulerTest(SchedulerCommonTest):
|
||||
assert torch.sum(torch.abs(scheduler._get_variance(487) - 0.00979)) < 1e-5
|
||||
assert torch.sum(torch.abs(scheduler._get_variance(999) - 0.02)) < 1e-5
|
||||
|
||||
def test_rescale_betas_zero_snr(self):
|
||||
for rescale_betas_zero_snr in [True, False]:
|
||||
self.check_over_configs(rescale_betas_zero_snr=rescale_betas_zero_snr)
|
||||
|
||||
def test_full_loop_no_noise(self):
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
|
||||
@@ -82,6 +82,10 @@ class DDPMParallelSchedulerTest(SchedulerCommonTest):
|
||||
assert torch.sum(torch.abs(scheduler._get_variance(487) - 0.00979)) < 1e-5
|
||||
assert torch.sum(torch.abs(scheduler._get_variance(999) - 0.02)) < 1e-5
|
||||
|
||||
def test_rescale_betas_zero_snr(self):
|
||||
for rescale_betas_zero_snr in [True, False]:
|
||||
self.check_over_configs(rescale_betas_zero_snr=rescale_betas_zero_snr)
|
||||
|
||||
def test_batch_step_no_noise(self):
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
|
||||
Reference in New Issue
Block a user