mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-22 04:14:43 +08:00
Compare commits
1 Commits
test-clean
...
v0.20.0
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e3380c02cf |
@@ -190,8 +190,6 @@
|
||||
title: Audio Diffusion
|
||||
- local: api/pipelines/audioldm
|
||||
title: AudioLDM
|
||||
- local: api/pipelines/audioldm2
|
||||
title: AudioLDM 2
|
||||
- local: api/pipelines/auto_pipeline
|
||||
title: AutoPipeline
|
||||
- local: api/pipelines/consistency_models
|
||||
|
||||
@@ -46,5 +46,6 @@ Make sure to check out the Schedulers [guide](/using-diffusers/schedulers) to le
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## AudioPipelineOutput
|
||||
[[autodoc]] pipelines.AudioPipelineOutput
|
||||
## StableDiffusionPipelineOutput
|
||||
|
||||
[[autodoc]] pipelines.stable_diffusion.StableDiffusionPipelineOutput
|
||||
@@ -1,93 +0,0 @@
|
||||
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# AudioLDM 2
|
||||
|
||||
AudioLDM 2 was proposed in [AudioLDM 2: Learning Holistic Audio Generation with Self-supervised Pretraining](https://arxiv.org/abs/2308.05734)
|
||||
by Haohe Liu et al. AudioLDM 2 takes a text prompt as input and predicts the corresponding audio. It can generate
|
||||
text-conditional sound effects, human speech and music.
|
||||
|
||||
Inspired by [Stable Diffusion](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/overview), AudioLDM 2
|
||||
is a text-to-audio _latent diffusion model (LDM)_ that learns continuous audio representations from text embeddings. Two
|
||||
text encoder models are used to compute the text embeddings from a prompt input: the text-branch of [CLAP](https://huggingface.co/docs/transformers/main/en/model_doc/clap)
|
||||
and the encoder of [Flan-T5](https://huggingface.co/docs/transformers/main/en/model_doc/flan-t5). These text embeddings
|
||||
are then projected to a shared embedding space by an [AudioLDM2ProjectionModel](https://huggingface.co/docs/diffusers/main/api/pipelines/audioldm2#diffusers.AudioLDM2ProjectionModel).
|
||||
A [GPT2](https://huggingface.co/docs/transformers/main/en/model_doc/gpt2) _language model (LM)_ is used to auto-regressively
|
||||
predict eight new embedding vectors, conditional on the projected CLAP and Flan-T5 embeddings. The generated embedding
|
||||
vectors and Flan-T5 text embeddings are used as cross-attention conditioning in the LDM. The [UNet](https://huggingface.co/docs/diffusers/main/en/api/pipelines/audioldm2#diffusers.AudioLDM2UNet2DConditionModel)
|
||||
of AudioLDM 2 is unique in the sense that it takes **two** cross-attention embeddings, as opposed to one cross-attention
|
||||
conditioning, as in most other LDMs.
|
||||
|
||||
The abstract of the paper is the following:
|
||||
|
||||
*Although audio generation shares commonalities across different types of audio, such as speech, music, and sound effects, designing models for each type requires careful consideration of specific objectives and biases that can significantly differ from those of other types. To bring us closer to a unified perspective of audio generation, this paper proposes a framework that utilizes the same learning method for speech, music, and sound effect generation. Our framework introduces a general representation of audio, called language of audio (LOA). Any audio can be translated into LOA based on AudioMAE, a self-supervised pre-trained representation learning model. In the generation process, we translate any modalities into LOA by using a GPT-2 model, and we perform self-supervised audio generation learning with a latent diffusion model conditioned on LOA. The proposed framework naturally brings advantages such as in-context learning abilities and reusable self-supervised pretrained AudioMAE and latent diffusion models. Experiments on the major benchmarks of text-to-audio, text-to-music, and text-to-speech demonstrate new state-of-the-art or competitive performance to previous approaches.*
|
||||
|
||||
This pipeline was contributed by [sanchit-gandhi](https://huggingface.co/sanchit-gandhi). The original codebase can be
|
||||
found at [haoheliu/audioldm2](https://github.com/haoheliu/audioldm2).
|
||||
|
||||
## Tips
|
||||
|
||||
### Choosing a checkpoint
|
||||
|
||||
AudioLDM2 comes in three variants. Two of these checkpoints are applicable to the general task of text-to-audio
|
||||
generation. The third checkpoint is trained exclusively on text-to-music generation.
|
||||
|
||||
All checkpoints share the same model size for the text encoders and VAE. They differ in the size and depth of the UNet.
|
||||
See table below for details on the three checkpoints:
|
||||
|
||||
| Checkpoint | Task | UNet Model Size | Total Model Size | Training Data / h |
|
||||
|-----------------------------------------------------------------|---------------|-----------------|------------------|-------------------|
|
||||
| [audioldm2](https://huggingface.co/cvssp/audioldm2) | Text-to-audio | 350M | 1.1B | 1150k |
|
||||
| [audioldm2-large](https://huggingface.co/cvssp/audioldm2-large) | Text-to-audio | 750M | 1.5B | 1150k |
|
||||
| [audioldm2-music](https://huggingface.co/cvssp/audioldm2-music) | Text-to-music | 350M | 1.1B | 665k |
|
||||
|
||||
### Constructing a prompt
|
||||
|
||||
* Descriptive prompt inputs work best: use adjectives to describe the sound (e.g. "high quality" or "clear") and make the prompt context specific (e.g. "water stream in a forest" instead of "stream").
|
||||
* It's best to use general terms like "cat" or "dog" instead of specific names or abstract objects the model may not be familiar with.
|
||||
* Using a **negative prompt** can significantly improve the quality of the generated waveform, by guiding the generation away from terms that correspond to poor quality audio. Try using a negative prompt of "Low quality."
|
||||
|
||||
### Controlling inference
|
||||
|
||||
* The _quality_ of the predicted audio sample can be controlled by the `num_inference_steps` argument; higher steps give higher quality audio at the expense of slower inference.
|
||||
* The _length_ of the predicted audio sample can be controlled by varying the `audio_length_in_s` argument.
|
||||
|
||||
### Evaluating generated waveforms:
|
||||
|
||||
* The quality of the generated waveforms can vary significantly based on the seed. Try generating with different seeds until you find a satisfactory generation
|
||||
* Multiple waveforms can be generated in one go: set `num_waveforms_per_prompt` to a value greater than 1. Automatic scoring will be performed between the generated waveforms and prompt text, and the audios ranked from best to worst accordingly.
|
||||
|
||||
The following example demonstrates how to construct good music generation using the aforementioned tips: [example](https://huggingface.co/docs/diffusers/main/en/api/pipelines/audioldm2#diffusers.AudioLDM2Pipeline.__call__.example).
|
||||
|
||||
<Tip>
|
||||
|
||||
Make sure to check out the Schedulers [guide](/using-diffusers/schedulers) to learn how to explore the tradeoff between
|
||||
scheduler speed and quality, and see the [reuse components across pipelines](/using-diffusers/loading#reuse-components-across-pipelines)
|
||||
section to learn how to efficiently load the same components into multiple pipelines.
|
||||
|
||||
</Tip>
|
||||
|
||||
## AudioLDM2Pipeline
|
||||
[[autodoc]] AudioLDM2Pipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## AudioLDM2ProjectionModel
|
||||
[[autodoc]] AudioLDM2ProjectionModel
|
||||
- forward
|
||||
|
||||
## AudioLDM2UNet2DConditionModel
|
||||
[[autodoc]] AudioLDM2UNet2DConditionModel
|
||||
- forward
|
||||
|
||||
## AudioPipelineOutput
|
||||
[[autodoc]] pipelines.AudioPipelineOutput
|
||||
@@ -20,12 +20,6 @@ The abstract from the [paper](https://arxiv.org/abs/2303.06555) is:
|
||||
|
||||
You can find the original codebase at [thu-ml/unidiffuser](https://github.com/thu-ml/unidiffuser) and additional checkpoints at [thu-ml](https://huggingface.co/thu-ml).
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
There is currently an issue on PyTorch 1.X where the output images are all black or the pixel values become `NaNs`. This issue can be mitigated by switching to PyTorch 2.X.
|
||||
|
||||
</Tip>
|
||||
|
||||
This pipeline was contributed by [dg845](https://github.com/dg845). ❤️
|
||||
|
||||
## Usage Examples
|
||||
|
||||
@@ -265,7 +265,7 @@ distributed_type: DEEPSPEED
|
||||
|
||||
See [documentation](https://huggingface.co/docs/accelerate/usage_guides/deepspeed) for more DeepSpeed configuration options.
|
||||
|
||||
</Tip>
|
||||
<Tip>
|
||||
|
||||
Changing the default Adam optimizer to DeepSpeed's Adam
|
||||
`deepspeed.ops.adam.DeepSpeedCPUAdam` gives a substantial speedup but
|
||||
@@ -330,4 +330,4 @@ image.save("./output.png")
|
||||
|
||||
## Stable Diffusion XL
|
||||
|
||||
Training with [Stable Diffusion XL](https://huggingface.co/papers/2307.01952) is also supported via the `train_controlnet_sdxl.py` script. Please refer to the docs [here](https://github.com/huggingface/diffusers/blob/main/examples/controlnet/README_sdxl.md).
|
||||
Training with [Stable Diffusion XL](https://huggingface.co/papers/2307.01952) is also supported via the `train_controlnet_sdxl.py` script. Please refer to the docs [here](https://github.com/huggingface/diffusers/blob/main/examples/controlnet/README_sdxl.md).
|
||||
@@ -30,7 +30,6 @@ pipeline = StableDiffusionInpaintPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-inpainting",
|
||||
torch_dtype=torch.float16,
|
||||
use_safetensors=True,
|
||||
variant="fp16",
|
||||
)
|
||||
pipeline = pipeline.to("cuda")
|
||||
```
|
||||
|
||||
@@ -15,7 +15,8 @@ specific language governing permissions and limitations under the License.
|
||||
[DreamBooth](https://arxiv.org/abs/2208.12242)는 한 주제에 대한 적은 이미지(3~5개)만으로도 stable diffusion과 같이 text-to-image 모델을 개인화할 수 있는 방법입니다. 이를 통해 모델은 다양한 장면, 포즈 및 장면(뷰)에서 피사체에 대해 맥락화(contextualized)된 이미지를 생성할 수 있습니다.
|
||||
|
||||

|
||||
<small>에서의 Dreambooth 예시 <a href="https://dreambooth.github.io">project's blog.</a></small>
|
||||
<a href="https://dreambooth.github.io">project's blog.</a></small>
|
||||
<small><a href="https://dreambooth.github.io">프로젝트 블로그</a>에서의 Dreambooth 예시</small>
|
||||
|
||||
|
||||
이 가이드는 다양한 GPU, Flax 사양에 대해 [`CompVis/stable-diffusion-v1-4`](https://huggingface.co/CompVis/stable-diffusion-v1-4) 모델로 DreamBooth를 파인튜닝하는 방법을 보여줍니다. 더 깊이 파고들어 작동 방식을 확인하는 데 관심이 있는 경우, 이 가이드에 사용된 DreamBooth의 모든 학습 스크립트를 [여기](https://github.com/huggingface/diffusers/tree/main/examples/dreambooth)에서 찾을 수 있습니다.
|
||||
@@ -471,4 +472,4 @@ image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0]
|
||||
image.save("dog-bucket.png")
|
||||
```
|
||||
|
||||
[저장된 학습 체크포인트](#inference-from-a-saved-checkpoint)에서도 추론을 실행할 수도 있습니다.
|
||||
[저장된 학습 체크포인트](#inference-from-a-saved-checkpoint)에서도 추론을 실행할 수도 있습니다.
|
||||
@@ -39,8 +39,7 @@ If a community doesn't work as expected, please open an issue and ping the autho
|
||||
| CLIP Guided Images Mixing Stable Diffusion Pipeline | Сombine images using usual diffusion models. | [CLIP Guided Images Mixing Using Stable Diffusion](#clip-guided-images-mixing-with-stable-diffusion) | - | [Karachev Denis](https://github.com/TheDenk) |
|
||||
| TensorRT Stable Diffusion Inpainting Pipeline | Accelerates the Stable Diffusion Inpainting Pipeline using TensorRT | [TensorRT Stable Diffusion Inpainting Pipeline](#tensorrt-inpainting-stable-diffusion-pipeline) | - | [Asfiya Baig](https://github.com/asfiyab-nvidia) |
|
||||
| IADB Pipeline | Implementation of [Iterative α-(de)Blending: a Minimalist Deterministic Diffusion Model](https://arxiv.org/abs/2305.03486) | [IADB Pipeline](#iadb-pipeline) | - | [Thomas Chambon](https://github.com/tchambon)
|
||||
| Zero1to3 Pipeline | Implementation of [Zero-1-to-3: Zero-shot One Image to 3D Object](https://arxiv.org/abs/2303.11328) | [Zero1to3 Pipeline](#Zero1to3-pipeline) | - | [Xin Kong](https://github.com/kxhit) |
|
||||
Stable Diffusion XL Long Weighted Prompt Pipeline | A pipeline support unlimited length of prompt and negative prompt, use A1111 style of prompt weighting | [Stable Diffusion XL Long Weighted Prompt Pipeline](#stable-diffusion-xl-long-weighted-prompt-pipeline) | - | [Andrew Zhu](https://xhinker.medium.com/) |
|
||||
| Zero1to3 Pipeline | Implementation of [Zero-1-to-3: Zero-shot One Image to 3D Object](https://arxiv.org/abs/2303.11328) | [Zero1to3 Pipeline](#Zero1to3-pipeline) | - | [Xin Kong](https://github.com/kxhit)
|
||||
|
||||
|
||||
To load a custom pipeline you just need to pass the `custom_pipeline` argument to `DiffusionPipeline`, as one of the files in `diffusers/examples/community`. Feel free to send a PR with your own pipelines, we will merge them quickly.
|
||||
@@ -1530,44 +1529,6 @@ CLIP guided stable diffusion images mixing pipline allows to combine two images
|
||||
This approach is using (optional) CoCa model to avoid writing image description.
|
||||
[More code examples](https://github.com/TheDenk/images_mixing)
|
||||
|
||||
|
||||
### Stable Diffusion XL Long Weighted Prompt Pipeline
|
||||
|
||||
This SDXL pipeline support unlimted length prompt and negative prompt, compatible with A1111 prompt weighted style.
|
||||
|
||||
You can provide both `prompt` and `prompt_2`. if only one prompt is provided, `prompt_2` will be a copy of the provided `prompt`. Here is a sample code to use this pipeline.
|
||||
|
||||
```python
|
||||
from diffusers import DiffusionPipeline
|
||||
import torch
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0"
|
||||
, torch_dtype = torch.float16
|
||||
, use_safetensors = True
|
||||
, variant = "fp16"
|
||||
, custom_pipeline = "lpw_stable_diffusion_xl",
|
||||
)
|
||||
|
||||
prompt = "photo of a cute (white) cat running on the grass"*20
|
||||
prompt2 = "chasing (birds:1.5)"*20
|
||||
prompt = f"{prompt},{prompt2}"
|
||||
neg_prompt = "blur, low quality, carton, animate"
|
||||
|
||||
pipe.to("cuda")
|
||||
images = pipe(
|
||||
prompt = prompt
|
||||
, negative_prompt = neg_prompt
|
||||
).images[0]
|
||||
|
||||
pipe.to("cpu")
|
||||
torch.cuda.empty_cache()
|
||||
images
|
||||
```
|
||||
|
||||
In the above code, the `prompt2` is appended to the `prompt`, which is more than 77 tokens. "birds" are showing up in the result.
|
||||

|
||||
|
||||
## Example Images Mixing (with CoCa)
|
||||
```python
|
||||
import requests
|
||||
@@ -1889,69 +1850,3 @@ for obj in range(bs):
|
||||
|
||||
```
|
||||
|
||||
### Stable Diffusion XL Reference
|
||||
|
||||
This pipeline uses the Reference . Refer to the [stable_diffusion_reference](https://github.com/huggingface/diffusers/blob/main/examples/community/README.md#stable-diffusion-reference).
|
||||
|
||||
|
||||
```py
|
||||
import torch
|
||||
from PIL import Image
|
||||
from diffusers.utils import load_image
|
||||
from diffusers import DiffusionPipeline
|
||||
from diffusers.schedulers import UniPCMultistepScheduler
|
||||
input_image = load_image("https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png")
|
||||
|
||||
# pipe = DiffusionPipeline.from_pretrained(
|
||||
# "stabilityai/stable-diffusion-xl-base-1.0",
|
||||
# custom_pipeline="stable_diffusion_xl_reference",
|
||||
# torch_dtype=torch.float16,
|
||||
# use_safetensors=True,
|
||||
# variant="fp16").to('cuda:0')
|
||||
|
||||
pipe = StableDiffusionXLReferencePipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
torch_dtype=torch.float16,
|
||||
use_safetensors=True,
|
||||
variant="fp16").to('cuda:0')
|
||||
|
||||
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
|
||||
|
||||
result_img = pipe(ref_image=input_image,
|
||||
prompt="1girl",
|
||||
num_inference_steps=20,
|
||||
reference_attn=True,
|
||||
reference_adain=True).images[0]
|
||||
```
|
||||
|
||||
Reference Image
|
||||
|
||||

|
||||
|
||||
Output Image
|
||||
|
||||
`prompt: 1 girl`
|
||||
|
||||
`reference_attn=True, reference_adain=True, num_inference_steps=20`
|
||||

|
||||
|
||||
Reference Image
|
||||

|
||||
|
||||
|
||||
Output Image
|
||||
|
||||
`prompt: A dog`
|
||||
|
||||
`reference_attn=True, reference_adain=False, num_inference_steps=20`
|
||||

|
||||
|
||||
Reference Image
|
||||

|
||||
|
||||
Output Image
|
||||
|
||||
`prompt: An astronaut riding a lion`
|
||||
|
||||
`reference_attn=True, reference_adain=True, num_inference_steps=20`
|
||||

|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,799 +0,0 @@
|
||||
# Based on stable_diffusion_reference.py
|
||||
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
import torch
|
||||
|
||||
from diffusers import StableDiffusionXLPipeline
|
||||
from diffusers.models.attention import BasicTransformerBlock
|
||||
from diffusers.models.unet_2d_blocks import (
|
||||
CrossAttnDownBlock2D,
|
||||
CrossAttnUpBlock2D,
|
||||
DownBlock2D,
|
||||
UpBlock2D,
|
||||
)
|
||||
from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput
|
||||
from diffusers.utils import PIL_INTERPOLATION, logging, randn_tensor
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> import torch
|
||||
>>> from diffusers import UniPCMultistepScheduler
|
||||
>>> from diffusers.utils import load_image
|
||||
|
||||
>>> input_image = load_image("https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png")
|
||||
|
||||
>>> pipe = StableDiffusionXLReferencePipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
torch_dtype=torch.float16,
|
||||
use_safetensors=True,
|
||||
variant="fp16").to('cuda:0')
|
||||
|
||||
>>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
|
||||
>>> result_img = pipe(ref_image=input_image,
|
||||
prompt="1girl",
|
||||
num_inference_steps=20,
|
||||
reference_attn=True,
|
||||
reference_adain=True).images[0]
|
||||
|
||||
>>> result_img.show()
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
def torch_dfs(model: torch.nn.Module):
|
||||
result = [model]
|
||||
for child in model.children():
|
||||
result += torch_dfs(child)
|
||||
return result
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
|
||||
|
||||
|
||||
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
||||
"""
|
||||
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
|
||||
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
|
||||
"""
|
||||
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
|
||||
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
|
||||
# rescale the results from guidance (fixes overexposure)
|
||||
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
|
||||
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
|
||||
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
|
||||
return noise_cfg
|
||||
|
||||
|
||||
class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline):
|
||||
def _default_height_width(self, height, width, image):
|
||||
# NOTE: It is possible that a list of images have different
|
||||
# dimensions for each image, so just checking the first image
|
||||
# is not _exactly_ correct, but it is simple.
|
||||
while isinstance(image, list):
|
||||
image = image[0]
|
||||
|
||||
if height is None:
|
||||
if isinstance(image, PIL.Image.Image):
|
||||
height = image.height
|
||||
elif isinstance(image, torch.Tensor):
|
||||
height = image.shape[2]
|
||||
|
||||
height = (height // 8) * 8 # round down to nearest multiple of 8
|
||||
|
||||
if width is None:
|
||||
if isinstance(image, PIL.Image.Image):
|
||||
width = image.width
|
||||
elif isinstance(image, torch.Tensor):
|
||||
width = image.shape[3]
|
||||
|
||||
width = (width // 8) * 8
|
||||
|
||||
return height, width
|
||||
|
||||
def prepare_image(
|
||||
self,
|
||||
image,
|
||||
width,
|
||||
height,
|
||||
batch_size,
|
||||
num_images_per_prompt,
|
||||
device,
|
||||
dtype,
|
||||
do_classifier_free_guidance=False,
|
||||
guess_mode=False,
|
||||
):
|
||||
if not isinstance(image, torch.Tensor):
|
||||
if isinstance(image, PIL.Image.Image):
|
||||
image = [image]
|
||||
|
||||
if isinstance(image[0], PIL.Image.Image):
|
||||
images = []
|
||||
|
||||
for image_ in image:
|
||||
image_ = image_.convert("RGB")
|
||||
image_ = image_.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
|
||||
image_ = np.array(image_)
|
||||
image_ = image_[None, :]
|
||||
images.append(image_)
|
||||
|
||||
image = images
|
||||
|
||||
image = np.concatenate(image, axis=0)
|
||||
image = np.array(image).astype(np.float32) / 255.0
|
||||
image = (image - 0.5) / 0.5
|
||||
image = image.transpose(0, 3, 1, 2)
|
||||
image = torch.from_numpy(image)
|
||||
|
||||
elif isinstance(image[0], torch.Tensor):
|
||||
image = torch.stack(image, dim=0)
|
||||
|
||||
image_batch_size = image.shape[0]
|
||||
|
||||
if image_batch_size == 1:
|
||||
repeat_by = batch_size
|
||||
else:
|
||||
repeat_by = num_images_per_prompt
|
||||
|
||||
image = image.repeat_interleave(repeat_by, dim=0)
|
||||
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
|
||||
if do_classifier_free_guidance and not guess_mode:
|
||||
image = torch.cat([image] * 2)
|
||||
|
||||
return image
|
||||
|
||||
def prepare_ref_latents(self, refimage, batch_size, dtype, device, generator, do_classifier_free_guidance):
|
||||
refimage = refimage.to(device=device)
|
||||
if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
|
||||
self.upcast_vae()
|
||||
refimage = refimage.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
|
||||
if refimage.dtype != self.vae.dtype:
|
||||
refimage = refimage.to(dtype=self.vae.dtype)
|
||||
# encode the mask image into latents space so we can concatenate it to the latents
|
||||
if isinstance(generator, list):
|
||||
ref_image_latents = [
|
||||
self.vae.encode(refimage[i : i + 1]).latent_dist.sample(generator=generator[i])
|
||||
for i in range(batch_size)
|
||||
]
|
||||
ref_image_latents = torch.cat(ref_image_latents, dim=0)
|
||||
else:
|
||||
ref_image_latents = self.vae.encode(refimage).latent_dist.sample(generator=generator)
|
||||
ref_image_latents = self.vae.config.scaling_factor * ref_image_latents
|
||||
|
||||
# duplicate mask and ref_image_latents for each generation per prompt, using mps friendly method
|
||||
if ref_image_latents.shape[0] < batch_size:
|
||||
if not batch_size % ref_image_latents.shape[0] == 0:
|
||||
raise ValueError(
|
||||
"The passed images and the required batch size don't match. Images are supposed to be duplicated"
|
||||
f" to a total batch size of {batch_size}, but {ref_image_latents.shape[0]} images were passed."
|
||||
" Make sure the number of images that you pass is divisible by the total requested batch size."
|
||||
)
|
||||
ref_image_latents = ref_image_latents.repeat(batch_size // ref_image_latents.shape[0], 1, 1, 1)
|
||||
|
||||
ref_image_latents = torch.cat([ref_image_latents] * 2) if do_classifier_free_guidance else ref_image_latents
|
||||
|
||||
# aligning device to prevent device errors when concating it with the latent model input
|
||||
ref_image_latents = ref_image_latents.to(device=device, dtype=dtype)
|
||||
return ref_image_latents
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
prompt_2: Optional[Union[str, List[str]]] = None,
|
||||
ref_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
num_inference_steps: int = 50,
|
||||
denoising_end: Optional[float] = None,
|
||||
guidance_scale: float = 5.0,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
callback_steps: int = 1,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
guidance_rescale: float = 0.0,
|
||||
original_size: Optional[Tuple[int, int]] = None,
|
||||
crops_coords_top_left: Tuple[int, int] = (0, 0),
|
||||
target_size: Optional[Tuple[int, int]] = None,
|
||||
attention_auto_machine_weight: float = 1.0,
|
||||
gn_auto_machine_weight: float = 1.0,
|
||||
style_fidelity: float = 0.5,
|
||||
reference_attn: bool = True,
|
||||
reference_adain: bool = True,
|
||||
):
|
||||
assert reference_attn or reference_adain, "`reference_attn` or `reference_adain` must be True."
|
||||
|
||||
# 0. Default height and width to unet
|
||||
# height, width = self._default_height_width(height, width, ref_image)
|
||||
|
||||
height = height or self.default_sample_size * self.vae_scale_factor
|
||||
width = width or self.default_sample_size * self.vae_scale_factor
|
||||
original_size = original_size or (height, width)
|
||||
target_size = target_size or (height, width)
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
prompt_2,
|
||||
height,
|
||||
width,
|
||||
callback_steps,
|
||||
negative_prompt,
|
||||
negative_prompt_2,
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds,
|
||||
)
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
# 3. Encode input prompt
|
||||
text_encoder_lora_scale = (
|
||||
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
|
||||
)
|
||||
(
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds,
|
||||
) = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
prompt_2=prompt_2,
|
||||
device=device,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
negative_prompt=negative_prompt,
|
||||
negative_prompt_2=negative_prompt_2,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
pooled_prompt_embeds=pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
||||
lora_scale=text_encoder_lora_scale,
|
||||
)
|
||||
# 4. Preprocess reference image
|
||||
ref_image = self.prepare_image(
|
||||
image=ref_image,
|
||||
width=width,
|
||||
height=height,
|
||||
batch_size=batch_size * num_images_per_prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
device=device,
|
||||
dtype=prompt_embeds.dtype,
|
||||
)
|
||||
|
||||
# 5. Prepare timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# 6. Prepare latent variables
|
||||
num_channels_latents = self.unet.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds.dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
# 7. Prepare reference latent variables
|
||||
ref_image_latents = self.prepare_ref_latents(
|
||||
ref_image,
|
||||
batch_size * num_images_per_prompt,
|
||||
prompt_embeds.dtype,
|
||||
device,
|
||||
generator,
|
||||
do_classifier_free_guidance,
|
||||
)
|
||||
|
||||
# 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 9. Modify self attebtion and group norm
|
||||
MODE = "write"
|
||||
uc_mask = (
|
||||
torch.Tensor([1] * batch_size * num_images_per_prompt + [0] * batch_size * num_images_per_prompt)
|
||||
.type_as(ref_image_latents)
|
||||
.bool()
|
||||
)
|
||||
|
||||
def hacked_basic_transformer_inner_forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
timestep: Optional[torch.LongTensor] = None,
|
||||
cross_attention_kwargs: Dict[str, Any] = None,
|
||||
class_labels: Optional[torch.LongTensor] = None,
|
||||
):
|
||||
if self.use_ada_layer_norm:
|
||||
norm_hidden_states = self.norm1(hidden_states, timestep)
|
||||
elif self.use_ada_layer_norm_zero:
|
||||
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
|
||||
hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
|
||||
)
|
||||
else:
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
|
||||
# 1. Self-Attention
|
||||
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
|
||||
if self.only_cross_attention:
|
||||
attn_output = self.attn1(
|
||||
norm_hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
|
||||
attention_mask=attention_mask,
|
||||
**cross_attention_kwargs,
|
||||
)
|
||||
else:
|
||||
if MODE == "write":
|
||||
self.bank.append(norm_hidden_states.detach().clone())
|
||||
attn_output = self.attn1(
|
||||
norm_hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
|
||||
attention_mask=attention_mask,
|
||||
**cross_attention_kwargs,
|
||||
)
|
||||
if MODE == "read":
|
||||
if attention_auto_machine_weight > self.attn_weight:
|
||||
attn_output_uc = self.attn1(
|
||||
norm_hidden_states,
|
||||
encoder_hidden_states=torch.cat([norm_hidden_states] + self.bank, dim=1),
|
||||
# attention_mask=attention_mask,
|
||||
**cross_attention_kwargs,
|
||||
)
|
||||
attn_output_c = attn_output_uc.clone()
|
||||
if do_classifier_free_guidance and style_fidelity > 0:
|
||||
attn_output_c[uc_mask] = self.attn1(
|
||||
norm_hidden_states[uc_mask],
|
||||
encoder_hidden_states=norm_hidden_states[uc_mask],
|
||||
**cross_attention_kwargs,
|
||||
)
|
||||
attn_output = style_fidelity * attn_output_c + (1.0 - style_fidelity) * attn_output_uc
|
||||
self.bank.clear()
|
||||
else:
|
||||
attn_output = self.attn1(
|
||||
norm_hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
|
||||
attention_mask=attention_mask,
|
||||
**cross_attention_kwargs,
|
||||
)
|
||||
if self.use_ada_layer_norm_zero:
|
||||
attn_output = gate_msa.unsqueeze(1) * attn_output
|
||||
hidden_states = attn_output + hidden_states
|
||||
|
||||
if self.attn2 is not None:
|
||||
norm_hidden_states = (
|
||||
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
|
||||
)
|
||||
|
||||
# 2. Cross-Attention
|
||||
attn_output = self.attn2(
|
||||
norm_hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=encoder_attention_mask,
|
||||
**cross_attention_kwargs,
|
||||
)
|
||||
hidden_states = attn_output + hidden_states
|
||||
|
||||
# 3. Feed-forward
|
||||
norm_hidden_states = self.norm3(hidden_states)
|
||||
|
||||
if self.use_ada_layer_norm_zero:
|
||||
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
||||
|
||||
ff_output = self.ff(norm_hidden_states)
|
||||
|
||||
if self.use_ada_layer_norm_zero:
|
||||
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
||||
|
||||
hidden_states = ff_output + hidden_states
|
||||
|
||||
return hidden_states
|
||||
|
||||
def hacked_mid_forward(self, *args, **kwargs):
|
||||
eps = 1e-6
|
||||
x = self.original_forward(*args, **kwargs)
|
||||
if MODE == "write":
|
||||
if gn_auto_machine_weight >= self.gn_weight:
|
||||
var, mean = torch.var_mean(x, dim=(2, 3), keepdim=True, correction=0)
|
||||
self.mean_bank.append(mean)
|
||||
self.var_bank.append(var)
|
||||
if MODE == "read":
|
||||
if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
|
||||
var, mean = torch.var_mean(x, dim=(2, 3), keepdim=True, correction=0)
|
||||
std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
|
||||
mean_acc = sum(self.mean_bank) / float(len(self.mean_bank))
|
||||
var_acc = sum(self.var_bank) / float(len(self.var_bank))
|
||||
std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
|
||||
x_uc = (((x - mean) / std) * std_acc) + mean_acc
|
||||
x_c = x_uc.clone()
|
||||
if do_classifier_free_guidance and style_fidelity > 0:
|
||||
x_c[uc_mask] = x[uc_mask]
|
||||
x = style_fidelity * x_c + (1.0 - style_fidelity) * x_uc
|
||||
self.mean_bank = []
|
||||
self.var_bank = []
|
||||
return x
|
||||
|
||||
def hack_CrossAttnDownBlock2D_forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
temb: Optional[torch.FloatTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
):
|
||||
eps = 1e-6
|
||||
|
||||
# TODO(Patrick, William) - attention mask is not used
|
||||
output_states = ()
|
||||
|
||||
for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)):
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
hidden_states = attn(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
attention_mask=attention_mask,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
if MODE == "write":
|
||||
if gn_auto_machine_weight >= self.gn_weight:
|
||||
var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
|
||||
self.mean_bank.append([mean])
|
||||
self.var_bank.append([var])
|
||||
if MODE == "read":
|
||||
if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
|
||||
var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
|
||||
std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
|
||||
mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i]))
|
||||
var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i]))
|
||||
std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
|
||||
hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc
|
||||
hidden_states_c = hidden_states_uc.clone()
|
||||
if do_classifier_free_guidance and style_fidelity > 0:
|
||||
hidden_states_c[uc_mask] = hidden_states[uc_mask]
|
||||
hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc
|
||||
|
||||
output_states = output_states + (hidden_states,)
|
||||
|
||||
if MODE == "read":
|
||||
self.mean_bank = []
|
||||
self.var_bank = []
|
||||
|
||||
if self.downsamplers is not None:
|
||||
for downsampler in self.downsamplers:
|
||||
hidden_states = downsampler(hidden_states)
|
||||
|
||||
output_states = output_states + (hidden_states,)
|
||||
|
||||
return hidden_states, output_states
|
||||
|
||||
def hacked_DownBlock2D_forward(self, hidden_states, temb=None):
|
||||
eps = 1e-6
|
||||
|
||||
output_states = ()
|
||||
|
||||
for i, resnet in enumerate(self.resnets):
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
if MODE == "write":
|
||||
if gn_auto_machine_weight >= self.gn_weight:
|
||||
var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
|
||||
self.mean_bank.append([mean])
|
||||
self.var_bank.append([var])
|
||||
if MODE == "read":
|
||||
if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
|
||||
var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
|
||||
std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
|
||||
mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i]))
|
||||
var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i]))
|
||||
std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
|
||||
hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc
|
||||
hidden_states_c = hidden_states_uc.clone()
|
||||
if do_classifier_free_guidance and style_fidelity > 0:
|
||||
hidden_states_c[uc_mask] = hidden_states[uc_mask]
|
||||
hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc
|
||||
|
||||
output_states = output_states + (hidden_states,)
|
||||
|
||||
if MODE == "read":
|
||||
self.mean_bank = []
|
||||
self.var_bank = []
|
||||
|
||||
if self.downsamplers is not None:
|
||||
for downsampler in self.downsamplers:
|
||||
hidden_states = downsampler(hidden_states)
|
||||
|
||||
output_states = output_states + (hidden_states,)
|
||||
|
||||
return hidden_states, output_states
|
||||
|
||||
def hacked_CrossAttnUpBlock2D_forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
|
||||
temb: Optional[torch.FloatTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
upsample_size: Optional[int] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
):
|
||||
eps = 1e-6
|
||||
# TODO(Patrick, William) - attention mask is not used
|
||||
for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)):
|
||||
# pop res hidden states
|
||||
res_hidden_states = res_hidden_states_tuple[-1]
|
||||
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
||||
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
hidden_states = attn(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
attention_mask=attention_mask,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
if MODE == "write":
|
||||
if gn_auto_machine_weight >= self.gn_weight:
|
||||
var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
|
||||
self.mean_bank.append([mean])
|
||||
self.var_bank.append([var])
|
||||
if MODE == "read":
|
||||
if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
|
||||
var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
|
||||
std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
|
||||
mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i]))
|
||||
var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i]))
|
||||
std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
|
||||
hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc
|
||||
hidden_states_c = hidden_states_uc.clone()
|
||||
if do_classifier_free_guidance and style_fidelity > 0:
|
||||
hidden_states_c[uc_mask] = hidden_states[uc_mask]
|
||||
hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc
|
||||
|
||||
if MODE == "read":
|
||||
self.mean_bank = []
|
||||
self.var_bank = []
|
||||
|
||||
if self.upsamplers is not None:
|
||||
for upsampler in self.upsamplers:
|
||||
hidden_states = upsampler(hidden_states, upsample_size)
|
||||
|
||||
return hidden_states
|
||||
|
||||
def hacked_UpBlock2D_forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
|
||||
eps = 1e-6
|
||||
for i, resnet in enumerate(self.resnets):
|
||||
# pop res hidden states
|
||||
res_hidden_states = res_hidden_states_tuple[-1]
|
||||
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
||||
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
if MODE == "write":
|
||||
if gn_auto_machine_weight >= self.gn_weight:
|
||||
var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
|
||||
self.mean_bank.append([mean])
|
||||
self.var_bank.append([var])
|
||||
if MODE == "read":
|
||||
if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
|
||||
var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
|
||||
std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
|
||||
mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i]))
|
||||
var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i]))
|
||||
std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
|
||||
hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc
|
||||
hidden_states_c = hidden_states_uc.clone()
|
||||
if do_classifier_free_guidance and style_fidelity > 0:
|
||||
hidden_states_c[uc_mask] = hidden_states[uc_mask]
|
||||
hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc
|
||||
|
||||
if MODE == "read":
|
||||
self.mean_bank = []
|
||||
self.var_bank = []
|
||||
|
||||
if self.upsamplers is not None:
|
||||
for upsampler in self.upsamplers:
|
||||
hidden_states = upsampler(hidden_states, upsample_size)
|
||||
|
||||
return hidden_states
|
||||
|
||||
if reference_attn:
|
||||
attn_modules = [module for module in torch_dfs(self.unet) if isinstance(module, BasicTransformerBlock)]
|
||||
attn_modules = sorted(attn_modules, key=lambda x: -x.norm1.normalized_shape[0])
|
||||
|
||||
for i, module in enumerate(attn_modules):
|
||||
module._original_inner_forward = module.forward
|
||||
module.forward = hacked_basic_transformer_inner_forward.__get__(module, BasicTransformerBlock)
|
||||
module.bank = []
|
||||
module.attn_weight = float(i) / float(len(attn_modules))
|
||||
|
||||
if reference_adain:
|
||||
gn_modules = [self.unet.mid_block]
|
||||
self.unet.mid_block.gn_weight = 0
|
||||
|
||||
down_blocks = self.unet.down_blocks
|
||||
for w, module in enumerate(down_blocks):
|
||||
module.gn_weight = 1.0 - float(w) / float(len(down_blocks))
|
||||
gn_modules.append(module)
|
||||
|
||||
up_blocks = self.unet.up_blocks
|
||||
for w, module in enumerate(up_blocks):
|
||||
module.gn_weight = float(w) / float(len(up_blocks))
|
||||
gn_modules.append(module)
|
||||
|
||||
for i, module in enumerate(gn_modules):
|
||||
if getattr(module, "original_forward", None) is None:
|
||||
module.original_forward = module.forward
|
||||
if i == 0:
|
||||
# mid_block
|
||||
module.forward = hacked_mid_forward.__get__(module, torch.nn.Module)
|
||||
elif isinstance(module, CrossAttnDownBlock2D):
|
||||
module.forward = hack_CrossAttnDownBlock2D_forward.__get__(module, CrossAttnDownBlock2D)
|
||||
elif isinstance(module, DownBlock2D):
|
||||
module.forward = hacked_DownBlock2D_forward.__get__(module, DownBlock2D)
|
||||
elif isinstance(module, CrossAttnUpBlock2D):
|
||||
module.forward = hacked_CrossAttnUpBlock2D_forward.__get__(module, CrossAttnUpBlock2D)
|
||||
elif isinstance(module, UpBlock2D):
|
||||
module.forward = hacked_UpBlock2D_forward.__get__(module, UpBlock2D)
|
||||
module.mean_bank = []
|
||||
module.var_bank = []
|
||||
module.gn_weight *= 2
|
||||
|
||||
# 10. Prepare added time ids & embeddings
|
||||
add_text_embeds = pooled_prompt_embeds
|
||||
add_time_ids = self._get_add_time_ids(
|
||||
original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype
|
||||
)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
|
||||
add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)
|
||||
|
||||
prompt_embeds = prompt_embeds.to(device)
|
||||
add_text_embeds = add_text_embeds.to(device)
|
||||
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
|
||||
|
||||
# 11. Denoising loop
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
|
||||
# 10.1 Apply denoising_end
|
||||
if denoising_end is not None and type(denoising_end) == float and denoising_end > 0 and denoising_end < 1:
|
||||
discrete_timestep_cutoff = int(
|
||||
round(
|
||||
self.scheduler.config.num_train_timesteps
|
||||
- (denoising_end * self.scheduler.config.num_train_timesteps)
|
||||
)
|
||||
)
|
||||
num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
|
||||
timesteps = timesteps[:num_inference_steps]
|
||||
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
|
||||
|
||||
# ref only part
|
||||
noise = randn_tensor(
|
||||
ref_image_latents.shape, generator=generator, device=device, dtype=ref_image_latents.dtype
|
||||
)
|
||||
ref_xt = self.scheduler.add_noise(
|
||||
ref_image_latents,
|
||||
noise,
|
||||
t.reshape(
|
||||
1,
|
||||
),
|
||||
)
|
||||
ref_xt = self.scheduler.scale_model_input(ref_xt, t)
|
||||
|
||||
MODE = "write"
|
||||
|
||||
self.unet(
|
||||
ref_xt,
|
||||
t,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
return_dict=False,
|
||||
)
|
||||
|
||||
# predict the noise residual
|
||||
MODE = "read"
|
||||
noise_pred = self.unet(
|
||||
latent_model_input,
|
||||
t,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
if do_classifier_free_guidance and guidance_rescale > 0.0:
|
||||
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
||||
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
|
||||
# make sure the VAE is in float32 mode, as it overflows in float16
|
||||
if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
|
||||
self.upcast_vae()
|
||||
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
|
||||
|
||||
if not output_type == "latent":
|
||||
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
||||
else:
|
||||
image = latents
|
||||
return StableDiffusionXLPipelineOutput(images=image)
|
||||
|
||||
# apply watermark if available
|
||||
if self.watermark is not None:
|
||||
image = self.watermark.apply_watermark(image)
|
||||
|
||||
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||
|
||||
# Offload last model to CPU
|
||||
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
||||
self.final_offload_hook.offload()
|
||||
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
return StableDiffusionXLPipelineOutput(images=image)
|
||||
@@ -56,7 +56,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.21.0.dev0")
|
||||
check_min_version("0.20.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -59,7 +59,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.21.0.dev0")
|
||||
check_min_version("0.20.0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -58,7 +58,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.21.0.dev0")
|
||||
check_min_version("0.20.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -58,7 +58,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.21.0.dev0")
|
||||
check_min_version("0.20.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -673,8 +673,6 @@ likely the learning rate can be increased with larger batch sizes.
|
||||
|
||||
Using 8bit adam and a batch size of 4, the model can be trained in ~48 GB VRAM.
|
||||
|
||||
`--validation_scheduler`: Set a particular scheduler via a string. We found that it is better to use the DDPMScheduler for validation when training DeepFloyd IF.
|
||||
|
||||
```sh
|
||||
export MODEL_NAME="DeepFloyd/IF-I-XL-v1.0"
|
||||
|
||||
@@ -699,7 +697,6 @@ accelerate launch train_dreambooth.py \
|
||||
--use_8bit_adam \
|
||||
--set_grads_to_none \
|
||||
--skip_save_text_encoder \
|
||||
--validation_scheduler DDPMScheduler \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
@@ -738,7 +735,6 @@ accelerate launch train_dreambooth.py \
|
||||
--text_encoder_use_attention_mask \
|
||||
--validation_images $VALIDATION_IMAGES \
|
||||
--class_labels_conditioning timesteps \
|
||||
--validation_scheduler DDPMScheduler\
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
|
||||
@@ -17,7 +17,6 @@ import argparse
|
||||
import copy
|
||||
import gc
|
||||
import hashlib
|
||||
import importlib
|
||||
import itertools
|
||||
import logging
|
||||
import math
|
||||
@@ -48,6 +47,7 @@ from diffusers import (
|
||||
AutoencoderKL,
|
||||
DDPMScheduler,
|
||||
DiffusionPipeline,
|
||||
DPMSolverMultistepScheduler,
|
||||
StableDiffusionPipeline,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
@@ -60,7 +60,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.21.0.dev0")
|
||||
check_min_version("0.20.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -153,9 +153,7 @@ def log_validation(
|
||||
|
||||
scheduler_args["variance_type"] = variance_type
|
||||
|
||||
module = importlib.import_module("diffusers")
|
||||
scheduler_class = getattr(module, args.validation_scheduler)
|
||||
pipeline.scheduler = scheduler_class.from_config(pipeline.scheduler.config, **scheduler_args)
|
||||
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args)
|
||||
pipeline = pipeline.to(accelerator.device)
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
|
||||
@@ -558,13 +556,6 @@ def parse_args(input_args=None):
|
||||
default=None,
|
||||
help="The optional `class_label` conditioning to pass to the unet, available values are `timesteps`.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--validation_scheduler",
|
||||
type=str,
|
||||
default="DPMSolverMultistepScheduler",
|
||||
choices=["DPMSolverMultistepScheduler", "DDPMScheduler"],
|
||||
help="Select which scheduler to use for validation. DDPMScheduler is recommended for DeepFloyd IF.",
|
||||
)
|
||||
|
||||
if input_args is not None:
|
||||
args = parser.parse_args(input_args)
|
||||
|
||||
@@ -36,7 +36,7 @@ from diffusers.utils import check_min_version
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.21.0.dev0")
|
||||
check_min_version("0.20.0")
|
||||
|
||||
# Cache compiled models across invocations of this script.
|
||||
cc.initialize_cache(os.path.expanduser("~/.cache/jax/compilation_cache"))
|
||||
|
||||
@@ -70,7 +70,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.21.0.dev0")
|
||||
check_min_version("0.20.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -58,7 +58,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.21.0.dev0")
|
||||
check_min_version("0.20.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -52,7 +52,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.21.0.dev0")
|
||||
check_min_version("0.20.0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -55,7 +55,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.21.0.dev0")
|
||||
check_min_version("0.20.0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -828,87 +828,6 @@ class ExamplesTestsAccelerate(unittest.TestCase):
|
||||
{"checkpoint-4", "checkpoint-6"},
|
||||
)
|
||||
|
||||
def test_text_to_image_lora_sdxl_checkpointing_checkpoints_total_limit(self):
|
||||
prompt = "a prompt"
|
||||
pipeline_path = "hf-internal-testing/tiny-stable-diffusion-xl-pipe"
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Run training script with checkpointing
|
||||
# max_train_steps == 7, checkpointing_steps == 2, checkpoints_total_limit == 2
|
||||
# Should create checkpoints at steps 2, 4, 6
|
||||
# with checkpoint at step 2 deleted
|
||||
|
||||
initial_run_args = f"""
|
||||
examples/text_to_image/train_text_to_image_lora_sdxl.py
|
||||
--pretrained_model_name_or_path {pipeline_path}
|
||||
--dataset_name hf-internal-testing/dummy_image_text_data
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 7
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
--checkpointing_steps=2
|
||||
--checkpoints_total_limit=2
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + initial_run_args)
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained(pipeline_path)
|
||||
pipe.load_lora_weights(tmpdir)
|
||||
pipe(prompt, num_inference_steps=2)
|
||||
|
||||
# check checkpoint directories exist
|
||||
self.assertEqual(
|
||||
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
|
||||
# checkpoint-2 should have been deleted
|
||||
{"checkpoint-4", "checkpoint-6"},
|
||||
)
|
||||
|
||||
def test_text_to_image_lora_sdxl_text_encoder_checkpointing_checkpoints_total_limit(self):
|
||||
prompt = "a prompt"
|
||||
pipeline_path = "hf-internal-testing/tiny-stable-diffusion-xl-pipe"
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Run training script with checkpointing
|
||||
# max_train_steps == 7, checkpointing_steps == 2, checkpoints_total_limit == 2
|
||||
# Should create checkpoints at steps 2, 4, 6
|
||||
# with checkpoint at step 2 deleted
|
||||
|
||||
initial_run_args = f"""
|
||||
examples/text_to_image/train_text_to_image_lora_sdxl.py
|
||||
--pretrained_model_name_or_path {pipeline_path}
|
||||
--dataset_name hf-internal-testing/dummy_image_text_data
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 7
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--train_text_encoder
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
--checkpointing_steps=2
|
||||
--checkpoints_total_limit=2
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + initial_run_args)
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained(pipeline_path)
|
||||
pipe.load_lora_weights(tmpdir)
|
||||
pipe(prompt, num_inference_steps=2)
|
||||
|
||||
# check checkpoint directories exist
|
||||
self.assertEqual(
|
||||
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
|
||||
# checkpoint-2 should have been deleted
|
||||
{"checkpoint-4", "checkpoint-6"},
|
||||
)
|
||||
|
||||
def test_text_to_image_lora_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe"
|
||||
prompt = "a prompt"
|
||||
|
||||
@@ -53,7 +53,7 @@ if is_wandb_available():
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.21.0.dev0")
|
||||
check_min_version("0.20.0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -33,7 +33,7 @@ from diffusers.utils import check_min_version
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.21.0.dev0")
|
||||
check_min_version("0.20.0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -48,7 +48,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.21.0.dev0")
|
||||
check_min_version("0.20.0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -57,7 +57,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.21.0.dev0")
|
||||
check_min_version("0.20.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -396,6 +396,16 @@ def parse_args(input_args=None):
|
||||
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prior_generation_precision",
|
||||
type=str,
|
||||
default=None,
|
||||
choices=["no", "fp32", "fp16", "bf16"],
|
||||
help=(
|
||||
"Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
|
||||
" 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32."
|
||||
),
|
||||
)
|
||||
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
||||
parser.add_argument(
|
||||
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
|
||||
@@ -714,15 +724,11 @@ def main(args):
|
||||
|
||||
lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir)
|
||||
LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_)
|
||||
|
||||
text_encoder_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder." in k}
|
||||
LoraLoaderMixin.load_lora_into_text_encoder(
|
||||
text_encoder_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_one_
|
||||
lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_one_
|
||||
)
|
||||
|
||||
text_encoder_2_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder_2." in k}
|
||||
LoraLoaderMixin.load_lora_into_text_encoder(
|
||||
text_encoder_2_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_two_
|
||||
lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_two_
|
||||
)
|
||||
|
||||
accelerator.register_save_state_pre_hook(save_model_hook)
|
||||
@@ -996,12 +1002,9 @@ def main(args):
|
||||
continue
|
||||
|
||||
with accelerator.accumulate(unet):
|
||||
# Convert images to latent space
|
||||
if args.pretrained_vae_model_name_or_path is not None:
|
||||
pixel_values = batch["pixel_values"].to(dtype=weight_dtype)
|
||||
else:
|
||||
pixel_values = batch["pixel_values"]
|
||||
pixel_values = batch["pixel_values"].to(dtype=weight_dtype)
|
||||
|
||||
# Convert images to latent space
|
||||
model_input = vae.encode(pixel_values).latent_dist.sample()
|
||||
model_input = model_input * vae.config.scaling_factor
|
||||
if args.pretrained_vae_model_name_or_path is None:
|
||||
@@ -1144,6 +1147,13 @@ def main(args):
|
||||
f" {args.validation_prompt}."
|
||||
)
|
||||
# create pipeline
|
||||
if not args.train_text_encoder:
|
||||
text_encoder_one = text_encoder_cls_one.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
|
||||
)
|
||||
text_encoder_two = text_encoder_cls_two.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision
|
||||
)
|
||||
pipeline = StableDiffusionXLPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
vae=vae,
|
||||
|
||||
@@ -57,7 +57,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.21.0.dev0")
|
||||
check_min_version("0.20.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -79,7 +79,7 @@ else:
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.21.0.dev0")
|
||||
check_min_version("0.20.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -56,7 +56,7 @@ else:
|
||||
# ------------------------------------------------------------------------------
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.21.0.dev0")
|
||||
check_min_version("0.20.0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -30,7 +30,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.21.0.dev0")
|
||||
check_min_version("0.20.0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -1,340 +0,0 @@
|
||||
# Script for converting a HF Diffusers saved pipeline to a Stable Diffusion checkpoint.
|
||||
# *Only* converts the UNet, VAE, and Text Encoder.
|
||||
# Does not convert optimizer state or any other thing.
|
||||
|
||||
import argparse
|
||||
import os.path as osp
|
||||
import re
|
||||
|
||||
import torch
|
||||
from safetensors.torch import load_file, save_file
|
||||
|
||||
|
||||
# =================#
|
||||
# UNet Conversion #
|
||||
# =================#
|
||||
|
||||
unet_conversion_map = [
|
||||
# (stable-diffusion, HF Diffusers)
|
||||
("time_embed.0.weight", "time_embedding.linear_1.weight"),
|
||||
("time_embed.0.bias", "time_embedding.linear_1.bias"),
|
||||
("time_embed.2.weight", "time_embedding.linear_2.weight"),
|
||||
("time_embed.2.bias", "time_embedding.linear_2.bias"),
|
||||
("input_blocks.0.0.weight", "conv_in.weight"),
|
||||
("input_blocks.0.0.bias", "conv_in.bias"),
|
||||
("out.0.weight", "conv_norm_out.weight"),
|
||||
("out.0.bias", "conv_norm_out.bias"),
|
||||
("out.2.weight", "conv_out.weight"),
|
||||
("out.2.bias", "conv_out.bias"),
|
||||
# the following are for sdxl
|
||||
("label_emb.0.0.weight", "add_embedding.linear_1.weight"),
|
||||
("label_emb.0.0.bias", "add_embedding.linear_1.bias"),
|
||||
("label_emb.0.2.weight", "add_embedding.linear_2.weight"),
|
||||
("label_emb.0.2.bias", "add_embedding.linear_2.bias"),
|
||||
]
|
||||
|
||||
unet_conversion_map_resnet = [
|
||||
# (stable-diffusion, HF Diffusers)
|
||||
("in_layers.0", "norm1"),
|
||||
("in_layers.2", "conv1"),
|
||||
("out_layers.0", "norm2"),
|
||||
("out_layers.3", "conv2"),
|
||||
("emb_layers.1", "time_emb_proj"),
|
||||
("skip_connection", "conv_shortcut"),
|
||||
]
|
||||
|
||||
unet_conversion_map_layer = []
|
||||
# hardcoded number of downblocks and resnets/attentions...
|
||||
# would need smarter logic for other networks.
|
||||
for i in range(3):
|
||||
# loop over downblocks/upblocks
|
||||
|
||||
for j in range(2):
|
||||
# loop over resnets/attentions for downblocks
|
||||
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
|
||||
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
|
||||
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
|
||||
|
||||
if i > 0:
|
||||
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
|
||||
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
|
||||
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
|
||||
|
||||
for j in range(4):
|
||||
# loop over resnets/attentions for upblocks
|
||||
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
|
||||
sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
|
||||
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
|
||||
|
||||
if i < 2:
|
||||
# no attention layers in up_blocks.0
|
||||
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
|
||||
sd_up_atn_prefix = f"output_blocks.{3 * i + j}.1."
|
||||
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
|
||||
|
||||
if i < 3:
|
||||
# no downsample in down_blocks.3
|
||||
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
|
||||
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
|
||||
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
|
||||
|
||||
# no upsample in up_blocks.3
|
||||
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
||||
sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."
|
||||
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
|
||||
unet_conversion_map_layer.append(("output_blocks.2.2.conv.", "output_blocks.2.1.conv."))
|
||||
|
||||
hf_mid_atn_prefix = "mid_block.attentions.0."
|
||||
sd_mid_atn_prefix = "middle_block.1."
|
||||
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
|
||||
for j in range(2):
|
||||
hf_mid_res_prefix = f"mid_block.resnets.{j}."
|
||||
sd_mid_res_prefix = f"middle_block.{2*j}."
|
||||
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
||||
|
||||
|
||||
def convert_unet_state_dict(unet_state_dict):
|
||||
# buyer beware: this is a *brittle* function,
|
||||
# and correct output requires that all of these pieces interact in
|
||||
# the exact order in which I have arranged them.
|
||||
mapping = {k: k for k in unet_state_dict.keys()}
|
||||
for sd_name, hf_name in unet_conversion_map:
|
||||
mapping[hf_name] = sd_name
|
||||
for k, v in mapping.items():
|
||||
if "resnets" in k:
|
||||
for sd_part, hf_part in unet_conversion_map_resnet:
|
||||
v = v.replace(hf_part, sd_part)
|
||||
mapping[k] = v
|
||||
for k, v in mapping.items():
|
||||
for sd_part, hf_part in unet_conversion_map_layer:
|
||||
v = v.replace(hf_part, sd_part)
|
||||
mapping[k] = v
|
||||
new_state_dict = {sd_name: unet_state_dict[hf_name] for hf_name, sd_name in mapping.items()}
|
||||
return new_state_dict
|
||||
|
||||
|
||||
# ================#
|
||||
# VAE Conversion #
|
||||
# ================#
|
||||
|
||||
vae_conversion_map = [
|
||||
# (stable-diffusion, HF Diffusers)
|
||||
("nin_shortcut", "conv_shortcut"),
|
||||
("norm_out", "conv_norm_out"),
|
||||
("mid.attn_1.", "mid_block.attentions.0."),
|
||||
]
|
||||
|
||||
for i in range(4):
|
||||
# down_blocks have two resnets
|
||||
for j in range(2):
|
||||
hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}."
|
||||
sd_down_prefix = f"encoder.down.{i}.block.{j}."
|
||||
vae_conversion_map.append((sd_down_prefix, hf_down_prefix))
|
||||
|
||||
if i < 3:
|
||||
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0."
|
||||
sd_downsample_prefix = f"down.{i}.downsample."
|
||||
vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
|
||||
|
||||
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
||||
sd_upsample_prefix = f"up.{3-i}.upsample."
|
||||
vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
|
||||
|
||||
# up_blocks have three resnets
|
||||
# also, up blocks in hf are numbered in reverse from sd
|
||||
for j in range(3):
|
||||
hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
|
||||
sd_up_prefix = f"decoder.up.{3-i}.block.{j}."
|
||||
vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
|
||||
|
||||
# this part accounts for mid blocks in both the encoder and the decoder
|
||||
for i in range(2):
|
||||
hf_mid_res_prefix = f"mid_block.resnets.{i}."
|
||||
sd_mid_res_prefix = f"mid.block_{i+1}."
|
||||
vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
||||
|
||||
|
||||
vae_conversion_map_attn = [
|
||||
# (stable-diffusion, HF Diffusers)
|
||||
("norm.", "group_norm."),
|
||||
# the following are for SDXL
|
||||
("q.", "to_q."),
|
||||
("k.", "to_k."),
|
||||
("v.", "to_v."),
|
||||
("proj_out.", "to_out.0."),
|
||||
]
|
||||
|
||||
|
||||
def reshape_weight_for_sd(w):
|
||||
# convert HF linear weights to SD conv2d weights
|
||||
return w.reshape(*w.shape, 1, 1)
|
||||
|
||||
|
||||
def convert_vae_state_dict(vae_state_dict):
|
||||
mapping = {k: k for k in vae_state_dict.keys()}
|
||||
for k, v in mapping.items():
|
||||
for sd_part, hf_part in vae_conversion_map:
|
||||
v = v.replace(hf_part, sd_part)
|
||||
mapping[k] = v
|
||||
for k, v in mapping.items():
|
||||
if "attentions" in k:
|
||||
for sd_part, hf_part in vae_conversion_map_attn:
|
||||
v = v.replace(hf_part, sd_part)
|
||||
mapping[k] = v
|
||||
new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}
|
||||
weights_to_convert = ["q", "k", "v", "proj_out"]
|
||||
for k, v in new_state_dict.items():
|
||||
for weight_name in weights_to_convert:
|
||||
if f"mid.attn_1.{weight_name}.weight" in k:
|
||||
print(f"Reshaping {k} for SD format")
|
||||
new_state_dict[k] = reshape_weight_for_sd(v)
|
||||
return new_state_dict
|
||||
|
||||
|
||||
# =========================#
|
||||
# Text Encoder Conversion #
|
||||
# =========================#
|
||||
|
||||
|
||||
textenc_conversion_lst = [
|
||||
# (stable-diffusion, HF Diffusers)
|
||||
("transformer.resblocks.", "text_model.encoder.layers."),
|
||||
("ln_1", "layer_norm1"),
|
||||
("ln_2", "layer_norm2"),
|
||||
(".c_fc.", ".fc1."),
|
||||
(".c_proj.", ".fc2."),
|
||||
(".attn", ".self_attn"),
|
||||
("ln_final.", "text_model.final_layer_norm."),
|
||||
("token_embedding.weight", "text_model.embeddings.token_embedding.weight"),
|
||||
("positional_embedding", "text_model.embeddings.position_embedding.weight"),
|
||||
]
|
||||
protected = {re.escape(x[1]): x[0] for x in textenc_conversion_lst}
|
||||
textenc_pattern = re.compile("|".join(protected.keys()))
|
||||
|
||||
# Ordering is from https://github.com/pytorch/pytorch/blob/master/test/cpp/api/modules.cpp
|
||||
code2idx = {"q": 0, "k": 1, "v": 2}
|
||||
|
||||
|
||||
def convert_openclip_text_enc_state_dict(text_enc_dict):
|
||||
new_state_dict = {}
|
||||
capture_qkv_weight = {}
|
||||
capture_qkv_bias = {}
|
||||
for k, v in text_enc_dict.items():
|
||||
if (
|
||||
k.endswith(".self_attn.q_proj.weight")
|
||||
or k.endswith(".self_attn.k_proj.weight")
|
||||
or k.endswith(".self_attn.v_proj.weight")
|
||||
):
|
||||
k_pre = k[: -len(".q_proj.weight")]
|
||||
k_code = k[-len("q_proj.weight")]
|
||||
if k_pre not in capture_qkv_weight:
|
||||
capture_qkv_weight[k_pre] = [None, None, None]
|
||||
capture_qkv_weight[k_pre][code2idx[k_code]] = v
|
||||
continue
|
||||
|
||||
if (
|
||||
k.endswith(".self_attn.q_proj.bias")
|
||||
or k.endswith(".self_attn.k_proj.bias")
|
||||
or k.endswith(".self_attn.v_proj.bias")
|
||||
):
|
||||
k_pre = k[: -len(".q_proj.bias")]
|
||||
k_code = k[-len("q_proj.bias")]
|
||||
if k_pre not in capture_qkv_bias:
|
||||
capture_qkv_bias[k_pre] = [None, None, None]
|
||||
capture_qkv_bias[k_pre][code2idx[k_code]] = v
|
||||
continue
|
||||
|
||||
relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k)
|
||||
new_state_dict[relabelled_key] = v
|
||||
|
||||
for k_pre, tensors in capture_qkv_weight.items():
|
||||
if None in tensors:
|
||||
raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
|
||||
relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
|
||||
new_state_dict[relabelled_key + ".in_proj_weight"] = torch.cat(tensors)
|
||||
|
||||
for k_pre, tensors in capture_qkv_bias.items():
|
||||
if None in tensors:
|
||||
raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
|
||||
relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
|
||||
new_state_dict[relabelled_key + ".in_proj_bias"] = torch.cat(tensors)
|
||||
|
||||
return new_state_dict
|
||||
|
||||
|
||||
def convert_openai_text_enc_state_dict(text_enc_dict):
|
||||
return text_enc_dict
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument("--model_path", default=None, type=str, required=True, help="Path to the model to convert.")
|
||||
parser.add_argument("--checkpoint_path", default=None, type=str, required=True, help="Path to the output model.")
|
||||
parser.add_argument("--half", action="store_true", help="Save weights in half precision.")
|
||||
parser.add_argument(
|
||||
"--use_safetensors", action="store_true", help="Save weights use safetensors, default is ckpt."
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
assert args.model_path is not None, "Must provide a model path!"
|
||||
|
||||
assert args.checkpoint_path is not None, "Must provide a checkpoint path!"
|
||||
|
||||
# Path for safetensors
|
||||
unet_path = osp.join(args.model_path, "unet", "diffusion_pytorch_model.safetensors")
|
||||
vae_path = osp.join(args.model_path, "vae", "diffusion_pytorch_model.safetensors")
|
||||
text_enc_path = osp.join(args.model_path, "text_encoder", "model.safetensors")
|
||||
text_enc_2_path = osp.join(args.model_path, "text_encoder_2", "model.safetensors")
|
||||
|
||||
# Load models from safetensors if it exists, if it doesn't pytorch
|
||||
if osp.exists(unet_path):
|
||||
unet_state_dict = load_file(unet_path, device="cpu")
|
||||
else:
|
||||
unet_path = osp.join(args.model_path, "unet", "diffusion_pytorch_model.bin")
|
||||
unet_state_dict = torch.load(unet_path, map_location="cpu")
|
||||
|
||||
if osp.exists(vae_path):
|
||||
vae_state_dict = load_file(vae_path, device="cpu")
|
||||
else:
|
||||
vae_path = osp.join(args.model_path, "vae", "diffusion_pytorch_model.bin")
|
||||
vae_state_dict = torch.load(vae_path, map_location="cpu")
|
||||
|
||||
if osp.exists(text_enc_path):
|
||||
text_enc_dict = load_file(text_enc_path, device="cpu")
|
||||
else:
|
||||
text_enc_path = osp.join(args.model_path, "text_encoder", "pytorch_model.bin")
|
||||
text_enc_dict = torch.load(text_enc_path, map_location="cpu")
|
||||
|
||||
if osp.exists(text_enc_2_path):
|
||||
text_enc_2_dict = load_file(text_enc_2_path, device="cpu")
|
||||
else:
|
||||
text_enc_2_path = osp.join(args.model_path, "text_encoder_2", "pytorch_model.bin")
|
||||
text_enc_2_dict = torch.load(text_enc_2_path, map_location="cpu")
|
||||
|
||||
# Convert the UNet model
|
||||
unet_state_dict = convert_unet_state_dict(unet_state_dict)
|
||||
unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()}
|
||||
|
||||
# Convert the VAE model
|
||||
vae_state_dict = convert_vae_state_dict(vae_state_dict)
|
||||
vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()}
|
||||
|
||||
text_enc_dict = convert_openai_text_enc_state_dict(text_enc_dict)
|
||||
text_enc_dict = {"conditioner.embedders.0.transformer." + k: v for k, v in text_enc_dict.items()}
|
||||
|
||||
text_enc_2_dict = convert_openclip_text_enc_state_dict(text_enc_2_dict)
|
||||
text_enc_2_dict = {"conditioner.embedders.1.model." + k: v for k, v in text_enc_2_dict.items()}
|
||||
|
||||
# Put together new checkpoint
|
||||
state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict, **text_enc_2_dict}
|
||||
|
||||
if args.half:
|
||||
state_dict = {k: v.half() for k, v in state_dict.items()}
|
||||
|
||||
if args.use_safetensors:
|
||||
save_file(state_dict, args.checkpoint_path)
|
||||
else:
|
||||
state_dict = {"state_dict": state_dict}
|
||||
torch.save(state_dict, args.checkpoint_path)
|
||||
File diff suppressed because it is too large
Load Diff
2
setup.py
2
setup.py
@@ -233,7 +233,7 @@ install_requires = [
|
||||
|
||||
setup(
|
||||
name="diffusers",
|
||||
version="0.21.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
version="0.20.0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
description="Diffusers",
|
||||
long_description=open("README.md", "r", encoding="utf-8").read(),
|
||||
long_description_content_type="text/markdown",
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
__version__ = "0.21.0.dev0"
|
||||
__version__ = "0.20.0"
|
||||
|
||||
from .configuration_utils import ConfigMixin
|
||||
from .utils import (
|
||||
@@ -133,9 +133,6 @@ else:
|
||||
from .pipelines import (
|
||||
AltDiffusionImg2ImgPipeline,
|
||||
AltDiffusionPipeline,
|
||||
AudioLDM2Pipeline,
|
||||
AudioLDM2ProjectionModel,
|
||||
AudioLDM2UNet2DConditionModel,
|
||||
AudioLDMPipeline,
|
||||
CycleDiffusionPipeline,
|
||||
IFImg2ImgPipeline,
|
||||
|
||||
@@ -1245,7 +1245,6 @@ class LoraLoaderMixin:
|
||||
|
||||
if len(text_encoder_lora_state_dict) > 0:
|
||||
logger.info(f"Loading {prefix}.")
|
||||
rank = {}
|
||||
|
||||
if any("to_out_lora" in k for k in text_encoder_lora_state_dict.keys()):
|
||||
# Convert from the old naming convention to the new naming convention.
|
||||
@@ -1284,17 +1283,10 @@ class LoraLoaderMixin:
|
||||
f"{name}.out_proj.lora_linear_layer.down.weight"
|
||||
] = text_encoder_lora_state_dict.pop(f"{name}.to_out_lora.down.weight")
|
||||
|
||||
for name, _ in text_encoder_attn_modules(text_encoder):
|
||||
rank_key = f"{name}.out_proj.lora_linear_layer.up.weight"
|
||||
rank.update({rank_key: text_encoder_lora_state_dict[rank_key].shape[1]})
|
||||
|
||||
rank = text_encoder_lora_state_dict[
|
||||
"text_model.encoder.layers.0.self_attn.out_proj.lora_linear_layer.up.weight"
|
||||
].shape[1]
|
||||
patch_mlp = any(".mlp." in key for key in text_encoder_lora_state_dict.keys())
|
||||
if patch_mlp:
|
||||
for name, _ in text_encoder_mlp_modules(text_encoder):
|
||||
rank_key_fc1 = f"{name}.fc1.lora_linear_layer.up.weight"
|
||||
rank_key_fc2 = f"{name}.fc2.lora_linear_layer.up.weight"
|
||||
rank.update({rank_key_fc1: text_encoder_lora_state_dict[rank_key_fc1].shape[1]})
|
||||
rank.update({rank_key_fc2: text_encoder_lora_state_dict[rank_key_fc2].shape[1]})
|
||||
|
||||
if network_alphas is not None:
|
||||
alpha_keys = [
|
||||
@@ -1352,7 +1344,7 @@ class LoraLoaderMixin:
|
||||
text_encoder,
|
||||
lora_scale=1,
|
||||
network_alphas=None,
|
||||
rank: Union[Dict[str, int], int] = 4,
|
||||
rank=4,
|
||||
dtype=None,
|
||||
patch_mlp=False,
|
||||
):
|
||||
@@ -1373,46 +1365,38 @@ class LoraLoaderMixin:
|
||||
value_alpha = network_alphas.pop(name + ".to_v_lora.down.weight.alpha", None)
|
||||
out_alpha = network_alphas.pop(name + ".to_out_lora.down.weight.alpha", None)
|
||||
|
||||
if isinstance(rank, dict):
|
||||
current_rank = rank.pop(f"{name}.out_proj.lora_linear_layer.up.weight")
|
||||
else:
|
||||
current_rank = rank
|
||||
|
||||
attn_module.q_proj = PatchedLoraProjection(
|
||||
attn_module.q_proj, lora_scale, network_alpha=query_alpha, rank=current_rank, dtype=dtype
|
||||
attn_module.q_proj, lora_scale, network_alpha=query_alpha, rank=rank, dtype=dtype
|
||||
)
|
||||
lora_parameters.extend(attn_module.q_proj.lora_linear_layer.parameters())
|
||||
|
||||
attn_module.k_proj = PatchedLoraProjection(
|
||||
attn_module.k_proj, lora_scale, network_alpha=key_alpha, rank=current_rank, dtype=dtype
|
||||
attn_module.k_proj, lora_scale, network_alpha=key_alpha, rank=rank, dtype=dtype
|
||||
)
|
||||
lora_parameters.extend(attn_module.k_proj.lora_linear_layer.parameters())
|
||||
|
||||
attn_module.v_proj = PatchedLoraProjection(
|
||||
attn_module.v_proj, lora_scale, network_alpha=value_alpha, rank=current_rank, dtype=dtype
|
||||
attn_module.v_proj, lora_scale, network_alpha=value_alpha, rank=rank, dtype=dtype
|
||||
)
|
||||
lora_parameters.extend(attn_module.v_proj.lora_linear_layer.parameters())
|
||||
|
||||
attn_module.out_proj = PatchedLoraProjection(
|
||||
attn_module.out_proj, lora_scale, network_alpha=out_alpha, rank=current_rank, dtype=dtype
|
||||
attn_module.out_proj, lora_scale, network_alpha=out_alpha, rank=rank, dtype=dtype
|
||||
)
|
||||
lora_parameters.extend(attn_module.out_proj.lora_linear_layer.parameters())
|
||||
|
||||
if patch_mlp:
|
||||
for name, mlp_module in text_encoder_mlp_modules(text_encoder):
|
||||
fc1_alpha = network_alphas.pop(name + ".fc1.lora_linear_layer.down.weight.alpha", None)
|
||||
fc2_alpha = network_alphas.pop(name + ".fc2.lora_linear_layer.down.weight.alpha", None)
|
||||
|
||||
current_rank_fc1 = rank.pop(f"{name}.fc1.lora_linear_layer.up.weight")
|
||||
current_rank_fc2 = rank.pop(f"{name}.fc2.lora_linear_layer.up.weight")
|
||||
fc1_alpha = network_alphas.pop(name + ".fc1.lora_linear_layer.down.weight.alpha")
|
||||
fc2_alpha = network_alphas.pop(name + ".fc2.lora_linear_layer.down.weight.alpha")
|
||||
|
||||
mlp_module.fc1 = PatchedLoraProjection(
|
||||
mlp_module.fc1, lora_scale, network_alpha=fc1_alpha, rank=current_rank_fc1, dtype=dtype
|
||||
mlp_module.fc1, lora_scale, network_alpha=fc1_alpha, rank=rank, dtype=dtype
|
||||
)
|
||||
lora_parameters.extend(mlp_module.fc1.lora_linear_layer.parameters())
|
||||
|
||||
mlp_module.fc2 = PatchedLoraProjection(
|
||||
mlp_module.fc2, lora_scale, network_alpha=fc2_alpha, rank=current_rank_fc2, dtype=dtype
|
||||
mlp_module.fc2, lora_scale, network_alpha=fc2_alpha, rank=rank, dtype=dtype
|
||||
)
|
||||
lora_parameters.extend(mlp_module.fc2.lora_linear_layer.parameters())
|
||||
|
||||
@@ -1806,9 +1790,6 @@ class FromSingleFileMixin:
|
||||
tokenizer ([`~transformers.CLIPTokenizer`], *optional*, defaults to `None`):
|
||||
An instance of `CLIPTokenizer` to use. If this parameter is `None`, the function loads a new instance
|
||||
of `CLIPTokenizer` by itself if needed.
|
||||
original_config_file (`str`):
|
||||
Path to `.yaml` config file corresponding to the original architecture. If `None`, will be
|
||||
automatically inferred by looking for a key that only exists in SD2.0 models.
|
||||
kwargs (remaining dictionary of keyword arguments, *optional*):
|
||||
Can be used to overwrite load and saveable variables (for example the pipeline components of the
|
||||
specific pipeline class). The overwritten components are directly passed to the pipelines `__init__`
|
||||
@@ -1839,7 +1820,6 @@ class FromSingleFileMixin:
|
||||
# import here to avoid circular dependency
|
||||
from .pipelines.stable_diffusion.convert_from_ckpt import download_from_original_stable_diffusion_ckpt
|
||||
|
||||
original_config_file = kwargs.pop("original_config_file", None)
|
||||
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
@@ -1956,7 +1936,6 @@ class FromSingleFileMixin:
|
||||
text_encoder=text_encoder,
|
||||
vae=vae,
|
||||
tokenizer=tokenizer,
|
||||
original_config_file=original_config_file,
|
||||
)
|
||||
|
||||
if torch_dtype is not None:
|
||||
|
||||
@@ -137,15 +137,6 @@ class AutoencoderTiny(ModelMixin, ConfigMixin):
|
||||
self.latent_shift = latent_shift
|
||||
self.scaling_factor = scaling_factor
|
||||
|
||||
self.use_slicing = False
|
||||
self.use_tiling = False
|
||||
|
||||
# only relevant if vae tiling is enabled
|
||||
self.spatial_scale_factor = 2**out_channels
|
||||
self.tile_overlap_factor = 0.125
|
||||
self.tile_sample_min_size = 512
|
||||
self.tile_latent_min_size = self.tile_sample_min_size // self.spatial_scale_factor
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if isinstance(module, (EncoderTiny, DecoderTiny)):
|
||||
module.gradient_checkpointing = value
|
||||
@@ -158,147 +149,11 @@ class AutoencoderTiny(ModelMixin, ConfigMixin):
|
||||
"""[0, 1] -> raw latents"""
|
||||
return x.sub(self.latent_shift).mul(2 * self.latent_magnitude)
|
||||
|
||||
def enable_slicing(self):
|
||||
r"""
|
||||
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
||||
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
||||
"""
|
||||
self.use_slicing = True
|
||||
|
||||
def disable_slicing(self):
|
||||
r"""
|
||||
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
|
||||
decoding in one step.
|
||||
"""
|
||||
self.use_slicing = False
|
||||
|
||||
def enable_tiling(self, use_tiling: bool = True):
|
||||
r"""
|
||||
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
||||
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
||||
processing larger images.
|
||||
"""
|
||||
self.use_tiling = use_tiling
|
||||
|
||||
def disable_tiling(self):
|
||||
r"""
|
||||
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
|
||||
decoding in one step.
|
||||
"""
|
||||
self.enable_tiling(False)
|
||||
|
||||
def _tiled_encode(self, x: torch.FloatTensor) -> torch.FloatTensor:
|
||||
r"""Encode a batch of images using a tiled encoder.
|
||||
|
||||
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
|
||||
steps. This is useful to keep memory use constant regardless of image size. To avoid tiling artifacts, the
|
||||
tiles overlap and are blended together to form a smooth output.
|
||||
|
||||
Args:
|
||||
x (`torch.FloatTensor`): Input batch of images.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.autoencoder_tiny.AutoencoderTinyOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
[`~models.autoencoder_tiny.AutoencoderTinyOutput`] or `tuple`:
|
||||
If return_dict is True, a [`~models.autoencoder_tiny.AutoencoderTinyOutput`] is returned, otherwise a
|
||||
plain `tuple` is returned.
|
||||
"""
|
||||
# scale of encoder output relative to input
|
||||
sf = self.spatial_scale_factor
|
||||
tile_size = self.tile_sample_min_size
|
||||
|
||||
# number of pixels to blend and to traverse between tile
|
||||
blend_size = int(tile_size * self.tile_overlap_factor)
|
||||
traverse_size = tile_size - blend_size
|
||||
|
||||
# tiles index (up/left)
|
||||
ti = range(0, x.shape[-2], traverse_size)
|
||||
tj = range(0, x.shape[-1], traverse_size)
|
||||
|
||||
# mask for blending
|
||||
blend_masks = torch.stack(
|
||||
torch.meshgrid([torch.arange(tile_size / sf) / (blend_size / sf - 1)] * 2, indexing="ij")
|
||||
)
|
||||
blend_masks = blend_masks.clamp(0, 1).to(x.device)
|
||||
|
||||
# output array
|
||||
out = torch.zeros(x.shape[0], 4, x.shape[-2] // sf, x.shape[-1] // sf, device=x.device)
|
||||
for i in ti:
|
||||
for j in tj:
|
||||
tile_in = x[..., i : i + tile_size, j : j + tile_size]
|
||||
# tile result
|
||||
tile_out = out[..., i // sf : (i + tile_size) // sf, j // sf : (j + tile_size) // sf]
|
||||
tile = self.encoder(tile_in)
|
||||
h, w = tile.shape[-2], tile.shape[-1]
|
||||
# blend tile result into output
|
||||
blend_mask_i = torch.ones_like(blend_masks[0]) if i == 0 else blend_masks[0]
|
||||
blend_mask_j = torch.ones_like(blend_masks[1]) if j == 0 else blend_masks[1]
|
||||
blend_mask = blend_mask_i * blend_mask_j
|
||||
tile, blend_mask = tile[..., :h, :w], blend_mask[..., :h, :w]
|
||||
tile_out.copy_(blend_mask * tile + (1 - blend_mask) * tile_out)
|
||||
return out
|
||||
|
||||
def _tiled_decode(self, x: torch.FloatTensor) -> torch.FloatTensor:
|
||||
r"""Encode a batch of images using a tiled encoder.
|
||||
|
||||
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
|
||||
steps. This is useful to keep memory use constant regardless of image size. To avoid tiling artifacts, the
|
||||
tiles overlap and are blended together to form a smooth output.
|
||||
|
||||
Args:
|
||||
x (`torch.FloatTensor`): Input batch of images.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.autoencoder_tiny.AutoencoderTinyOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
[`~models.vae.DecoderOutput`] or `tuple`:
|
||||
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
|
||||
returned.
|
||||
"""
|
||||
# scale of decoder output relative to input
|
||||
sf = self.spatial_scale_factor
|
||||
tile_size = self.tile_latent_min_size
|
||||
|
||||
# number of pixels to blend and to traverse between tiles
|
||||
blend_size = int(tile_size * self.tile_overlap_factor)
|
||||
traverse_size = tile_size - blend_size
|
||||
|
||||
# tiles index (up/left)
|
||||
ti = range(0, x.shape[-2], traverse_size)
|
||||
tj = range(0, x.shape[-1], traverse_size)
|
||||
|
||||
# mask for blending
|
||||
blend_masks = torch.stack(
|
||||
torch.meshgrid([torch.arange(tile_size * sf) / (blend_size * sf - 1)] * 2, indexing="ij")
|
||||
)
|
||||
blend_masks = blend_masks.clamp(0, 1).to(x.device)
|
||||
|
||||
# output array
|
||||
out = torch.zeros(x.shape[0], 3, x.shape[-2] * sf, x.shape[-1] * sf, device=x.device)
|
||||
for i in ti:
|
||||
for j in tj:
|
||||
tile_in = x[..., i : i + tile_size, j : j + tile_size]
|
||||
# tile result
|
||||
tile_out = out[..., i * sf : (i + tile_size) * sf, j * sf : (j + tile_size) * sf]
|
||||
tile = self.decoder(tile_in)
|
||||
h, w = tile.shape[-2], tile.shape[-1]
|
||||
# blend tile result into output
|
||||
blend_mask_i = torch.ones_like(blend_masks[0]) if i == 0 else blend_masks[0]
|
||||
blend_mask_j = torch.ones_like(blend_masks[1]) if j == 0 else blend_masks[1]
|
||||
blend_mask = (blend_mask_i * blend_mask_j)[..., :h, :w]
|
||||
tile_out.copy_(blend_mask * tile + (1 - blend_mask) * tile_out)
|
||||
return out
|
||||
|
||||
@apply_forward_hook
|
||||
def encode(
|
||||
self, x: torch.FloatTensor, return_dict: bool = True
|
||||
) -> Union[AutoencoderTinyOutput, Tuple[torch.FloatTensor]]:
|
||||
if self.use_slicing and x.shape[0] > 1:
|
||||
output = [self._tiled_encode(x_slice) if self.use_tiling else self.encoder(x) for x_slice in x.split(1)]
|
||||
output = torch.cat(output)
|
||||
else:
|
||||
output = self._tiled_encode(x) if self.use_tiling else self.encoder(x)
|
||||
output = self.encoder(x)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
@@ -307,11 +162,10 @@ class AutoencoderTiny(ModelMixin, ConfigMixin):
|
||||
|
||||
@apply_forward_hook
|
||||
def decode(self, x: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, Tuple[torch.FloatTensor]]:
|
||||
if self.use_slicing and x.shape[0] > 1:
|
||||
output = [self._tiled_decode(x_slice) if self.use_tiling else self.decoder(x) for x_slice in x.split(1)]
|
||||
output = torch.cat(output)
|
||||
else:
|
||||
output = self._tiled_decode(x) if self.use_tiling else self.decoder(x)
|
||||
output = self.decoder(x)
|
||||
# Refer to the following discussion to know why this is needed.
|
||||
# https://github.com/huggingface/diffusers/pull/4384#discussion_r1279401854
|
||||
output = output.mul_(2).sub_(1)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
@@ -330,15 +184,8 @@ class AutoencoderTiny(ModelMixin, ConfigMixin):
|
||||
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
|
||||
"""
|
||||
enc = self.encode(sample).latents
|
||||
|
||||
# scale latents to be in [0, 1], then quantize latents to a byte tensor,
|
||||
# as if we were storing the latents in an RGBA uint8 image.
|
||||
scaled_enc = self.scale_latents(enc).mul_(255).round_().byte()
|
||||
|
||||
# unquantize latents back into [0, 1], then unscale latents back to their original range,
|
||||
# as if we were loading the latents from an RGBA uint8 image.
|
||||
unscaled_enc = self.unscale_latents(scaled_enc / 255.0)
|
||||
|
||||
unscaled_enc = self.unscale_latents(scaled_enc)
|
||||
dec = self.decode(unscaled_enc)
|
||||
|
||||
if not return_dict:
|
||||
|
||||
@@ -88,7 +88,6 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
||||
num_embeds_ada_norm: Optional[int] = None,
|
||||
use_linear_projection: bool = False,
|
||||
only_cross_attention: bool = False,
|
||||
double_self_attention: bool = False,
|
||||
upcast_attention: bool = False,
|
||||
norm_type: str = "layer_norm",
|
||||
norm_elementwise_affine: bool = True,
|
||||
@@ -182,7 +181,6 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
||||
num_embeds_ada_norm=num_embeds_ada_norm,
|
||||
attention_bias=attention_bias,
|
||||
only_cross_attention=only_cross_attention,
|
||||
double_self_attention=double_self_attention,
|
||||
upcast_attention=upcast_attention,
|
||||
norm_type=norm_type,
|
||||
norm_elementwise_affine=norm_elementwise_affine,
|
||||
|
||||
@@ -732,8 +732,7 @@ class EncoderTiny(nn.Module):
|
||||
x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x)
|
||||
|
||||
else:
|
||||
# scale image from [-1, 1] to [0, 1] to match TAESD convention
|
||||
x = self.layers(x.add(1).div(2))
|
||||
x = self.layers(x)
|
||||
|
||||
return x
|
||||
|
||||
@@ -791,5 +790,4 @@ class DecoderTiny(nn.Module):
|
||||
else:
|
||||
x = self.layers(x)
|
||||
|
||||
# scale image from [0, 1] to [-1, 1] to match diffusers convention
|
||||
return x.mul(2).sub(1)
|
||||
return x
|
||||
|
||||
@@ -46,7 +46,6 @@ except OptionalDependencyNotAvailable:
|
||||
else:
|
||||
from .alt_diffusion import AltDiffusionImg2ImgPipeline, AltDiffusionPipeline
|
||||
from .audioldm import AudioLDMPipeline
|
||||
from .audioldm2 import AudioLDM2Pipeline, AudioLDM2ProjectionModel, AudioLDM2UNet2DConditionModel
|
||||
from .controlnet import (
|
||||
StableDiffusionControlNetImg2ImgPipeline,
|
||||
StableDiffusionControlNetInpaintPipeline,
|
||||
|
||||
@@ -418,7 +418,8 @@ class AudioLDMPipeline(DiffusionPipeline):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
|
||||
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.AudioPipelineOutput`] instead of a plain tuple.
|
||||
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
||||
plain tuple.
|
||||
callback (`Callable`, *optional*):
|
||||
A function that calls every `callback_steps` steps during inference. The function is called with the
|
||||
following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
||||
@@ -435,9 +436,9 @@ class AudioLDMPipeline(DiffusionPipeline):
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.AudioPipelineOutput`] or `tuple`:
|
||||
If `return_dict` is `True`, [`~pipelines.AudioPipelineOutput`] is returned, otherwise a `tuple` is
|
||||
returned where the first element is a list with the generated audio.
|
||||
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
||||
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
|
||||
otherwise a `tuple` is returned where the first element is a list with the generated audio.
|
||||
"""
|
||||
# 0. Convert audio input length from seconds to spectrogram height
|
||||
vocoder_upsample_factor = np.prod(self.vocoder.config.upsample_rates) / self.vocoder.config.sampling_rate
|
||||
|
||||
@@ -1,20 +0,0 @@
|
||||
from ...utils import (
|
||||
OptionalDependencyNotAvailable,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
is_transformers_version,
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.27.0")):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import (
|
||||
AudioLDM2Pipeline,
|
||||
AudioLDM2ProjectionModel,
|
||||
AudioLDM2UNet2DConditionModel,
|
||||
)
|
||||
else:
|
||||
from .modeling_audioldm2 import AudioLDM2ProjectionModel, AudioLDM2UNet2DConditionModel
|
||||
from .pipeline_audioldm2 import AudioLDM2Pipeline
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,973 +0,0 @@
|
||||
# Copyright 2023 CVSSP, ByteDance and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import (
|
||||
ClapFeatureExtractor,
|
||||
ClapModel,
|
||||
GPT2Model,
|
||||
RobertaTokenizer,
|
||||
RobertaTokenizerFast,
|
||||
SpeechT5HifiGan,
|
||||
T5EncoderModel,
|
||||
T5Tokenizer,
|
||||
T5TokenizerFast,
|
||||
)
|
||||
|
||||
from ...models import AutoencoderKL
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import (
|
||||
is_accelerate_available,
|
||||
is_accelerate_version,
|
||||
is_librosa_available,
|
||||
logging,
|
||||
randn_tensor,
|
||||
replace_example_docstring,
|
||||
)
|
||||
from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline
|
||||
from .modeling_audioldm2 import AudioLDM2ProjectionModel, AudioLDM2UNet2DConditionModel
|
||||
|
||||
|
||||
if is_librosa_available():
|
||||
import librosa
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> import scipy
|
||||
>>> import torch
|
||||
>>> from diffusers import AudioLDM2Pipeline
|
||||
|
||||
>>> repo_id = "cvssp/audioldm2"
|
||||
>>> pipe = AudioLDM2Pipeline.from_pretrained(repo_id, torch_dtype=torch.float16)
|
||||
>>> pipe = pipe.to("cuda")
|
||||
|
||||
>>> # define the prompts
|
||||
>>> prompt = "The sound of a hammer hitting a wooden surface."
|
||||
>>> negative_prompt = "Low quality."
|
||||
|
||||
>>> # set the seed for generator
|
||||
>>> generator = torch.Generator("cuda").manual_seed(0)
|
||||
|
||||
>>> # run the generation
|
||||
>>> audio = pipe(
|
||||
... prompt,
|
||||
... negative_prompt=negative_prompt,
|
||||
... num_inference_steps=200,
|
||||
... audio_length_in_s=10.0,
|
||||
... num_waveforms_per_prompt=3,
|
||||
... generator=generator,
|
||||
... ).audios
|
||||
|
||||
>>> # save the best audio sample (index 0) as a .wav file
|
||||
>>> scipy.io.wavfile.write("techno.wav", rate=16000, data=audio[0])
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
inputs_embeds,
|
||||
attention_mask=None,
|
||||
past_key_values=None,
|
||||
**kwargs,
|
||||
):
|
||||
if past_key_values is not None:
|
||||
# only last token for inputs_embeds if past is defined in kwargs
|
||||
inputs_embeds = inputs_embeds[:, -1:]
|
||||
|
||||
return {
|
||||
"inputs_embeds": inputs_embeds,
|
||||
"attention_mask": attention_mask,
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": kwargs.get("use_cache"),
|
||||
}
|
||||
|
||||
|
||||
class AudioLDM2Pipeline(DiffusionPipeline):
|
||||
r"""
|
||||
Pipeline for text-to-audio generation using AudioLDM2.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
||||
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
|
||||
text_encoder ([`~transformers.ClapModel`]):
|
||||
First frozen text-encoder. AudioLDM2 uses the joint audio-text embedding model
|
||||
[CLAP](https://huggingface.co/docs/transformers/model_doc/clap#transformers.CLAPTextModelWithProjection),
|
||||
specifically the [laion/clap-htsat-unfused](https://huggingface.co/laion/clap-htsat-unfused) variant. The
|
||||
text branch is used to encode the text prompt to a prompt embedding. The full audio-text model is used to
|
||||
rank generated waveforms against the text prompt by computing similarity scores.
|
||||
text_encoder_2 ([`~transformers.T5EncoderModel`]):
|
||||
Second frozen text-encoder. AudioLDM2 uses the encoder of
|
||||
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
|
||||
[google/flan-t5-large](https://huggingface.co/google/flan-t5-large) variant.
|
||||
projection_model ([`AudioLDM2ProjectionModel`]):
|
||||
A trained model used to linearly project the hidden-states from the first and second text encoder models
|
||||
and insert learned SOS and EOS token embeddings. The projected hidden-states from the two text encoders are
|
||||
concatenated to give the input to the language model.
|
||||
language_model ([`~transformers.GPT2Model`]):
|
||||
An auto-regressive language model used to generate a sequence of hidden-states conditioned on the projected
|
||||
outputs from the two text encoders.
|
||||
tokenizer ([`~transformers.RobertaTokenizer`]):
|
||||
Tokenizer to tokenize text for the first frozen text-encoder.
|
||||
tokenizer_2 ([`~transformers.T5Tokenizer`]):
|
||||
Tokenizer to tokenize text for the second frozen text-encoder.
|
||||
feature_extractor ([`~transformers.ClapFeatureExtractor`]):
|
||||
Feature extractor to pre-process generated audio waveforms to log-mel spectrograms for automatic scoring.
|
||||
unet ([`UNet2DConditionModel`]):
|
||||
A `UNet2DConditionModel` to denoise the encoded audio latents.
|
||||
scheduler ([`SchedulerMixin`]):
|
||||
A scheduler to be used in combination with `unet` to denoise the encoded audio latents. Can be one of
|
||||
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
||||
vocoder ([`~transformers.SpeechT5HifiGan`]):
|
||||
Vocoder of class `SpeechT5HifiGan` to convert the mel-spectrogram latents to the final audio waveform.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vae: AutoencoderKL,
|
||||
text_encoder: ClapModel,
|
||||
text_encoder_2: T5EncoderModel,
|
||||
projection_model: AudioLDM2ProjectionModel,
|
||||
language_model: GPT2Model,
|
||||
tokenizer: Union[RobertaTokenizer, RobertaTokenizerFast],
|
||||
tokenizer_2: Union[T5Tokenizer, T5TokenizerFast],
|
||||
feature_extractor: ClapFeatureExtractor,
|
||||
unet: AudioLDM2UNet2DConditionModel,
|
||||
scheduler: KarrasDiffusionSchedulers,
|
||||
vocoder: SpeechT5HifiGan,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
text_encoder_2=text_encoder_2,
|
||||
projection_model=projection_model,
|
||||
language_model=language_model,
|
||||
tokenizer=tokenizer,
|
||||
tokenizer_2=tokenizer_2,
|
||||
feature_extractor=feature_extractor,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
vocoder=vocoder,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
|
||||
def enable_vae_slicing(self):
|
||||
r"""
|
||||
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
||||
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
||||
"""
|
||||
self.vae.enable_slicing()
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
|
||||
def disable_vae_slicing(self):
|
||||
r"""
|
||||
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
|
||||
computing decoding in one step.
|
||||
"""
|
||||
self.vae.disable_slicing()
|
||||
|
||||
def enable_model_cpu_offload(self, gpu_id=0):
|
||||
r"""
|
||||
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
|
||||
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
|
||||
method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
|
||||
`enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
|
||||
"""
|
||||
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
|
||||
from accelerate import cpu_offload_with_hook
|
||||
else:
|
||||
raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
if self.device.type != "cpu":
|
||||
self.to("cpu", silence_dtype_warnings=True)
|
||||
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
|
||||
|
||||
model_sequence = [
|
||||
self.text_encoder,
|
||||
self.text_encoder_2,
|
||||
self.projection_model,
|
||||
self.language_model,
|
||||
self.unet,
|
||||
self.vae,
|
||||
]
|
||||
|
||||
hook = None
|
||||
for cpu_offloaded_model in model_sequence:
|
||||
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
|
||||
|
||||
# We'll offload the last model manually.
|
||||
self.final_offload_hook = hook
|
||||
|
||||
def generate_language_model(
|
||||
self,
|
||||
inputs_embeds: torch.Tensor = None,
|
||||
max_new_tokens: int = 8,
|
||||
**model_kwargs,
|
||||
):
|
||||
"""
|
||||
|
||||
Generates a sequence of hidden-states from the language model, conditioned on the embedding inputs.
|
||||
|
||||
Parameters:
|
||||
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
||||
The sequence used as a prompt for the generation.
|
||||
max_new_tokens (`int`):
|
||||
Number of new tokens to generate.
|
||||
model_kwargs (`Dict[str, Any]`, *optional*):
|
||||
Ad hoc parametrization of additional model-specific kwargs that will be forwarded to the `forward`
|
||||
function of the model.
|
||||
|
||||
Return:
|
||||
`inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
||||
The sequence of generated hidden-states.
|
||||
"""
|
||||
max_new_tokens = max_new_tokens if max_new_tokens is not None else self.language_model.config.max_new_tokens
|
||||
for _ in range(max_new_tokens):
|
||||
# prepare model inputs
|
||||
model_inputs = prepare_inputs_for_generation(inputs_embeds, **model_kwargs)
|
||||
|
||||
# forward pass to get next hidden states
|
||||
output = self.language_model(**model_inputs, return_dict=True)
|
||||
|
||||
next_hidden_states = output.last_hidden_state
|
||||
|
||||
# Update the model input
|
||||
inputs_embeds = torch.cat([inputs_embeds, next_hidden_states[:, -1:, :]], dim=1)
|
||||
|
||||
# Update generated hidden states, model inputs, and length for next step
|
||||
model_kwargs = self.language_model._update_model_kwargs_for_generation(output, model_kwargs)
|
||||
|
||||
return inputs_embeds[:, -max_new_tokens:, :]
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
device,
|
||||
num_waveforms_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt=None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
generated_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_generated_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.LongTensor] = None,
|
||||
negative_attention_mask: Optional[torch.LongTensor] = None,
|
||||
max_new_tokens: Optional[int] = None,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
device (`torch.device`):
|
||||
torch device
|
||||
num_waveforms_per_prompt (`int`):
|
||||
number of waveforms that should be generated per prompt
|
||||
do_classifier_free_guidance (`bool`):
|
||||
whether to use classifier free guidance or not
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the audio generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-computed text embeddings from the Flan T5 model. Can be used to easily tweak text inputs, *e.g.*
|
||||
prompt weighting. If not provided, text embeddings will be computed from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-computed negative text embeddings from the Flan T5 model. Can be used to easily tweak text inputs,
|
||||
*e.g.* prompt weighting. If not provided, negative_prompt_embeds will be computed from
|
||||
`negative_prompt` input argument.
|
||||
generated_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings from the GPT2 langauge model. Can be used to easily tweak text inputs,
|
||||
*e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input
|
||||
argument.
|
||||
negative_generated_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings from the GPT2 language model. Can be used to easily tweak text
|
||||
inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be computed from
|
||||
`negative_prompt` input argument.
|
||||
attention_mask (`torch.LongTensor`, *optional*):
|
||||
Pre-computed attention mask to be applied to the `prompt_embeds`. If not provided, attention mask will
|
||||
be computed from `prompt` input argument.
|
||||
negative_attention_mask (`torch.LongTensor`, *optional*):
|
||||
Pre-computed attention mask to be applied to the `negative_prompt_embeds`. If not provided, attention
|
||||
mask will be computed from `negative_prompt` input argument.
|
||||
max_new_tokens (`int`, *optional*, defaults to None):
|
||||
The number of new tokens to generate with the GPT2 language model.
|
||||
Returns:
|
||||
prompt_embeds (`torch.FloatTensor`):
|
||||
Text embeddings from the Flan T5 model.
|
||||
attention_mask (`torch.LongTensor`):
|
||||
Attention mask to be applied to the `prompt_embeds`.
|
||||
generated_prompt_embeds (`torch.FloatTensor`):
|
||||
Text embeddings generated from the GPT2 langauge model.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> import scipy
|
||||
>>> import torch
|
||||
>>> from diffusers import AudioLDM2Pipeline
|
||||
|
||||
>>> repo_id = "cvssp/audioldm2"
|
||||
>>> pipe = AudioLDM2Pipeline.from_pretrained(repo_id, torch_dtype=torch.float16)
|
||||
>>> pipe = pipe.to("cuda")
|
||||
|
||||
>>> # Get text embedding vectors
|
||||
>>> prompt_embeds, attention_mask, generated_prompt_embeds = pipe.encode_prompt(
|
||||
... prompt="Techno music with a strong, upbeat tempo and high melodic riffs",
|
||||
... device="cuda",
|
||||
... do_classifier_free_guidance=True,
|
||||
... )
|
||||
|
||||
>>> # Pass text embeddings to pipeline for text-conditional audio generation
|
||||
>>> audio = pipe(
|
||||
... prompt_embeds=prompt_embeds,
|
||||
... attention_mask=attention_mask,
|
||||
... generated_prompt_embeds=generated_prompt_embeds,
|
||||
... num_inference_steps=200,
|
||||
... audio_length_in_s=10.0,
|
||||
... ).audios[0]
|
||||
|
||||
>>> # save generated audio sample
|
||||
>>> scipy.io.wavfile.write("techno.wav", rate=16000, data=audio)
|
||||
```"""
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
# Define tokenizers and text encoders
|
||||
tokenizers = [self.tokenizer, self.tokenizer_2]
|
||||
text_encoders = [self.text_encoder, self.text_encoder_2]
|
||||
|
||||
if prompt_embeds is None:
|
||||
prompt_embeds_list = []
|
||||
attention_mask_list = []
|
||||
|
||||
for tokenizer, text_encoder in zip(tokenizers, text_encoders):
|
||||
text_inputs = tokenizer(
|
||||
prompt,
|
||||
padding="max_length" if isinstance(tokenizer, (RobertaTokenizer, RobertaTokenizerFast)) else True,
|
||||
max_length=tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
attention_mask = text_inputs.attention_mask
|
||||
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
||||
text_input_ids, untruncated_ids
|
||||
):
|
||||
removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
|
||||
logger.warning(
|
||||
f"The following part of your input was truncated because {text_encoder.config.model_type} can "
|
||||
f"only handle sequences up to {tokenizer.model_max_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
text_input_ids = text_input_ids.to(device)
|
||||
attention_mask = attention_mask.to(device)
|
||||
|
||||
if text_encoder.config.model_type == "clap":
|
||||
prompt_embeds = text_encoder.get_text_features(
|
||||
text_input_ids,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
# append the seq-len dim: (bs, hidden_size) -> (bs, seq_len, hidden_size)
|
||||
prompt_embeds = prompt_embeds[:, None, :]
|
||||
# make sure that we attend to this single hidden-state
|
||||
attention_mask = attention_mask.new_ones((batch_size, 1))
|
||||
else:
|
||||
prompt_embeds = text_encoder(
|
||||
text_input_ids,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
prompt_embeds = prompt_embeds[0]
|
||||
|
||||
prompt_embeds_list.append(prompt_embeds)
|
||||
attention_mask_list.append(attention_mask)
|
||||
|
||||
projection_output = self.projection_model(
|
||||
hidden_states=prompt_embeds_list[0],
|
||||
hidden_states_1=prompt_embeds_list[1],
|
||||
attention_mask=attention_mask_list[0],
|
||||
attention_mask_1=attention_mask_list[1],
|
||||
)
|
||||
projected_prompt_embeds = projection_output.hidden_states
|
||||
projected_attention_mask = projection_output.attention_mask
|
||||
|
||||
generated_prompt_embeds = self.generate_language_model(
|
||||
projected_prompt_embeds,
|
||||
attention_mask=projected_attention_mask,
|
||||
max_new_tokens=max_new_tokens,
|
||||
)
|
||||
|
||||
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
|
||||
attention_mask = (
|
||||
attention_mask.to(device=device)
|
||||
if attention_mask is not None
|
||||
else torch.ones(prompt_embeds.shape[:2], dtype=torch.long, device=device)
|
||||
)
|
||||
generated_prompt_embeds = generated_prompt_embeds.to(dtype=self.language_model.dtype, device=device)
|
||||
|
||||
bs_embed, seq_len, hidden_size = prompt_embeds.shape
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_waveforms_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(bs_embed * num_waveforms_per_prompt, seq_len, hidden_size)
|
||||
|
||||
# duplicate attention mask for each generation per prompt
|
||||
attention_mask = attention_mask.repeat(1, num_waveforms_per_prompt)
|
||||
attention_mask = attention_mask.view(bs_embed * num_waveforms_per_prompt, seq_len)
|
||||
|
||||
bs_embed, seq_len, hidden_size = generated_prompt_embeds.shape
|
||||
# duplicate generated embeddings for each generation per prompt, using mps friendly method
|
||||
generated_prompt_embeds = generated_prompt_embeds.repeat(1, num_waveforms_per_prompt, 1)
|
||||
generated_prompt_embeds = generated_prompt_embeds.view(
|
||||
bs_embed * num_waveforms_per_prompt, seq_len, hidden_size
|
||||
)
|
||||
|
||||
# get unconditional embeddings for classifier free guidance
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
uncond_tokens: List[str]
|
||||
if negative_prompt is None:
|
||||
uncond_tokens = [""] * batch_size
|
||||
elif type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif isinstance(negative_prompt, str):
|
||||
uncond_tokens = [negative_prompt]
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
else:
|
||||
uncond_tokens = negative_prompt
|
||||
|
||||
negative_prompt_embeds_list = []
|
||||
negative_attention_mask_list = []
|
||||
max_length = prompt_embeds.shape[1]
|
||||
for tokenizer, text_encoder in zip(tokenizers, text_encoders):
|
||||
uncond_input = tokenizer(
|
||||
uncond_tokens,
|
||||
padding="max_length",
|
||||
max_length=tokenizer.model_max_length
|
||||
if isinstance(tokenizer, (RobertaTokenizer, RobertaTokenizerFast))
|
||||
else max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
uncond_input_ids = uncond_input.input_ids.to(device)
|
||||
negative_attention_mask = uncond_input.attention_mask.to(device)
|
||||
|
||||
if text_encoder.config.model_type == "clap":
|
||||
negative_prompt_embeds = text_encoder.get_text_features(
|
||||
uncond_input_ids,
|
||||
attention_mask=negative_attention_mask,
|
||||
)
|
||||
# append the seq-len dim: (bs, hidden_size) -> (bs, seq_len, hidden_size)
|
||||
negative_prompt_embeds = negative_prompt_embeds[:, None, :]
|
||||
# make sure that we attend to this single hidden-state
|
||||
negative_attention_mask = negative_attention_mask.new_ones((batch_size, 1))
|
||||
else:
|
||||
negative_prompt_embeds = text_encoder(
|
||||
uncond_input_ids,
|
||||
attention_mask=negative_attention_mask,
|
||||
)
|
||||
negative_prompt_embeds = negative_prompt_embeds[0]
|
||||
|
||||
negative_prompt_embeds_list.append(negative_prompt_embeds)
|
||||
negative_attention_mask_list.append(negative_attention_mask)
|
||||
|
||||
projection_output = self.projection_model(
|
||||
hidden_states=negative_prompt_embeds_list[0],
|
||||
hidden_states_1=negative_prompt_embeds_list[1],
|
||||
attention_mask=negative_attention_mask_list[0],
|
||||
attention_mask_1=negative_attention_mask_list[1],
|
||||
)
|
||||
negative_projected_prompt_embeds = projection_output.hidden_states
|
||||
negative_projected_attention_mask = projection_output.attention_mask
|
||||
|
||||
negative_generated_prompt_embeds = self.generate_language_model(
|
||||
negative_projected_prompt_embeds,
|
||||
attention_mask=negative_projected_attention_mask,
|
||||
max_new_tokens=max_new_tokens,
|
||||
)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
seq_len = negative_prompt_embeds.shape[1]
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
|
||||
negative_attention_mask = (
|
||||
negative_attention_mask.to(device=device)
|
||||
if negative_attention_mask is not None
|
||||
else torch.ones(negative_prompt_embeds.shape[:2], dtype=torch.long, device=device)
|
||||
)
|
||||
negative_generated_prompt_embeds = negative_generated_prompt_embeds.to(
|
||||
dtype=self.language_model.dtype, device=device
|
||||
)
|
||||
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_waveforms_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_waveforms_per_prompt, seq_len, -1)
|
||||
|
||||
# duplicate unconditional attention mask for each generation per prompt
|
||||
negative_attention_mask = negative_attention_mask.repeat(1, num_waveforms_per_prompt)
|
||||
negative_attention_mask = negative_attention_mask.view(batch_size * num_waveforms_per_prompt, seq_len)
|
||||
|
||||
# duplicate unconditional generated embeddings for each generation per prompt
|
||||
seq_len = negative_generated_prompt_embeds.shape[1]
|
||||
negative_generated_prompt_embeds = negative_generated_prompt_embeds.repeat(1, num_waveforms_per_prompt, 1)
|
||||
negative_generated_prompt_embeds = negative_generated_prompt_embeds.view(
|
||||
batch_size * num_waveforms_per_prompt, seq_len, -1
|
||||
)
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
attention_mask = torch.cat([negative_attention_mask, attention_mask])
|
||||
generated_prompt_embeds = torch.cat([negative_generated_prompt_embeds, generated_prompt_embeds])
|
||||
|
||||
return prompt_embeds, attention_mask, generated_prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.audioldm.pipeline_audioldm.AudioLDMPipeline.mel_spectrogram_to_waveform
|
||||
def mel_spectrogram_to_waveform(self, mel_spectrogram):
|
||||
if mel_spectrogram.dim() == 4:
|
||||
mel_spectrogram = mel_spectrogram.squeeze(1)
|
||||
|
||||
waveform = self.vocoder(mel_spectrogram)
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
||||
waveform = waveform.cpu().float()
|
||||
return waveform
|
||||
|
||||
def score_waveforms(self, text, audio, num_waveforms_per_prompt, device, dtype):
|
||||
if not is_librosa_available():
|
||||
logger.info(
|
||||
"Automatic scoring of the generated audio waveforms against the input prompt text requires the "
|
||||
"`librosa` package to resample the generated waveforms. Returning the audios in the order they were "
|
||||
"generated. To enable automatic scoring, install `librosa` with: `pip install librosa`."
|
||||
)
|
||||
return audio
|
||||
inputs = self.tokenizer(text, return_tensors="pt", padding=True)
|
||||
resampled_audio = librosa.resample(
|
||||
audio.numpy(), orig_sr=self.vocoder.config.sampling_rate, target_sr=self.feature_extractor.sampling_rate
|
||||
)
|
||||
inputs["input_features"] = self.feature_extractor(
|
||||
list(resampled_audio), return_tensors="pt", sampling_rate=self.feature_extractor.sampling_rate
|
||||
).input_features.type(dtype)
|
||||
inputs = inputs.to(device)
|
||||
|
||||
# compute the audio-text similarity score using the CLAP model
|
||||
logits_per_text = self.text_encoder(**inputs).logits_per_text
|
||||
# sort by the highest matching generations per prompt
|
||||
indices = torch.argsort(logits_per_text, dim=1, descending=True)[:, :num_waveforms_per_prompt]
|
||||
audio = torch.index_select(audio, 0, indices.reshape(-1).cpu())
|
||||
return audio
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
||||
# and should be between [0, 1]
|
||||
|
||||
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
extra_step_kwargs = {}
|
||||
if accepts_eta:
|
||||
extra_step_kwargs["eta"] = eta
|
||||
|
||||
# check if the scheduler accepts generator
|
||||
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
if accepts_generator:
|
||||
extra_step_kwargs["generator"] = generator
|
||||
return extra_step_kwargs
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
audio_length_in_s,
|
||||
vocoder_upsample_factor,
|
||||
callback_steps,
|
||||
negative_prompt=None,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
generated_prompt_embeds=None,
|
||||
negative_generated_prompt_embeds=None,
|
||||
attention_mask=None,
|
||||
negative_attention_mask=None,
|
||||
):
|
||||
min_audio_length_in_s = vocoder_upsample_factor * self.vae_scale_factor
|
||||
if audio_length_in_s < min_audio_length_in_s:
|
||||
raise ValueError(
|
||||
f"`audio_length_in_s` has to be a positive value greater than or equal to {min_audio_length_in_s}, but "
|
||||
f"is {audio_length_in_s}."
|
||||
)
|
||||
|
||||
if self.vocoder.config.model_in_dim % self.vae_scale_factor != 0:
|
||||
raise ValueError(
|
||||
f"The number of frequency bins in the vocoder's log-mel spectrogram has to be divisible by the "
|
||||
f"VAE scale factor, but got {self.vocoder.config.model_in_dim} bins and a scale factor of "
|
||||
f"{self.vae_scale_factor}."
|
||||
)
|
||||
|
||||
if (callback_steps is None) or (
|
||||
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
||||
f" {type(callback_steps)}."
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and (prompt_embeds is None or generated_prompt_embeds is None):
|
||||
raise ValueError(
|
||||
"Provide either `prompt`, or `prompt_embeds` and `generated_prompt_embeds`. Cannot leave "
|
||||
"`prompt` undefined without specifying both `prompt_embeds` and `generated_prompt_embeds`."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if negative_prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
elif negative_prompt_embeds is not None and negative_generated_prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Cannot forward `negative_prompt_embeds` without `negative_generated_prompt_embeds`. Ensure that"
|
||||
"both arguments are specified"
|
||||
)
|
||||
|
||||
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
||||
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
||||
raise ValueError(
|
||||
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
||||
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
if attention_mask is not None and attention_mask.shape != prompt_embeds.shape[:2]:
|
||||
raise ValueError(
|
||||
"`attention_mask should have the same batch size and sequence length as `prompt_embeds`, but got:"
|
||||
f"`attention_mask: {attention_mask.shape} != `prompt_embeds` {prompt_embeds.shape}"
|
||||
)
|
||||
|
||||
if generated_prompt_embeds is not None and negative_generated_prompt_embeds is not None:
|
||||
if generated_prompt_embeds.shape != negative_generated_prompt_embeds.shape:
|
||||
raise ValueError(
|
||||
"`generated_prompt_embeds` and `negative_generated_prompt_embeds` must have the same shape when "
|
||||
f"passed directly, but got: `generated_prompt_embeds` {generated_prompt_embeds.shape} != "
|
||||
f"`negative_generated_prompt_embeds` {negative_generated_prompt_embeds.shape}."
|
||||
)
|
||||
if (
|
||||
negative_attention_mask is not None
|
||||
and negative_attention_mask.shape != negative_prompt_embeds.shape[:2]
|
||||
):
|
||||
raise ValueError(
|
||||
"`attention_mask should have the same batch size and sequence length as `prompt_embeds`, but got:"
|
||||
f"`attention_mask: {negative_attention_mask.shape} != `prompt_embeds` {negative_prompt_embeds.shape}"
|
||||
)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents with width->self.vocoder.config.model_in_dim
|
||||
def prepare_latents(self, batch_size, num_channels_latents, height, dtype, device, generator, latents=None):
|
||||
shape = (
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
height // self.vae_scale_factor,
|
||||
self.vocoder.config.model_in_dim // self.vae_scale_factor,
|
||||
)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
if latents is None:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
else:
|
||||
latents = latents.to(device)
|
||||
|
||||
# scale the initial noise by the standard deviation required by the scheduler
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
audio_length_in_s: Optional[float] = None,
|
||||
num_inference_steps: int = 200,
|
||||
guidance_scale: float = 3.5,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_waveforms_per_prompt: Optional[int] = 1,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
generated_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_generated_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.LongTensor] = None,
|
||||
negative_attention_mask: Optional[torch.LongTensor] = None,
|
||||
max_new_tokens: Optional[int] = None,
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
callback_steps: Optional[int] = 1,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
output_type: Optional[str] = "np",
|
||||
):
|
||||
r"""
|
||||
The call function to the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide audio generation. If not defined, you need to pass `prompt_embeds`.
|
||||
audio_length_in_s (`int`, *optional*, defaults to 10.24):
|
||||
The length of the generated audio sample in seconds.
|
||||
num_inference_steps (`int`, *optional*, defaults to 200):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality audio at the
|
||||
expense of slower inference.
|
||||
guidance_scale (`float`, *optional*, defaults to 3.5):
|
||||
A higher guidance scale value encourages the model to generate audio that is closely linked to the text
|
||||
`prompt` at the expense of lower sound quality. Guidance scale is enabled when `guidance_scale > 1`.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide what to not include in audio generation. If not defined, you need to
|
||||
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
|
||||
num_waveforms_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of waveforms to generate per prompt. If `num_waveforms_per_prompt > 1`, then automatic
|
||||
scoring is performed between the generated outputs and the text prompt. This scoring ranks the
|
||||
generated waveforms based on their cosine similarity with the text input in the joint text-audio
|
||||
embedding space.
|
||||
eta (`float`, *optional*, defaults to 0.0):
|
||||
Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
|
||||
to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
||||
generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for spectrogram
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor is generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
|
||||
provided, text embeddings are generated from the `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
|
||||
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
|
||||
generated_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings from the GPT2 langauge model. Can be used to easily tweak text inputs,
|
||||
*e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input
|
||||
argument.
|
||||
negative_generated_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings from the GPT2 language model. Can be used to easily tweak text
|
||||
inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be computed from
|
||||
`negative_prompt` input argument.
|
||||
attention_mask (`torch.LongTensor`, *optional*):
|
||||
Pre-computed attention mask to be applied to the `prompt_embeds`. If not provided, attention mask will
|
||||
be computed from `prompt` input argument.
|
||||
negative_attention_mask (`torch.LongTensor`, *optional*):
|
||||
Pre-computed attention mask to be applied to the `negative_prompt_embeds`. If not provided, attention
|
||||
mask will be computed from `negative_prompt` input argument.
|
||||
max_new_tokens (`int`, *optional*, defaults to None):
|
||||
Number of new tokens to generate with the GPT2 language model. If not provided, number of tokens will
|
||||
be taken from the config of the model.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
||||
plain tuple.
|
||||
callback (`Callable`, *optional*):
|
||||
A function that calls every `callback_steps` steps during inference. The function is called with the
|
||||
following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
||||
callback_steps (`int`, *optional*, defaults to 1):
|
||||
The frequency at which the `callback` function is called. If not specified, the callback is called at
|
||||
every step.
|
||||
cross_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
|
||||
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
output_type (`str`, *optional*, defaults to `"np"`):
|
||||
The output format of the generated audio. Choose between `"np"` to return a NumPy `np.ndarray` or
|
||||
`"pt"` to return a PyTorch `torch.Tensor` object. Set to `"latent"` to return the latent diffusion
|
||||
model (LDM) output.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
||||
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
|
||||
otherwise a `tuple` is returned where the first element is a list with the generated audio.
|
||||
"""
|
||||
# 0. Convert audio input length from seconds to spectrogram height
|
||||
vocoder_upsample_factor = np.prod(self.vocoder.config.upsample_rates) / self.vocoder.config.sampling_rate
|
||||
|
||||
if audio_length_in_s is None:
|
||||
audio_length_in_s = self.unet.config.sample_size * self.vae_scale_factor * vocoder_upsample_factor
|
||||
|
||||
height = int(audio_length_in_s / vocoder_upsample_factor)
|
||||
|
||||
original_waveform_length = int(audio_length_in_s * self.vocoder.config.sampling_rate)
|
||||
if height % self.vae_scale_factor != 0:
|
||||
height = int(np.ceil(height / self.vae_scale_factor)) * self.vae_scale_factor
|
||||
logger.info(
|
||||
f"Audio length in seconds {audio_length_in_s} is increased to {height * vocoder_upsample_factor} "
|
||||
f"so that it can be handled by the model. It will be cut to {audio_length_in_s} after the "
|
||||
f"denoising process."
|
||||
)
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
audio_length_in_s,
|
||||
vocoder_upsample_factor,
|
||||
callback_steps,
|
||||
negative_prompt,
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
generated_prompt_embeds,
|
||||
negative_generated_prompt_embeds,
|
||||
attention_mask,
|
||||
negative_attention_mask,
|
||||
)
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
# 3. Encode input prompt
|
||||
prompt_embeds, attention_mask, generated_prompt_embeds = self.encode_prompt(
|
||||
prompt,
|
||||
device,
|
||||
num_waveforms_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
generated_prompt_embeds=generated_prompt_embeds,
|
||||
negative_generated_prompt_embeds=negative_generated_prompt_embeds,
|
||||
attention_mask=attention_mask,
|
||||
negative_attention_mask=negative_attention_mask,
|
||||
max_new_tokens=max_new_tokens,
|
||||
)
|
||||
|
||||
# 4. Prepare timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# 5. Prepare latent variables
|
||||
num_channels_latents = self.unet.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_waveforms_per_prompt,
|
||||
num_channels_latents,
|
||||
height,
|
||||
prompt_embeds.dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
# 6. Prepare extra step kwargs
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 7. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(
|
||||
latent_model_input,
|
||||
t,
|
||||
encoder_hidden_states=generated_prompt_embeds,
|
||||
encoder_hidden_states_1=prompt_embeds,
|
||||
encoder_attention_mask_1=attention_mask,
|
||||
).sample
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
|
||||
# 8. Post-processing
|
||||
if not output_type == "latent":
|
||||
latents = 1 / self.vae.config.scaling_factor * latents
|
||||
mel_spectrogram = self.vae.decode(latents).sample
|
||||
else:
|
||||
return AudioPipelineOutput(audios=latents)
|
||||
|
||||
audio = self.mel_spectrogram_to_waveform(mel_spectrogram)
|
||||
|
||||
audio = audio[:, :original_waveform_length]
|
||||
|
||||
# 9. Automatic scoring
|
||||
if num_waveforms_per_prompt > 1 and prompt is not None:
|
||||
audio = self.score_waveforms(
|
||||
text=prompt,
|
||||
audio=audio,
|
||||
num_waveforms_per_prompt=num_waveforms_per_prompt,
|
||||
device=device,
|
||||
dtype=prompt_embeds.dtype,
|
||||
)
|
||||
|
||||
if output_type == "np":
|
||||
audio = audio.numpy()
|
||||
|
||||
if not return_dict:
|
||||
return (audio,)
|
||||
|
||||
return AudioPipelineOutput(audios=audio)
|
||||
@@ -14,7 +14,6 @@
|
||||
|
||||
|
||||
import inspect
|
||||
import os
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
@@ -1170,76 +1169,3 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa
|
||||
return (image,)
|
||||
|
||||
return StableDiffusionXLPipelineOutput(images=image)
|
||||
|
||||
# Overrride to properly handle the loading and unloading of the additional text encoder.
|
||||
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.load_lora_weights
|
||||
def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
|
||||
# We could have accessed the unet config from `lora_state_dict()` too. We pass
|
||||
# it here explicitly to be able to tell that it's coming from an SDXL
|
||||
# pipeline.
|
||||
state_dict, network_alphas = self.lora_state_dict(
|
||||
pretrained_model_name_or_path_or_dict,
|
||||
unet_config=self.unet.config,
|
||||
**kwargs,
|
||||
)
|
||||
self.load_lora_into_unet(state_dict, network_alphas=network_alphas, unet=self.unet)
|
||||
|
||||
text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
|
||||
if len(text_encoder_state_dict) > 0:
|
||||
self.load_lora_into_text_encoder(
|
||||
text_encoder_state_dict,
|
||||
network_alphas=network_alphas,
|
||||
text_encoder=self.text_encoder,
|
||||
prefix="text_encoder",
|
||||
lora_scale=self.lora_scale,
|
||||
)
|
||||
|
||||
text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k}
|
||||
if len(text_encoder_2_state_dict) > 0:
|
||||
self.load_lora_into_text_encoder(
|
||||
text_encoder_2_state_dict,
|
||||
network_alphas=network_alphas,
|
||||
text_encoder=self.text_encoder_2,
|
||||
prefix="text_encoder_2",
|
||||
lora_scale=self.lora_scale,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.save_lora_weights
|
||||
def save_lora_weights(
|
||||
self,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
||||
text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
||||
text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
||||
is_main_process: bool = True,
|
||||
weight_name: str = None,
|
||||
save_function: Callable = None,
|
||||
safe_serialization: bool = True,
|
||||
):
|
||||
state_dict = {}
|
||||
|
||||
def pack_weights(layers, prefix):
|
||||
layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
|
||||
layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
|
||||
return layers_state_dict
|
||||
|
||||
state_dict.update(pack_weights(unet_lora_layers, "unet"))
|
||||
|
||||
if text_encoder_lora_layers and text_encoder_2_lora_layers:
|
||||
state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder"))
|
||||
state_dict.update(pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
|
||||
|
||||
self.write_lora_layers(
|
||||
state_dict=state_dict,
|
||||
save_directory=save_directory,
|
||||
is_main_process=is_main_process,
|
||||
weight_name=weight_name,
|
||||
save_function=save_function,
|
||||
safe_serialization=safe_serialization,
|
||||
)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._remove_text_encoder_monkey_patch
|
||||
def _remove_text_encoder_monkey_patch(self):
|
||||
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder)
|
||||
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder_2)
|
||||
|
||||
@@ -924,7 +924,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
||||
variant = kwargs.pop("variant", None)
|
||||
use_safetensors = kwargs.pop("use_safetensors", None)
|
||||
use_onnx = kwargs.pop("use_onnx", None)
|
||||
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
|
||||
|
||||
# 1. Download the checkpoints and configs
|
||||
@@ -941,7 +940,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
revision=revision,
|
||||
from_flax=from_flax,
|
||||
use_safetensors=use_safetensors,
|
||||
use_onnx=use_onnx,
|
||||
custom_pipeline=custom_pipeline,
|
||||
custom_revision=custom_revision,
|
||||
variant=variant,
|
||||
|
||||
@@ -981,9 +981,6 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
|
||||
prompt_embeds_edit[1:2] += edit_direction
|
||||
|
||||
# 10. Second denoising loop to generate the edited image.
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
latents = latents_init
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
|
||||
@@ -96,7 +96,6 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.timesteps = torch.from_numpy(timesteps)
|
||||
self.custom_timesteps = False
|
||||
self.is_scale_input_called = False
|
||||
self._step_index = None
|
||||
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
if schedule_timesteps is None:
|
||||
@@ -105,13 +104,6 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
|
||||
indices = (schedule_timesteps == timestep).nonzero()
|
||||
return indices.item()
|
||||
|
||||
@property
|
||||
def step_index(self):
|
||||
"""
|
||||
The index counter for current timestep. It will increae 1 after each scheduler step.
|
||||
"""
|
||||
return self._step_index
|
||||
|
||||
def scale_model_input(
|
||||
self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
|
||||
) -> torch.FloatTensor:
|
||||
@@ -129,10 +121,10 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
|
||||
A scaled input sample.
|
||||
"""
|
||||
# Get sigma corresponding to timestep
|
||||
if self.step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
|
||||
sigma = self.sigmas[self.step_index]
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
step_idx = self.index_for_timestep(timestep)
|
||||
sigma = self.sigmas[step_idx]
|
||||
|
||||
sample = sample / ((sigma**2 + self.config.sigma_data**2) ** 0.5)
|
||||
|
||||
@@ -228,8 +220,6 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
|
||||
else:
|
||||
self.timesteps = torch.from_numpy(timesteps).to(device=device)
|
||||
|
||||
self._step_index = None
|
||||
|
||||
# Modified _convert_to_karras implementation that takes in ramp as argument
|
||||
def _convert_to_karras(self, ramp):
|
||||
"""Constructs the noise schedule of Karras et al. (2022)."""
|
||||
@@ -277,24 +267,6 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
|
||||
c_out = (sigma - sigma_min) * sigma_data / (sigma**2 + sigma_data**2) ** 0.5
|
||||
return c_skip, c_out
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
|
||||
def _init_step_index(self, timestep):
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
|
||||
index_candidates = (self.timesteps == timestep).nonzero()
|
||||
|
||||
# The sigma index that is taken for the **very** first `step`
|
||||
# is always the second index (or the last index if there is only 1)
|
||||
# This way we can ensure we don't accidentally skip a sigma in
|
||||
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
||||
if len(index_candidates) > 1:
|
||||
step_index = index_candidates[1]
|
||||
else:
|
||||
step_index = index_candidates[0]
|
||||
|
||||
self._step_index = step_index.item()
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: torch.FloatTensor,
|
||||
@@ -346,16 +318,18 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
|
||||
"See `StableDiffusionPipeline` for a usage example."
|
||||
)
|
||||
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
|
||||
sigma_min = self.config.sigma_min
|
||||
sigma_max = self.config.sigma_max
|
||||
|
||||
if self.step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
step_index = self.index_for_timestep(timestep)
|
||||
|
||||
# sigma_next corresponds to next_t in original implementation
|
||||
sigma = self.sigmas[self.step_index]
|
||||
if self.step_index + 1 < self.config.num_train_timesteps:
|
||||
sigma_next = self.sigmas[self.step_index + 1]
|
||||
sigma = self.sigmas[step_index]
|
||||
if step_index + 1 < self.config.num_train_timesteps:
|
||||
sigma_next = self.sigmas[step_index + 1]
|
||||
else:
|
||||
# Set sigma_next to sigma_min
|
||||
sigma_next = self.sigmas[-1]
|
||||
@@ -384,9 +358,6 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
|
||||
# tau = sigma_hat, eps = sigma_min
|
||||
prev_sample = denoised + z * (sigma_hat**2 - sigma_min**2) ** 0.5
|
||||
|
||||
# upon completion increase step index by one
|
||||
self._step_index += 1
|
||||
|
||||
if not return_dict:
|
||||
return (prev_sample,)
|
||||
|
||||
|
||||
@@ -166,8 +166,6 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.timesteps = torch.from_numpy(timesteps)
|
||||
self.is_scale_input_called = False
|
||||
|
||||
self._step_index = None
|
||||
|
||||
@property
|
||||
def init_noise_sigma(self):
|
||||
# standard deviation of the initial noise distribution
|
||||
@@ -176,13 +174,6 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
return (self.sigmas.max() ** 2 + 1) ** 0.5
|
||||
|
||||
@property
|
||||
def step_index(self):
|
||||
"""
|
||||
The index counter for current timestep. It will increae 1 after each scheduler step.
|
||||
"""
|
||||
return self._step_index
|
||||
|
||||
def scale_model_input(
|
||||
self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
|
||||
) -> torch.FloatTensor:
|
||||
@@ -200,11 +191,10 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
`torch.FloatTensor`:
|
||||
A scaled input sample.
|
||||
"""
|
||||
|
||||
if self.step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
|
||||
sigma = self.sigmas[self.step_index]
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
step_index = (self.timesteps == timestep).nonzero().item()
|
||||
sigma = self.sigmas[step_index]
|
||||
sample = sample / ((sigma**2 + 1) ** 0.5)
|
||||
self.is_scale_input_called = True
|
||||
return sample
|
||||
@@ -223,20 +213,20 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
|
||||
if self.config.timestep_spacing == "linspace":
|
||||
timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=np.float32)[
|
||||
timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[
|
||||
::-1
|
||||
].copy()
|
||||
elif self.config.timestep_spacing == "leading":
|
||||
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
|
||||
# creates integer timesteps by multiplying by ratio
|
||||
# casting to int to avoid issues when num_inference_step is power of 3
|
||||
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.float32)
|
||||
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float)
|
||||
timesteps += self.config.steps_offset
|
||||
elif self.config.timestep_spacing == "trailing":
|
||||
step_ratio = self.config.num_train_timesteps / self.num_inference_steps
|
||||
# creates integer timesteps by multiplying by ratio
|
||||
# casting to int to avoid issues when num_inference_step is power of 3
|
||||
timesteps = (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy().astype(np.float32)
|
||||
timesteps = (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy().astype(float)
|
||||
timesteps -= 1
|
||||
else:
|
||||
raise ValueError(
|
||||
@@ -247,27 +237,11 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
|
||||
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
|
||||
self.sigmas = torch.from_numpy(sigmas).to(device=device)
|
||||
|
||||
self.timesteps = torch.from_numpy(timesteps).to(device=device)
|
||||
self._step_index = None
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
|
||||
def _init_step_index(self, timestep):
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
|
||||
index_candidates = (self.timesteps == timestep).nonzero()
|
||||
|
||||
# The sigma index that is taken for the **very** first `step`
|
||||
# is always the second index (or the last index if there is only 1)
|
||||
# This way we can ensure we don't accidentally skip a sigma in
|
||||
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
||||
if len(index_candidates) > 1:
|
||||
step_index = index_candidates[1]
|
||||
if str(device).startswith("mps"):
|
||||
# mps does not support float64
|
||||
self.timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32)
|
||||
else:
|
||||
step_index = index_candidates[0]
|
||||
|
||||
self._step_index = step_index.item()
|
||||
self.timesteps = torch.from_numpy(timesteps).to(device=device)
|
||||
|
||||
def step(
|
||||
self,
|
||||
@@ -321,10 +295,11 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
"See `StableDiffusionPipeline` for a usage example."
|
||||
)
|
||||
|
||||
if self.step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
|
||||
sigma = self.sigmas[self.step_index]
|
||||
step_index = (self.timesteps == timestep).nonzero().item()
|
||||
sigma = self.sigmas[step_index]
|
||||
|
||||
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
|
||||
if self.config.prediction_type == "epsilon":
|
||||
@@ -339,8 +314,8 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
|
||||
)
|
||||
|
||||
sigma_from = self.sigmas[self.step_index]
|
||||
sigma_to = self.sigmas[self.step_index + 1]
|
||||
sigma_from = self.sigmas[step_index]
|
||||
sigma_to = self.sigmas[step_index + 1]
|
||||
sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5
|
||||
sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
|
||||
|
||||
@@ -356,9 +331,6 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
prev_sample = prev_sample + noise * sigma_up
|
||||
|
||||
# upon completion increase step index by one
|
||||
self._step_index += 1
|
||||
|
||||
if not return_dict:
|
||||
return (prev_sample,)
|
||||
|
||||
|
||||
@@ -175,8 +175,6 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.is_scale_input_called = False
|
||||
self.use_karras_sigmas = use_karras_sigmas
|
||||
|
||||
self._step_index = None
|
||||
|
||||
@property
|
||||
def init_noise_sigma(self):
|
||||
# standard deviation of the initial noise distribution
|
||||
@@ -185,13 +183,6 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
return (self.sigmas.max() ** 2 + 1) ** 0.5
|
||||
|
||||
@property
|
||||
def step_index(self):
|
||||
"""
|
||||
The index counter for current timestep. It will increae 1 after each scheduler step.
|
||||
"""
|
||||
return self._step_index
|
||||
|
||||
def scale_model_input(
|
||||
self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
|
||||
) -> torch.FloatTensor:
|
||||
@@ -209,10 +200,11 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
`torch.FloatTensor`:
|
||||
A scaled input sample.
|
||||
"""
|
||||
if self.step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
step_index = (self.timesteps == timestep).nonzero().item()
|
||||
sigma = self.sigmas[step_index]
|
||||
|
||||
sigma = self.sigmas[self.step_index]
|
||||
sample = sample / ((sigma**2 + 1) ** 0.5)
|
||||
|
||||
self.is_scale_input_called = True
|
||||
@@ -232,20 +224,20 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
|
||||
if self.config.timestep_spacing == "linspace":
|
||||
timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=np.float32)[
|
||||
timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[
|
||||
::-1
|
||||
].copy()
|
||||
elif self.config.timestep_spacing == "leading":
|
||||
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
|
||||
# creates integer timesteps by multiplying by ratio
|
||||
# casting to int to avoid issues when num_inference_step is power of 3
|
||||
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.float32)
|
||||
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float)
|
||||
timesteps += self.config.steps_offset
|
||||
elif self.config.timestep_spacing == "trailing":
|
||||
step_ratio = self.config.num_train_timesteps / self.num_inference_steps
|
||||
# creates integer timesteps by multiplying by ratio
|
||||
# casting to int to avoid issues when num_inference_step is power of 3
|
||||
timesteps = (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy().astype(np.float32)
|
||||
timesteps = (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy().astype(float)
|
||||
timesteps -= 1
|
||||
else:
|
||||
raise ValueError(
|
||||
@@ -271,9 +263,11 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
|
||||
self.sigmas = torch.from_numpy(sigmas).to(device=device)
|
||||
|
||||
self.timesteps = torch.from_numpy(timesteps).to(device=device)
|
||||
self._step_index = None
|
||||
if str(device).startswith("mps"):
|
||||
# mps does not support float64
|
||||
self.timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32)
|
||||
else:
|
||||
self.timesteps = torch.from_numpy(timesteps).to(device=device)
|
||||
|
||||
def _sigma_to_t(self, sigma, log_sigmas):
|
||||
# get log sigma
|
||||
@@ -312,23 +306,6 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
|
||||
return sigmas
|
||||
|
||||
def _init_step_index(self, timestep):
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
|
||||
index_candidates = (self.timesteps == timestep).nonzero()
|
||||
|
||||
# The sigma index that is taken for the **very** first `step`
|
||||
# is always the second index (or the last index if there is only 1)
|
||||
# This way we can ensure we don't accidentally skip a sigma in
|
||||
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
||||
if len(index_candidates) > 1:
|
||||
step_index = index_candidates[1]
|
||||
else:
|
||||
step_index = index_candidates[0]
|
||||
|
||||
self._step_index = step_index.item()
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: torch.FloatTensor,
|
||||
@@ -388,10 +365,11 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
"See `StableDiffusionPipeline` for a usage example."
|
||||
)
|
||||
|
||||
if self.step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
|
||||
sigma = self.sigmas[self.step_index]
|
||||
step_index = (self.timesteps == timestep).nonzero().item()
|
||||
sigma = self.sigmas[step_index]
|
||||
|
||||
gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0
|
||||
|
||||
@@ -423,13 +401,10 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
# 2. Convert to an ODE derivative
|
||||
derivative = (sample - pred_original_sample) / sigma_hat
|
||||
|
||||
dt = self.sigmas[self.step_index + 1] - sigma_hat
|
||||
dt = self.sigmas[step_index + 1] - sigma_hat
|
||||
|
||||
prev_sample = sample + derivative * dt
|
||||
|
||||
# upon completion increase step index by one
|
||||
self._step_index += 1
|
||||
|
||||
if not return_dict:
|
||||
return (prev_sample,)
|
||||
|
||||
|
||||
@@ -149,8 +149,6 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.set_timesteps(num_train_timesteps, None, num_train_timesteps)
|
||||
self.use_karras_sigmas = use_karras_sigmas
|
||||
|
||||
self._step_index = None
|
||||
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
if schedule_timesteps is None:
|
||||
schedule_timesteps = self.timesteps
|
||||
@@ -177,13 +175,6 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
return (self.sigmas.max() ** 2 + 1) ** 0.5
|
||||
|
||||
@property
|
||||
def step_index(self):
|
||||
"""
|
||||
The index counter for current timestep. It will increae 1 after each scheduler step.
|
||||
"""
|
||||
return self._step_index
|
||||
|
||||
def scale_model_input(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
@@ -203,10 +194,9 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
`torch.FloatTensor`:
|
||||
A scaled input sample.
|
||||
"""
|
||||
if self.step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
step_index = self.index_for_timestep(timestep)
|
||||
|
||||
sigma = self.sigmas[self.step_index]
|
||||
sigma = self.sigmas[step_index]
|
||||
sample = sample / ((sigma**2 + 1) ** 0.5)
|
||||
return sample
|
||||
|
||||
@@ -231,18 +221,18 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
|
||||
if self.config.timestep_spacing == "linspace":
|
||||
timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=np.float32)[::-1].copy()
|
||||
timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy()
|
||||
elif self.config.timestep_spacing == "leading":
|
||||
step_ratio = num_train_timesteps // self.num_inference_steps
|
||||
# creates integer timesteps by multiplying by ratio
|
||||
# casting to int to avoid issues when num_inference_step is power of 3
|
||||
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.float32)
|
||||
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float)
|
||||
timesteps += self.config.steps_offset
|
||||
elif self.config.timestep_spacing == "trailing":
|
||||
step_ratio = num_train_timesteps / self.num_inference_steps
|
||||
# creates integer timesteps by multiplying by ratio
|
||||
# casting to int to avoid issues when num_inference_step is power of 3
|
||||
timesteps = (np.arange(num_train_timesteps, 0, -step_ratio)).round().copy().astype(np.float32)
|
||||
timesteps = (np.arange(num_train_timesteps, 0, -step_ratio)).round().copy().astype(float)
|
||||
timesteps -= 1
|
||||
else:
|
||||
raise ValueError(
|
||||
@@ -264,15 +254,16 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
timesteps = torch.from_numpy(timesteps)
|
||||
timesteps = torch.cat([timesteps[:1], timesteps[1:].repeat_interleave(2)])
|
||||
|
||||
self.timesteps = timesteps.to(device=device)
|
||||
if str(device).startswith("mps"):
|
||||
# mps does not support float64
|
||||
self.timesteps = timesteps.to(device, dtype=torch.float32)
|
||||
else:
|
||||
self.timesteps = timesteps.to(device=device)
|
||||
|
||||
# empty dt and derivative
|
||||
self.prev_derivative = None
|
||||
self.dt = None
|
||||
|
||||
self._step_index = None
|
||||
|
||||
# (YiYi Notes: keep this for now since we are keeping add_noise function which use index_for_timestep)
|
||||
# for exp beta schedules, such as the one for `pipeline_shap_e.py`
|
||||
# we need an index counter
|
||||
self._index_counter = defaultdict(int)
|
||||
@@ -319,24 +310,6 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
def state_in_first_order(self):
|
||||
return self.dt is None
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
|
||||
def _init_step_index(self, timestep):
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
|
||||
index_candidates = (self.timesteps == timestep).nonzero()
|
||||
|
||||
# The sigma index that is taken for the **very** first `step`
|
||||
# is always the second index (or the last index if there is only 1)
|
||||
# This way we can ensure we don't accidentally skip a sigma in
|
||||
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
||||
if len(index_candidates) > 1:
|
||||
step_index = index_candidates[1]
|
||||
else:
|
||||
step_index = index_candidates[0]
|
||||
|
||||
self._step_index = step_index.item()
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: Union[torch.FloatTensor, np.ndarray],
|
||||
@@ -363,21 +336,19 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
|
||||
tuple is returned where the first element is the sample tensor.
|
||||
"""
|
||||
if self.step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
step_index = self.index_for_timestep(timestep)
|
||||
|
||||
# (YiYi notes: keep this for now since we are keeping the add_noise method)
|
||||
# advance index counter by 1
|
||||
timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep
|
||||
self._index_counter[timestep_int] += 1
|
||||
|
||||
if self.state_in_first_order:
|
||||
sigma = self.sigmas[self.step_index]
|
||||
sigma_next = self.sigmas[self.step_index + 1]
|
||||
sigma = self.sigmas[step_index]
|
||||
sigma_next = self.sigmas[step_index + 1]
|
||||
else:
|
||||
# 2nd order / Heun's method
|
||||
sigma = self.sigmas[self.step_index - 1]
|
||||
sigma_next = self.sigmas[self.step_index]
|
||||
sigma = self.sigmas[step_index - 1]
|
||||
sigma_next = self.sigmas[step_index]
|
||||
|
||||
# currently only gamma=0 is supported. This usually works best anyways.
|
||||
# We can support gamma in the future but then need to scale the timestep before
|
||||
@@ -433,9 +404,6 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
prev_sample = sample + derivative * dt
|
||||
|
||||
# upon completion increase step index by one
|
||||
self._step_index += 1
|
||||
|
||||
if not return_dict:
|
||||
return (prev_sample,)
|
||||
|
||||
|
||||
@@ -55,14 +55,6 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# running values
|
||||
self.ets = []
|
||||
self._step_index = None
|
||||
|
||||
@property
|
||||
def step_index(self):
|
||||
"""
|
||||
The index counter for current timestep. It will increae 1 after each scheduler step.
|
||||
"""
|
||||
return self._step_index
|
||||
|
||||
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
|
||||
"""
|
||||
@@ -89,25 +81,6 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.timesteps = timesteps.to(device)
|
||||
|
||||
self.ets = []
|
||||
self._step_index = None
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
|
||||
def _init_step_index(self, timestep):
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
|
||||
index_candidates = (self.timesteps == timestep).nonzero()
|
||||
|
||||
# The sigma index that is taken for the **very** first `step`
|
||||
# is always the second index (or the last index if there is only 1)
|
||||
# This way we can ensure we don't accidentally skip a sigma in
|
||||
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
||||
if len(index_candidates) > 1:
|
||||
step_index = index_candidates[1]
|
||||
else:
|
||||
step_index = index_candidates[0]
|
||||
|
||||
self._step_index = step_index.item()
|
||||
|
||||
def step(
|
||||
self,
|
||||
@@ -139,11 +112,9 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
raise ValueError(
|
||||
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
|
||||
)
|
||||
if self.step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
|
||||
timestep_index = self.step_index
|
||||
prev_timestep_index = self.step_index + 1
|
||||
timestep_index = (self.timesteps == timestep).nonzero().item()
|
||||
prev_timestep_index = timestep_index + 1
|
||||
|
||||
ets = sample * self.betas[timestep_index] + model_output * self.alphas[timestep_index]
|
||||
self.ets.append(ets)
|
||||
@@ -159,9 +130,6 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
prev_sample = self._get_prev_sample(sample, timestep_index, prev_timestep_index, ets)
|
||||
|
||||
# upon completion increase step index by one
|
||||
self._step_index += 1
|
||||
|
||||
if not return_dict:
|
||||
return (prev_sample,)
|
||||
|
||||
|
||||
@@ -137,7 +137,6 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# set all values
|
||||
self.set_timesteps(num_train_timesteps, None, num_train_timesteps)
|
||||
self._step_index = None
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.index_for_timestep
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
@@ -166,13 +165,6 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
return (self.sigmas.max() ** 2 + 1) ** 0.5
|
||||
|
||||
@property
|
||||
def step_index(self):
|
||||
"""
|
||||
The index counter for current timestep. It will increae 1 after each scheduler step.
|
||||
"""
|
||||
return self._step_index
|
||||
|
||||
def scale_model_input(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
@@ -192,13 +184,12 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
`torch.FloatTensor`:
|
||||
A scaled input sample.
|
||||
"""
|
||||
if self.step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
step_index = self.index_for_timestep(timestep)
|
||||
|
||||
if self.state_in_first_order:
|
||||
sigma = self.sigmas[self.step_index]
|
||||
sigma = self.sigmas[step_index]
|
||||
else:
|
||||
sigma = self.sigmas_interpol[self.step_index - 1]
|
||||
sigma = self.sigmas_interpol[step_index - 1]
|
||||
|
||||
sample = sample / ((sigma**2 + 1) ** 0.5)
|
||||
return sample
|
||||
@@ -224,18 +215,18 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
|
||||
if self.config.timestep_spacing == "linspace":
|
||||
timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=np.float32)[::-1].copy()
|
||||
timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy()
|
||||
elif self.config.timestep_spacing == "leading":
|
||||
step_ratio = num_train_timesteps // self.num_inference_steps
|
||||
# creates integer timesteps by multiplying by ratio
|
||||
# casting to int to avoid issues when num_inference_step is power of 3
|
||||
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.float32)
|
||||
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float)
|
||||
timesteps += self.config.steps_offset
|
||||
elif self.config.timestep_spacing == "trailing":
|
||||
step_ratio = num_train_timesteps / self.num_inference_steps
|
||||
# creates integer timesteps by multiplying by ratio
|
||||
# casting to int to avoid issues when num_inference_step is power of 3
|
||||
timesteps = (np.arange(num_train_timesteps, 0, -step_ratio)).round().copy().astype(np.float32)
|
||||
timesteps = (np.arange(num_train_timesteps, 0, -step_ratio)).round().copy().astype(float)
|
||||
timesteps -= 1
|
||||
else:
|
||||
raise ValueError(
|
||||
@@ -268,7 +259,12 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.sigmas_up = torch.cat([sigmas_up[:1], sigmas_up[1:].repeat_interleave(2), sigmas_up[-1:]])
|
||||
self.sigmas_down = torch.cat([sigmas_down[:1], sigmas_down[1:].repeat_interleave(2), sigmas_down[-1:]])
|
||||
|
||||
timesteps = torch.from_numpy(timesteps).to(device)
|
||||
if str(device).startswith("mps"):
|
||||
# mps does not support float64
|
||||
timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32)
|
||||
else:
|
||||
timesteps = torch.from_numpy(timesteps).to(device)
|
||||
|
||||
timesteps_interpol = self.sigma_to_t(sigmas_interpol).to(device, dtype=timesteps.dtype)
|
||||
interleaved_timesteps = torch.stack((timesteps_interpol[:-2, None], timesteps[1:, None]), dim=-1).flatten()
|
||||
|
||||
@@ -280,8 +276,6 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
# we need an index counter
|
||||
self._index_counter = defaultdict(int)
|
||||
|
||||
self._step_index = None
|
||||
|
||||
def sigma_to_t(self, sigma):
|
||||
# get log sigma
|
||||
log_sigma = sigma.log()
|
||||
@@ -309,24 +303,6 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
def state_in_first_order(self):
|
||||
return self.sample is None
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
|
||||
def _init_step_index(self, timestep):
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
|
||||
index_candidates = (self.timesteps == timestep).nonzero()
|
||||
|
||||
# The sigma index that is taken for the **very** first `step`
|
||||
# is always the second index (or the last index if there is only 1)
|
||||
# This way we can ensure we don't accidentally skip a sigma in
|
||||
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
||||
if len(index_candidates) > 1:
|
||||
step_index = index_candidates[1]
|
||||
else:
|
||||
step_index = index_candidates[0]
|
||||
|
||||
self._step_index = step_index.item()
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: Union[torch.FloatTensor, np.ndarray],
|
||||
@@ -356,24 +332,23 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
If return_dict is `True`, [`~schedulers.scheduling_ddim.SchedulerOutput`] is returned, otherwise a
|
||||
tuple is returned where the first element is the sample tensor.
|
||||
"""
|
||||
if self.step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
step_index = self.index_for_timestep(timestep)
|
||||
|
||||
# advance index counter by 1
|
||||
timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep
|
||||
self._index_counter[timestep_int] += 1
|
||||
|
||||
if self.state_in_first_order:
|
||||
sigma = self.sigmas[self.step_index]
|
||||
sigma_interpol = self.sigmas_interpol[self.step_index]
|
||||
sigma_up = self.sigmas_up[self.step_index]
|
||||
sigma_down = self.sigmas_down[self.step_index - 1]
|
||||
sigma = self.sigmas[step_index]
|
||||
sigma_interpol = self.sigmas_interpol[step_index]
|
||||
sigma_up = self.sigmas_up[step_index]
|
||||
sigma_down = self.sigmas_down[step_index - 1]
|
||||
else:
|
||||
# 2nd order / KPDM2's method
|
||||
sigma = self.sigmas[self.step_index - 1]
|
||||
sigma_interpol = self.sigmas_interpol[self.step_index - 1]
|
||||
sigma_up = self.sigmas_up[self.step_index - 1]
|
||||
sigma_down = self.sigmas_down[self.step_index - 1]
|
||||
sigma = self.sigmas[step_index - 1]
|
||||
sigma_interpol = self.sigmas_interpol[step_index - 1]
|
||||
sigma_up = self.sigmas_up[step_index - 1]
|
||||
sigma_down = self.sigmas_down[step_index - 1]
|
||||
|
||||
# currently only gamma=0 is supported. This usually works best anyways.
|
||||
# We can support gamma in the future but then need to scale the timestep before
|
||||
@@ -423,9 +398,6 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
prev_sample = sample + derivative * dt
|
||||
prev_sample = prev_sample + noise * sigma_up
|
||||
|
||||
# upon completion increase step index by one
|
||||
self._step_index += 1
|
||||
|
||||
if not return_dict:
|
||||
return (prev_sample,)
|
||||
|
||||
|
||||
@@ -137,8 +137,6 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
# set all values
|
||||
self.set_timesteps(num_train_timesteps, None, num_train_timesteps)
|
||||
|
||||
self._step_index = None
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.index_for_timestep
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
if schedule_timesteps is None:
|
||||
@@ -166,13 +164,6 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
return (self.sigmas.max() ** 2 + 1) ** 0.5
|
||||
|
||||
@property
|
||||
def step_index(self):
|
||||
"""
|
||||
The index counter for current timestep. It will increae 1 after each scheduler step.
|
||||
"""
|
||||
return self._step_index
|
||||
|
||||
def scale_model_input(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
@@ -192,13 +183,12 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
`torch.FloatTensor`:
|
||||
A scaled input sample.
|
||||
"""
|
||||
if self.step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
step_index = self.index_for_timestep(timestep)
|
||||
|
||||
if self.state_in_first_order:
|
||||
sigma = self.sigmas[self.step_index]
|
||||
sigma = self.sigmas[step_index]
|
||||
else:
|
||||
sigma = self.sigmas_interpol[self.step_index]
|
||||
sigma = self.sigmas_interpol[step_index]
|
||||
|
||||
sample = sample / ((sigma**2 + 1) ** 0.5)
|
||||
return sample
|
||||
@@ -224,18 +214,18 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
|
||||
if self.config.timestep_spacing == "linspace":
|
||||
timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=np.float32)[::-1].copy()
|
||||
timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy()
|
||||
elif self.config.timestep_spacing == "leading":
|
||||
step_ratio = num_train_timesteps // self.num_inference_steps
|
||||
# creates integer timesteps by multiplying by ratio
|
||||
# casting to int to avoid issues when num_inference_step is power of 3
|
||||
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.float32)
|
||||
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float)
|
||||
timesteps += self.config.steps_offset
|
||||
elif self.config.timestep_spacing == "trailing":
|
||||
step_ratio = num_train_timesteps / self.num_inference_steps
|
||||
# creates integer timesteps by multiplying by ratio
|
||||
# casting to int to avoid issues when num_inference_step is power of 3
|
||||
timesteps = (np.arange(num_train_timesteps, 0, -step_ratio)).round().copy().astype(np.float32)
|
||||
timesteps = (np.arange(num_train_timesteps, 0, -step_ratio)).round().copy().astype(float)
|
||||
timesteps -= 1
|
||||
else:
|
||||
raise ValueError(
|
||||
@@ -257,7 +247,11 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
[sigmas_interpol[:1], sigmas_interpol[1:].repeat_interleave(2), sigmas_interpol[-1:]]
|
||||
)
|
||||
|
||||
timesteps = torch.from_numpy(timesteps).to(device)
|
||||
if str(device).startswith("mps"):
|
||||
# mps does not support float64
|
||||
timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32)
|
||||
else:
|
||||
timesteps = torch.from_numpy(timesteps).to(device)
|
||||
|
||||
# interpolate timesteps
|
||||
timesteps_interpol = self.sigma_to_t(sigmas_interpol).to(device, dtype=timesteps.dtype)
|
||||
@@ -271,8 +265,6 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
# we need an index counter
|
||||
self._index_counter = defaultdict(int)
|
||||
|
||||
self._step_index = None
|
||||
|
||||
def sigma_to_t(self, sigma):
|
||||
# get log sigma
|
||||
log_sigma = sigma.log()
|
||||
@@ -300,24 +292,6 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
def state_in_first_order(self):
|
||||
return self.sample is None
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
|
||||
def _init_step_index(self, timestep):
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
|
||||
index_candidates = (self.timesteps == timestep).nonzero()
|
||||
|
||||
# The sigma index that is taken for the **very** first `step`
|
||||
# is always the second index (or the last index if there is only 1)
|
||||
# This way we can ensure we don't accidentally skip a sigma in
|
||||
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
||||
if len(index_candidates) > 1:
|
||||
step_index = index_candidates[1]
|
||||
else:
|
||||
step_index = index_candidates[0]
|
||||
|
||||
self._step_index = step_index.item()
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: Union[torch.FloatTensor, np.ndarray],
|
||||
@@ -344,22 +318,21 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
|
||||
tuple is returned where the first element is the sample tensor.
|
||||
"""
|
||||
if self.step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
step_index = self.index_for_timestep(timestep)
|
||||
|
||||
# advance index counter by 1
|
||||
timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep
|
||||
self._index_counter[timestep_int] += 1
|
||||
|
||||
if self.state_in_first_order:
|
||||
sigma = self.sigmas[self.step_index]
|
||||
sigma_interpol = self.sigmas_interpol[self.step_index + 1]
|
||||
sigma_next = self.sigmas[self.step_index + 1]
|
||||
sigma = self.sigmas[step_index]
|
||||
sigma_interpol = self.sigmas_interpol[step_index + 1]
|
||||
sigma_next = self.sigmas[step_index + 1]
|
||||
else:
|
||||
# 2nd order / KDPM2's method
|
||||
sigma = self.sigmas[self.step_index - 1]
|
||||
sigma_interpol = self.sigmas_interpol[self.step_index]
|
||||
sigma_next = self.sigmas[self.step_index]
|
||||
sigma = self.sigmas[step_index - 1]
|
||||
sigma_interpol = self.sigmas_interpol[step_index]
|
||||
sigma_next = self.sigmas[step_index]
|
||||
|
||||
# currently only gamma=0 is supported. This usually works best anyways.
|
||||
# We can support gamma in the future but then need to scale the timestep before
|
||||
@@ -402,9 +375,6 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
sample = self.sample
|
||||
self.sample = None
|
||||
|
||||
# upon completion increase step index by one
|
||||
self._step_index += 1
|
||||
|
||||
prev_sample = sample + derivative * dt
|
||||
|
||||
if not return_dict:
|
||||
|
||||
@@ -169,8 +169,6 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.derivatives = []
|
||||
self.is_scale_input_called = False
|
||||
|
||||
self._step_index = None
|
||||
|
||||
@property
|
||||
def init_noise_sigma(self):
|
||||
# standard deviation of the initial noise distribution
|
||||
@@ -179,13 +177,6 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
return (self.sigmas.max() ** 2 + 1) ** 0.5
|
||||
|
||||
@property
|
||||
def step_index(self):
|
||||
"""
|
||||
The index counter for current timestep. It will increae 1 after each scheduler step.
|
||||
"""
|
||||
return self._step_index
|
||||
|
||||
def scale_model_input(
|
||||
self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
|
||||
) -> torch.FloatTensor:
|
||||
@@ -203,11 +194,10 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
`torch.FloatTensor`:
|
||||
A scaled input sample.
|
||||
"""
|
||||
|
||||
if self.step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
|
||||
sigma = self.sigmas[self.step_index]
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
step_index = (self.timesteps == timestep).nonzero().item()
|
||||
sigma = self.sigmas[step_index]
|
||||
sample = sample / ((sigma**2 + 1) ** 0.5)
|
||||
self.is_scale_input_called = True
|
||||
return sample
|
||||
@@ -248,20 +238,20 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
|
||||
if self.config.timestep_spacing == "linspace":
|
||||
timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=np.float32)[
|
||||
timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[
|
||||
::-1
|
||||
].copy()
|
||||
elif self.config.timestep_spacing == "leading":
|
||||
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
|
||||
# creates integer timesteps by multiplying by ratio
|
||||
# casting to int to avoid issues when num_inference_step is power of 3
|
||||
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.float32)
|
||||
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float)
|
||||
timesteps += self.config.steps_offset
|
||||
elif self.config.timestep_spacing == "trailing":
|
||||
step_ratio = self.config.num_train_timesteps / self.num_inference_steps
|
||||
# creates integer timesteps by multiplying by ratio
|
||||
# casting to int to avoid issues when num_inference_step is power of 3
|
||||
timesteps = (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy().astype(np.float32)
|
||||
timesteps = (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy().astype(float)
|
||||
timesteps -= 1
|
||||
else:
|
||||
raise ValueError(
|
||||
@@ -279,29 +269,14 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
|
||||
|
||||
self.sigmas = torch.from_numpy(sigmas).to(device=device)
|
||||
self.timesteps = torch.from_numpy(timesteps).to(device=device)
|
||||
self._step_index = None
|
||||
if str(device).startswith("mps"):
|
||||
# mps does not support float64
|
||||
self.timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32)
|
||||
else:
|
||||
self.timesteps = torch.from_numpy(timesteps).to(device=device)
|
||||
|
||||
self.derivatives = []
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
|
||||
def _init_step_index(self, timestep):
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
|
||||
index_candidates = (self.timesteps == timestep).nonzero()
|
||||
|
||||
# The sigma index that is taken for the **very** first `step`
|
||||
# is always the second index (or the last index if there is only 1)
|
||||
# This way we can ensure we don't accidentally skip a sigma in
|
||||
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
||||
if len(index_candidates) > 1:
|
||||
step_index = index_candidates[1]
|
||||
else:
|
||||
step_index = index_candidates[0]
|
||||
|
||||
self._step_index = step_index.item()
|
||||
|
||||
# copied from diffusers.schedulers.scheduling_euler_discrete._sigma_to_t
|
||||
def _sigma_to_t(self, sigma, log_sigmas):
|
||||
# get log sigma
|
||||
@@ -376,10 +351,10 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
"See `StableDiffusionPipeline` for a usage example."
|
||||
)
|
||||
|
||||
if self.step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
|
||||
sigma = self.sigmas[self.step_index]
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
step_index = (self.timesteps == timestep).nonzero().item()
|
||||
sigma = self.sigmas[step_index]
|
||||
|
||||
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
|
||||
if self.config.prediction_type == "epsilon":
|
||||
@@ -401,17 +376,14 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.derivatives.pop(0)
|
||||
|
||||
# 3. Compute linear multistep coefficients
|
||||
order = min(self.step_index + 1, order)
|
||||
lms_coeffs = [self.get_lms_coefficient(order, self.step_index, curr_order) for curr_order in range(order)]
|
||||
order = min(step_index + 1, order)
|
||||
lms_coeffs = [self.get_lms_coefficient(order, step_index, curr_order) for curr_order in range(order)]
|
||||
|
||||
# 4. Compute previous sample based on the derivatives path
|
||||
prev_sample = sample + sum(
|
||||
coeff * derivative for coeff, derivative in zip(lms_coeffs, reversed(self.derivatives))
|
||||
)
|
||||
|
||||
# upon completion increase step index by one
|
||||
self._step_index += 1
|
||||
|
||||
if not return_dict:
|
||||
return (prev_sample,)
|
||||
|
||||
|
||||
@@ -32,51 +32,6 @@ class AltDiffusionPipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class AudioLDM2Pipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class AudioLDM2ProjectionModel(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class AudioLDM2UNet2DConditionModel(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class AudioLDMPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
|
||||
@@ -567,7 +567,7 @@ class DummyObject(type):
|
||||
"""
|
||||
|
||||
def __getattr__(cls, key):
|
||||
if key.startswith("_") and key not in ["_load_connected_pipes", "_is_onnx"]:
|
||||
if key.startswith("_") and key != "_load_connected_pipes":
|
||||
return super().__getattr__(cls, key)
|
||||
requires_backends(cls, cls._backends)
|
||||
|
||||
|
||||
@@ -44,13 +44,13 @@ if is_torch_available():
|
||||
|
||||
if "DIFFUSERS_TEST_DEVICE" in os.environ:
|
||||
torch_device = os.environ["DIFFUSERS_TEST_DEVICE"]
|
||||
try:
|
||||
# try creating device to see if provided device is valid
|
||||
_ = torch.device(torch_device)
|
||||
except RuntimeError as e:
|
||||
raise RuntimeError(
|
||||
f"Unknown testing device specified by environment variable `DIFFUSERS_TEST_DEVICE`: {torch_device}"
|
||||
) from e
|
||||
|
||||
available_backends = ["cuda", "cpu", "mps"]
|
||||
if torch_device not in available_backends:
|
||||
raise ValueError(
|
||||
f"unknown torch backend for diffusers tests: {torch_device}. Available backends are:"
|
||||
f" {available_backends}"
|
||||
)
|
||||
logger.info(f"torch_device overrode to {torch_device}")
|
||||
else:
|
||||
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
@@ -221,7 +221,7 @@ class ModelTesterMixin:
|
||||
if isinstance(new_image, dict):
|
||||
new_image = new_image.to_tuple()[0]
|
||||
|
||||
max_diff = (image - new_image).abs().max().item()
|
||||
max_diff = (image - new_image).abs().sum().item()
|
||||
self.assertLessEqual(max_diff, 5e-5, "Models give different forward passes")
|
||||
|
||||
def test_getattr_is_correct(self):
|
||||
@@ -351,7 +351,7 @@ class ModelTesterMixin:
|
||||
if isinstance(new_image, dict):
|
||||
new_image = new_image.to_tuple()[0]
|
||||
|
||||
max_diff = (image - new_image).abs().max().item()
|
||||
max_diff = (image - new_image).abs().sum().item()
|
||||
self.assertLessEqual(max_diff, 5e-5, "Models give different forward passes")
|
||||
|
||||
@require_torch_2
|
||||
|
||||
@@ -137,7 +137,7 @@ class UNetLDMModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
model_accelerate.config.in_channels,
|
||||
model_accelerate.config.sample_size,
|
||||
model_accelerate.config.sample_size,
|
||||
generator=torch.Generator("cpu").manual_seed(0),
|
||||
generator=torch.manual_seed(0),
|
||||
)
|
||||
noise = noise.to(torch_device)
|
||||
time_step = torch.tensor([10] * noise.shape[0]).to(torch_device)
|
||||
@@ -263,7 +263,7 @@ class NCSNppModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
|
||||
output_slice = output[0, -3:, -3:, -1].flatten().cpu()
|
||||
# fmt: off
|
||||
expected_output_slice = torch.tensor([-4836.2178, -6487.1470, -3816.8196, -7964.9302, -10966.3037, -20043.5957, 8137.0513, 2340.3328, 544.6056])
|
||||
expected_output_slice = torch.tensor([-4842.8691, -6499.6631, -3800.1953, -7978.2686, -10980.7129, -20028.8535, 8148.2822, 2342.2905, 567.7608])
|
||||
# fmt: on
|
||||
|
||||
self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-2))
|
||||
|
||||
@@ -726,8 +726,8 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
|
||||
model.disable_xformers_memory_efficient_attention()
|
||||
off_sample = model(**inputs_dict).sample
|
||||
|
||||
assert (sample - on_sample).abs().max() <= 5e-4
|
||||
assert (sample - off_sample).abs().max() <= 5e-4
|
||||
assert (sample - on_sample).abs().max() < 1e-4
|
||||
assert (sample - off_sample).abs().max() < 1e-4
|
||||
|
||||
def test_custom_diffusion_processors(self):
|
||||
# enable deterministic behavior for gradient checkpointing
|
||||
|
||||
@@ -1,571 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import gc
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import (
|
||||
ClapAudioConfig,
|
||||
ClapConfig,
|
||||
ClapFeatureExtractor,
|
||||
ClapModel,
|
||||
ClapTextConfig,
|
||||
GPT2Config,
|
||||
GPT2Model,
|
||||
RobertaTokenizer,
|
||||
SpeechT5HifiGan,
|
||||
SpeechT5HifiGanConfig,
|
||||
T5Config,
|
||||
T5EncoderModel,
|
||||
T5Tokenizer,
|
||||
)
|
||||
|
||||
from diffusers import (
|
||||
AudioLDM2Pipeline,
|
||||
AudioLDM2ProjectionModel,
|
||||
AudioLDM2UNet2DConditionModel,
|
||||
AutoencoderKL,
|
||||
DDIMScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
PNDMScheduler,
|
||||
)
|
||||
from diffusers.utils import is_xformers_available, slow, torch_device
|
||||
from diffusers.utils.testing_utils import enable_full_determinism
|
||||
|
||||
from ..pipeline_params import TEXT_TO_AUDIO_BATCH_PARAMS, TEXT_TO_AUDIO_PARAMS
|
||||
from ..test_pipelines_common import PipelineTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class AudioLDM2PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = AudioLDM2Pipeline
|
||||
params = TEXT_TO_AUDIO_PARAMS
|
||||
batch_params = TEXT_TO_AUDIO_BATCH_PARAMS
|
||||
required_optional_params = frozenset(
|
||||
[
|
||||
"num_inference_steps",
|
||||
"num_waveforms_per_prompt",
|
||||
"generator",
|
||||
"latents",
|
||||
"output_type",
|
||||
"return_dict",
|
||||
"callback",
|
||||
"callback_steps",
|
||||
]
|
||||
)
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
unet = AudioLDM2UNet2DConditionModel(
|
||||
block_out_channels=(32, 64),
|
||||
layers_per_block=2,
|
||||
sample_size=32,
|
||||
in_channels=4,
|
||||
out_channels=4,
|
||||
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
|
||||
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
|
||||
cross_attention_dim=([None, 16, 32], [None, 16, 32]),
|
||||
)
|
||||
scheduler = DDIMScheduler(
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
beta_schedule="scaled_linear",
|
||||
clip_sample=False,
|
||||
set_alpha_to_one=False,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKL(
|
||||
block_out_channels=[32, 64],
|
||||
in_channels=1,
|
||||
out_channels=1,
|
||||
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
|
||||
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
|
||||
latent_channels=4,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
text_branch_config = ClapTextConfig(
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
hidden_size=16,
|
||||
intermediate_size=37,
|
||||
layer_norm_eps=1e-05,
|
||||
num_attention_heads=2,
|
||||
num_hidden_layers=2,
|
||||
pad_token_id=1,
|
||||
vocab_size=1000,
|
||||
projection_dim=16,
|
||||
)
|
||||
audio_branch_config = ClapAudioConfig(
|
||||
spec_size=64,
|
||||
window_size=4,
|
||||
num_mel_bins=64,
|
||||
intermediate_size=37,
|
||||
layer_norm_eps=1e-05,
|
||||
depths=[2, 2],
|
||||
num_attention_heads=[2, 2],
|
||||
num_hidden_layers=2,
|
||||
hidden_size=192,
|
||||
projection_dim=16,
|
||||
patch_size=2,
|
||||
patch_stride=2,
|
||||
patch_embed_input_channels=4,
|
||||
)
|
||||
text_encoder_config = ClapConfig.from_text_audio_configs(
|
||||
text_config=text_branch_config, audio_config=audio_branch_config, projection_dim=16
|
||||
)
|
||||
text_encoder = ClapModel(text_encoder_config)
|
||||
tokenizer = RobertaTokenizer.from_pretrained("hf-internal-testing/tiny-random-roberta", model_max_length=77)
|
||||
feature_extractor = ClapFeatureExtractor.from_pretrained(
|
||||
"hf-internal-testing/tiny-random-ClapModel", hop_length=7900
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
text_encoder_2_config = T5Config(
|
||||
vocab_size=32100,
|
||||
d_model=32,
|
||||
d_ff=37,
|
||||
d_kv=8,
|
||||
num_heads=2,
|
||||
num_layers=2,
|
||||
)
|
||||
text_encoder_2 = T5EncoderModel(text_encoder_2_config)
|
||||
tokenizer_2 = T5Tokenizer.from_pretrained("hf-internal-testing/tiny-random-T5Model", model_max_length=77)
|
||||
|
||||
torch.manual_seed(0)
|
||||
language_model_config = GPT2Config(
|
||||
n_embd=16,
|
||||
n_head=2,
|
||||
n_layer=2,
|
||||
vocab_size=1000,
|
||||
n_ctx=99,
|
||||
n_positions=99,
|
||||
)
|
||||
language_model = GPT2Model(language_model_config)
|
||||
language_model.config.max_new_tokens = 8
|
||||
|
||||
torch.manual_seed(0)
|
||||
projection_model = AudioLDM2ProjectionModel(text_encoder_dim=16, text_encoder_1_dim=32, langauge_model_dim=16)
|
||||
|
||||
vocoder_config = SpeechT5HifiGanConfig(
|
||||
model_in_dim=8,
|
||||
sampling_rate=16000,
|
||||
upsample_initial_channel=16,
|
||||
upsample_rates=[2, 2],
|
||||
upsample_kernel_sizes=[4, 4],
|
||||
resblock_kernel_sizes=[3, 7],
|
||||
resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5]],
|
||||
normalize_before=False,
|
||||
)
|
||||
|
||||
vocoder = SpeechT5HifiGan(vocoder_config)
|
||||
|
||||
components = {
|
||||
"unet": unet,
|
||||
"scheduler": scheduler,
|
||||
"vae": vae,
|
||||
"text_encoder": text_encoder,
|
||||
"text_encoder_2": text_encoder_2,
|
||||
"tokenizer": tokenizer,
|
||||
"tokenizer_2": tokenizer_2,
|
||||
"feature_extractor": feature_extractor,
|
||||
"language_model": language_model,
|
||||
"projection_model": projection_model,
|
||||
"vocoder": vocoder,
|
||||
}
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
inputs = {
|
||||
"prompt": "A hammer hitting a wooden surface",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 6.0,
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_audioldm2_ddim(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
|
||||
components = self.get_dummy_components()
|
||||
audioldm_pipe = AudioLDM2Pipeline(**components)
|
||||
audioldm_pipe = audioldm_pipe.to(torch_device)
|
||||
audioldm_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
output = audioldm_pipe(**inputs)
|
||||
audio = output.audios[0]
|
||||
|
||||
assert audio.ndim == 1
|
||||
assert len(audio) == 256
|
||||
|
||||
audio_slice = audio[:10]
|
||||
expected_slice = np.array(
|
||||
[0.0025, 0.0018, 0.0018, -0.0023, -0.0026, -0.0020, -0.0026, -0.0021, -0.0027, -0.0020]
|
||||
)
|
||||
|
||||
assert np.abs(audio_slice - expected_slice).max() < 1e-4
|
||||
|
||||
def test_audioldm2_prompt_embeds(self):
|
||||
components = self.get_dummy_components()
|
||||
audioldm_pipe = AudioLDM2Pipeline(**components)
|
||||
audioldm_pipe = audioldm_pipe.to(torch_device)
|
||||
audioldm_pipe = audioldm_pipe.to(torch_device)
|
||||
audioldm_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
inputs["prompt"] = 3 * [inputs["prompt"]]
|
||||
|
||||
# forward
|
||||
output = audioldm_pipe(**inputs)
|
||||
audio_1 = output.audios[0]
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
prompt = 3 * [inputs.pop("prompt")]
|
||||
|
||||
text_inputs = audioldm_pipe.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=audioldm_pipe.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_inputs = text_inputs["input_ids"].to(torch_device)
|
||||
|
||||
clap_prompt_embeds = audioldm_pipe.text_encoder.get_text_features(text_inputs)
|
||||
clap_prompt_embeds = clap_prompt_embeds[:, None, :]
|
||||
|
||||
text_inputs = audioldm_pipe.tokenizer_2(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=True,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_inputs = text_inputs["input_ids"].to(torch_device)
|
||||
|
||||
t5_prompt_embeds = audioldm_pipe.text_encoder_2(
|
||||
text_inputs,
|
||||
)
|
||||
t5_prompt_embeds = t5_prompt_embeds[0]
|
||||
|
||||
projection_embeds = audioldm_pipe.projection_model(clap_prompt_embeds, t5_prompt_embeds)[0]
|
||||
generated_prompt_embeds = audioldm_pipe.generate_language_model(projection_embeds, max_new_tokens=8)
|
||||
|
||||
inputs["prompt_embeds"] = t5_prompt_embeds
|
||||
inputs["generated_prompt_embeds"] = generated_prompt_embeds
|
||||
|
||||
# forward
|
||||
output = audioldm_pipe(**inputs)
|
||||
audio_2 = output.audios[0]
|
||||
|
||||
assert np.abs(audio_1 - audio_2).max() < 1e-2
|
||||
|
||||
def test_audioldm2_negative_prompt_embeds(self):
|
||||
components = self.get_dummy_components()
|
||||
audioldm_pipe = AudioLDM2Pipeline(**components)
|
||||
audioldm_pipe = audioldm_pipe.to(torch_device)
|
||||
audioldm_pipe = audioldm_pipe.to(torch_device)
|
||||
audioldm_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
negative_prompt = 3 * ["this is a negative prompt"]
|
||||
inputs["negative_prompt"] = negative_prompt
|
||||
inputs["prompt"] = 3 * [inputs["prompt"]]
|
||||
|
||||
# forward
|
||||
output = audioldm_pipe(**inputs)
|
||||
audio_1 = output.audios[0]
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
prompt = 3 * [inputs.pop("prompt")]
|
||||
|
||||
embeds = []
|
||||
generated_embeds = []
|
||||
for p in [prompt, negative_prompt]:
|
||||
text_inputs = audioldm_pipe.tokenizer(
|
||||
p,
|
||||
padding="max_length",
|
||||
max_length=audioldm_pipe.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_inputs = text_inputs["input_ids"].to(torch_device)
|
||||
|
||||
clap_prompt_embeds = audioldm_pipe.text_encoder.get_text_features(text_inputs)
|
||||
clap_prompt_embeds = clap_prompt_embeds[:, None, :]
|
||||
|
||||
text_inputs = audioldm_pipe.tokenizer_2(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=True if len(embeds) == 0 else embeds[0].shape[1],
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_inputs = text_inputs["input_ids"].to(torch_device)
|
||||
|
||||
t5_prompt_embeds = audioldm_pipe.text_encoder_2(
|
||||
text_inputs,
|
||||
)
|
||||
t5_prompt_embeds = t5_prompt_embeds[0]
|
||||
|
||||
projection_embeds = audioldm_pipe.projection_model(clap_prompt_embeds, t5_prompt_embeds)[0]
|
||||
generated_prompt_embeds = audioldm_pipe.generate_language_model(projection_embeds, max_new_tokens=8)
|
||||
|
||||
embeds.append(t5_prompt_embeds)
|
||||
generated_embeds.append(generated_prompt_embeds)
|
||||
|
||||
inputs["prompt_embeds"], inputs["negative_prompt_embeds"] = embeds
|
||||
inputs["generated_prompt_embeds"], inputs["negative_generated_prompt_embeds"] = generated_embeds
|
||||
|
||||
# forward
|
||||
output = audioldm_pipe(**inputs)
|
||||
audio_2 = output.audios[0]
|
||||
|
||||
assert np.abs(audio_1 - audio_2).max() < 1e-2
|
||||
|
||||
def test_audioldm2_negative_prompt(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
components["scheduler"] = PNDMScheduler(skip_prk_steps=True)
|
||||
audioldm_pipe = AudioLDM2Pipeline(**components)
|
||||
audioldm_pipe = audioldm_pipe.to(device)
|
||||
audioldm_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
negative_prompt = "egg cracking"
|
||||
output = audioldm_pipe(**inputs, negative_prompt=negative_prompt)
|
||||
audio = output.audios[0]
|
||||
|
||||
assert audio.ndim == 1
|
||||
assert len(audio) == 256
|
||||
|
||||
audio_slice = audio[:10]
|
||||
expected_slice = np.array(
|
||||
[0.0025, 0.0018, 0.0018, -0.0023, -0.0026, -0.0020, -0.0026, -0.0021, -0.0027, -0.0020]
|
||||
)
|
||||
|
||||
assert np.abs(audio_slice - expected_slice).max() < 1e-4
|
||||
|
||||
def test_audioldm2_num_waveforms_per_prompt(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
components["scheduler"] = PNDMScheduler(skip_prk_steps=True)
|
||||
audioldm_pipe = AudioLDM2Pipeline(**components)
|
||||
audioldm_pipe = audioldm_pipe.to(device)
|
||||
audioldm_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "A hammer hitting a wooden surface"
|
||||
|
||||
# test num_waveforms_per_prompt=1 (default)
|
||||
audios = audioldm_pipe(prompt, num_inference_steps=2).audios
|
||||
|
||||
assert audios.shape == (1, 256)
|
||||
|
||||
# test num_waveforms_per_prompt=1 (default) for batch of prompts
|
||||
batch_size = 2
|
||||
audios = audioldm_pipe([prompt] * batch_size, num_inference_steps=2).audios
|
||||
|
||||
assert audios.shape == (batch_size, 256)
|
||||
|
||||
# test num_waveforms_per_prompt for single prompt
|
||||
num_waveforms_per_prompt = 2
|
||||
audios = audioldm_pipe(prompt, num_inference_steps=2, num_waveforms_per_prompt=num_waveforms_per_prompt).audios
|
||||
|
||||
assert audios.shape == (num_waveforms_per_prompt, 256)
|
||||
|
||||
# test num_waveforms_per_prompt for batch of prompts
|
||||
batch_size = 2
|
||||
audios = audioldm_pipe(
|
||||
[prompt] * batch_size, num_inference_steps=2, num_waveforms_per_prompt=num_waveforms_per_prompt
|
||||
).audios
|
||||
|
||||
assert audios.shape == (batch_size * num_waveforms_per_prompt, 256)
|
||||
|
||||
def test_audioldm2_audio_length_in_s(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
audioldm_pipe = AudioLDM2Pipeline(**components)
|
||||
audioldm_pipe = audioldm_pipe.to(torch_device)
|
||||
audioldm_pipe.set_progress_bar_config(disable=None)
|
||||
vocoder_sampling_rate = audioldm_pipe.vocoder.config.sampling_rate
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
output = audioldm_pipe(audio_length_in_s=0.016, **inputs)
|
||||
audio = output.audios[0]
|
||||
|
||||
assert audio.ndim == 1
|
||||
assert len(audio) / vocoder_sampling_rate == 0.016
|
||||
|
||||
output = audioldm_pipe(audio_length_in_s=0.032, **inputs)
|
||||
audio = output.audios[0]
|
||||
|
||||
assert audio.ndim == 1
|
||||
assert len(audio) / vocoder_sampling_rate == 0.032
|
||||
|
||||
def test_audioldm2_vocoder_model_in_dim(self):
|
||||
components = self.get_dummy_components()
|
||||
audioldm_pipe = AudioLDM2Pipeline(**components)
|
||||
audioldm_pipe = audioldm_pipe.to(torch_device)
|
||||
audioldm_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = ["hey"]
|
||||
|
||||
output = audioldm_pipe(prompt, num_inference_steps=1)
|
||||
audio_shape = output.audios.shape
|
||||
assert audio_shape == (1, 256)
|
||||
|
||||
config = audioldm_pipe.vocoder.config
|
||||
config.model_in_dim *= 2
|
||||
audioldm_pipe.vocoder = SpeechT5HifiGan(config).to(torch_device)
|
||||
output = audioldm_pipe(prompt, num_inference_steps=1)
|
||||
audio_shape = output.audios.shape
|
||||
# waveform shape is unchanged, we just have 2x the number of mel channels in the spectrogram
|
||||
assert audio_shape == (1, 256)
|
||||
|
||||
def test_attention_slicing_forward_pass(self):
|
||||
self._test_attention_slicing_forward_pass(test_mean_pixel_difference=False)
|
||||
|
||||
@unittest.skipIf(
|
||||
torch_device != "cuda" or not is_xformers_available(),
|
||||
reason="XFormers attention is only available with CUDA and `xformers` installed",
|
||||
)
|
||||
def test_xformers_attention_forwardGenerator_pass(self):
|
||||
self._test_xformers_attention_forwardGenerator_pass(test_mean_pixel_difference=False)
|
||||
|
||||
def test_dict_tuple_outputs_equivalent(self):
|
||||
# increase tolerance from 1e-4 -> 2e-4 to account for large composite model
|
||||
super().test_dict_tuple_outputs_equivalent(expected_max_difference=2e-4)
|
||||
|
||||
def test_inference_batch_single_identical(self):
|
||||
# increase tolerance from 1e-4 -> 2e-4 to account for large composite model
|
||||
self._test_inference_batch_single_identical(test_mean_pixel_difference=False, expected_max_diff=2e-4)
|
||||
|
||||
def test_save_load_local(self):
|
||||
# increase tolerance from 1e-4 -> 2e-4 to account for large composite model
|
||||
super().test_save_load_local(expected_max_difference=2e-4)
|
||||
|
||||
def test_save_load_optional_components(self):
|
||||
# increase tolerance from 1e-4 -> 2e-4 to account for large composite model
|
||||
super().test_save_load_optional_components(expected_max_difference=2e-4)
|
||||
|
||||
def test_to_dtype(self):
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
# The method component.dtype returns the dtype of the first parameter registered in the model, not the
|
||||
# dtype of the entire model. In the case of CLAP, the first parameter is a float64 constant (logit scale)
|
||||
model_dtypes = {key: component.dtype for key, component in components.items() if hasattr(component, "dtype")}
|
||||
self.assertTrue(model_dtypes["text_encoder"] == torch.float64)
|
||||
|
||||
# Without the logit scale parameters, everything is float32
|
||||
model_dtypes.pop("text_encoder")
|
||||
self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes.values()))
|
||||
|
||||
# the CLAP sub-models are float32
|
||||
model_dtypes["clap_text_branch"] = components["text_encoder"].text_model.dtype
|
||||
self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes.values()))
|
||||
|
||||
# Once we send to fp16, all params are in half-precision, including the logit scale
|
||||
pipe.to(torch_dtype=torch.float16)
|
||||
model_dtypes = {key: component.dtype for key, component in components.items() if hasattr(component, "dtype")}
|
||||
self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes.values()))
|
||||
|
||||
|
||||
@slow
|
||||
class AudioLDM2PipelineSlowTests(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
|
||||
generator = torch.Generator(device=generator_device).manual_seed(seed)
|
||||
latents = np.random.RandomState(seed).standard_normal((1, 8, 128, 16))
|
||||
latents = torch.from_numpy(latents).to(device=device, dtype=dtype)
|
||||
inputs = {
|
||||
"prompt": "A hammer hitting a wooden surface",
|
||||
"latents": latents,
|
||||
"generator": generator,
|
||||
"num_inference_steps": 3,
|
||||
"guidance_scale": 2.5,
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_audioldm2(self):
|
||||
audioldm_pipe = AudioLDM2Pipeline.from_pretrained("/home/sanchit/convert-audioldm2/hub-audioldm2")
|
||||
audioldm_pipe = audioldm_pipe.to(torch_device)
|
||||
audioldm_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_inputs(torch_device)
|
||||
inputs["num_inference_steps"] = 25
|
||||
audio = audioldm_pipe(**inputs).audios[0]
|
||||
|
||||
assert audio.ndim == 1
|
||||
assert len(audio) == 81952
|
||||
|
||||
# check the portion of the generated audio with the largest dynamic range (reduces flakiness)
|
||||
audio_slice = audio[17275:17285]
|
||||
expected_slice = np.array([0.0791, 0.0666, 0.1158, 0.1227, 0.1171, -0.2880, -0.1940, -0.0283, -0.0126, 0.1127])
|
||||
max_diff = np.abs(expected_slice - audio_slice).max()
|
||||
assert max_diff < 1e-3
|
||||
|
||||
def test_audioldm2_lms(self):
|
||||
audioldm_pipe = AudioLDM2Pipeline.from_pretrained("/home/sanchit/convert-audioldm2/hub-audioldm2")
|
||||
audioldm_pipe.scheduler = LMSDiscreteScheduler.from_config(audioldm_pipe.scheduler.config)
|
||||
audioldm_pipe = audioldm_pipe.to(torch_device)
|
||||
audioldm_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_inputs(torch_device)
|
||||
audio = audioldm_pipe(**inputs).audios[0]
|
||||
|
||||
assert audio.ndim == 1
|
||||
assert len(audio) == 81952
|
||||
|
||||
# check the portion of the generated audio with the largest dynamic range (reduces flakiness)
|
||||
audio_slice = audio[31390:31400]
|
||||
expected_slice = np.array(
|
||||
[-0.1318, -0.0577, 0.0446, -0.0573, 0.0659, 0.1074, -0.2600, 0.0080, -0.2190, -0.4301]
|
||||
)
|
||||
max_diff = np.abs(expected_slice - audio_slice).max()
|
||||
assert max_diff < 1e-3
|
||||
|
||||
def test_audioldm2_large(self):
|
||||
audioldm_pipe = AudioLDM2Pipeline.from_pretrained("/home/sanchit/convert-audioldm2/hub-audioldm2-large")
|
||||
audioldm_pipe = audioldm_pipe.to(torch_device)
|
||||
audioldm_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_inputs(torch_device)
|
||||
audio = audioldm_pipe(**inputs).audios[0]
|
||||
|
||||
assert audio.ndim == 1
|
||||
assert len(audio) == 81952
|
||||
|
||||
# check the portion of the generated audio with the largest dynamic range (reduces flakiness)
|
||||
audio_slice = audio[8825:8835]
|
||||
expected_slice = np.array(
|
||||
[-0.1829, -0.1461, 0.0759, -0.1493, -0.1396, 0.5783, 0.3001, -0.3038, -0.0639, -0.2244]
|
||||
)
|
||||
max_diff = np.abs(expected_slice - audio_slice).max()
|
||||
assert max_diff < 1e-3
|
||||
@@ -1008,4 +1008,4 @@ class StableDiffusionMultiControlNetPipelineNightlyTests(unittest.TestCase):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
assert np.abs(images[0] - images[1]).max() < 1e-3
|
||||
assert np.abs(images[0] - images[1]).sum() < 1e-3
|
||||
|
||||
@@ -455,4 +455,4 @@ class ControlNetImg2ImgPipelineNightlyTests(unittest.TestCase):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
assert np.abs(images[0] - images[1]).max() < 1e-3
|
||||
assert np.abs(images[0] - images[1]).sum() < 1e-3
|
||||
|
||||
@@ -602,4 +602,4 @@ class ControlNetInpaintPipelineNightlyTests(unittest.TestCase):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
assert np.abs(images[0] - images[1]).max() < 1e-3
|
||||
assert np.abs(images[0] - images[1]).sum() < 1e-3
|
||||
|
||||
@@ -76,7 +76,6 @@ from diffusers.utils.testing_utils import (
|
||||
load_numpy,
|
||||
require_compel,
|
||||
require_flax,
|
||||
require_onnxruntime,
|
||||
require_torch_gpu,
|
||||
run_test_in_subprocess,
|
||||
)
|
||||
@@ -122,7 +121,7 @@ def _test_from_save_pretrained_dynamo(in_queue, out_queue, timeout):
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
new_image = new_ddpm(generator=generator, num_inference_steps=5, output_type="numpy").images
|
||||
|
||||
assert np.abs(image - new_image).max() < 1e-5, "Models don't give the same forward pass"
|
||||
assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass"
|
||||
except Exception:
|
||||
error = f"{traceback.format_exc()}"
|
||||
|
||||
@@ -326,22 +325,6 @@ class DownloadTests(unittest.TestCase):
|
||||
assert not any("openvino_" in f for f in files)
|
||||
|
||||
def test_download_no_onnx_by_default(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
tmpdirname = DiffusionPipeline.download(
|
||||
"hf-internal-testing/tiny-stable-diffusion-xl-pipe",
|
||||
cache_dir=tmpdirname,
|
||||
use_safetensors=False,
|
||||
)
|
||||
|
||||
all_root_files = [t[-1] for t in os.walk(os.path.join(tmpdirname))]
|
||||
files = [item for sublist in all_root_files for item in sublist]
|
||||
|
||||
# make sure that by default no onnx weights are downloaded for non-ONNX pipelines
|
||||
assert all((f.endswith(".json") or f.endswith(".bin") or f.endswith(".txt")) for f in files)
|
||||
assert not any((f.endswith(".onnx") or f.endswith(".pb")) for f in files)
|
||||
|
||||
@require_onnxruntime
|
||||
def test_download_onnx_by_default_for_onnx_pipelines(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
tmpdirname = DiffusionPipeline.download(
|
||||
"hf-internal-testing/tiny-random-OnnxStableDiffusionPipeline",
|
||||
@@ -351,7 +334,21 @@ class DownloadTests(unittest.TestCase):
|
||||
all_root_files = [t[-1] for t in os.walk(os.path.join(tmpdirname))]
|
||||
files = [item for sublist in all_root_files for item in sublist]
|
||||
|
||||
# make sure that by default onnx weights are downloaded for ONNX pipelines
|
||||
# make sure that by default no onnx weights are downloaded
|
||||
assert all((f.endswith(".json") or f.endswith(".bin") or f.endswith(".txt")) for f in files)
|
||||
assert not any((f.endswith(".onnx") or f.endswith(".pb")) for f in files)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
tmpdirname = DiffusionPipeline.download(
|
||||
"hf-internal-testing/tiny-random-OnnxStableDiffusionPipeline",
|
||||
cache_dir=tmpdirname,
|
||||
use_onnx=True,
|
||||
)
|
||||
|
||||
all_root_files = [t[-1] for t in os.walk(os.path.join(tmpdirname))]
|
||||
files = [item for sublist in all_root_files for item in sublist]
|
||||
|
||||
# if `use_onnx` is specified make sure weights are downloaded
|
||||
assert any((f.endswith(".json") or f.endswith(".bin") or f.endswith(".txt")) for f in files)
|
||||
assert any((f.endswith(".onnx")) for f in files)
|
||||
assert any((f.endswith(".pb")) for f in files)
|
||||
@@ -1543,7 +1540,7 @@ class PipelineSlowTests(unittest.TestCase):
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
new_image = new_ddpm(generator=generator, num_inference_steps=5, output_type="numpy").images
|
||||
|
||||
assert np.abs(image - new_image).max() < 1e-5, "Models don't give the same forward pass"
|
||||
assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass"
|
||||
|
||||
@require_torch_2
|
||||
def test_from_save_pretrained_dynamo(self):
|
||||
@@ -1568,7 +1565,7 @@ class PipelineSlowTests(unittest.TestCase):
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
new_image = ddpm_from_hub(generator=generator, num_inference_steps=5, output_type="numpy").images
|
||||
|
||||
assert np.abs(image - new_image).max() < 1e-5, "Models don't give the same forward pass"
|
||||
assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass"
|
||||
|
||||
def test_from_pretrained_hub_pass_model(self):
|
||||
model_path = "google/ddpm-cifar10-32"
|
||||
@@ -1591,7 +1588,7 @@ class PipelineSlowTests(unittest.TestCase):
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
new_image = ddpm_from_hub(generator=generator, num_inference_steps=5, output_type="numpy").images
|
||||
|
||||
assert np.abs(image - new_image).max() < 1e-5, "Models don't give the same forward pass"
|
||||
assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass"
|
||||
|
||||
def test_output_format(self):
|
||||
model_path = "google/ddpm-cifar10-32"
|
||||
@@ -1625,7 +1622,7 @@ class PipelineSlowTests(unittest.TestCase):
|
||||
from diffusers import FlaxStableDiffusionPipeline
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pipe_pt.save_pretrained(tmpdirname, safe_serialization=False)
|
||||
pipe_pt.save_pretrained(tmpdirname)
|
||||
|
||||
pipe_flax, params = FlaxStableDiffusionPipeline.from_pretrained(
|
||||
tmpdirname, safety_checker=None, from_pt=True
|
||||
|
||||
@@ -76,7 +76,7 @@ class VersatileDiffusionDualGuidedPipelineIntegrationTests(unittest.TestCase):
|
||||
output_type="numpy",
|
||||
).images
|
||||
|
||||
assert np.abs(image - new_image).max() < 1e-5, "Models don't have the same forward pass"
|
||||
assert np.abs(image - new_image).sum() < 1e-5, "Models don't have the same forward pass"
|
||||
|
||||
def test_inference_dual_guided(self):
|
||||
pipe = VersatileDiffusionDualGuidedPipeline.from_pretrained("shi-labs/versatile-diffusion")
|
||||
|
||||
@@ -77,7 +77,7 @@ class VersatileDiffusionMegaPipelineIntegrationTests(unittest.TestCase):
|
||||
output_type="numpy",
|
||||
).images
|
||||
|
||||
assert np.abs(image - new_image).max() < 1e-5, "Models don't have the same forward pass"
|
||||
assert np.abs(image - new_image).sum() < 1e-5, "Models don't have the same forward pass"
|
||||
|
||||
def test_inference_dual_guided_then_text_to_image(self):
|
||||
pipe = VersatileDiffusionPipeline.from_pretrained("shi-labs/versatile-diffusion", torch_dtype=torch.float16)
|
||||
|
||||
@@ -64,7 +64,7 @@ class VersatileDiffusionTextToImagePipelineIntegrationTests(unittest.TestCase):
|
||||
prompt=prompt, generator=generator, guidance_scale=7.5, num_inference_steps=2, output_type="numpy"
|
||||
).images
|
||||
|
||||
assert np.abs(image - new_image).max() < 1e-5, "Models don't have the same forward pass"
|
||||
assert np.abs(image - new_image).sum() < 1e-5, "Models don't have the same forward pass"
|
||||
|
||||
def test_inference_text2img(self):
|
||||
pipe = VersatileDiffusionTextToImagePipeline.from_pretrained(
|
||||
|
||||
@@ -104,8 +104,6 @@ class IPNDMSchedulerTest(SchedulerCommonTest):
|
||||
residual = model(sample, t)
|
||||
sample = scheduler.step(residual, t, sample).prev_sample
|
||||
|
||||
scheduler._step_index = None
|
||||
|
||||
for i, t in enumerate(scheduler.timesteps):
|
||||
residual = model(sample, t)
|
||||
sample = scheduler.step(residual, t, sample).prev_sample
|
||||
|
||||
@@ -485,8 +485,8 @@ class SchedulerCommonTest(unittest.TestCase):
|
||||
|
||||
num_inference_steps = kwargs.pop("num_inference_steps", None)
|
||||
|
||||
timestep_0 = 1
|
||||
timestep_1 = 0
|
||||
timestep_0 = 0
|
||||
timestep_1 = 1
|
||||
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
if scheduler_class in (EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, LMSDiscreteScheduler):
|
||||
|
||||
Reference in New Issue
Block a user