mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-06 20:44:33 +08:00
Compare commits
18 Commits
version-ch
...
pia
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
facdd91be9 | ||
|
|
d01a61bc81 | ||
|
|
2aba6316cd | ||
|
|
911b6825c9 | ||
|
|
7542799344 | ||
|
|
8961802a37 | ||
|
|
e4535a52b2 | ||
|
|
b3e4161e57 | ||
|
|
4c9244dd3b | ||
|
|
a93a5ba95c | ||
|
|
1c7b31a580 | ||
|
|
8d1794db96 | ||
|
|
1ba867ddea | ||
|
|
845d7f722d | ||
|
|
5025fc3594 | ||
|
|
7be981cc8c | ||
|
|
1169370f19 | ||
|
|
b14881e38b |
@@ -300,6 +300,8 @@
|
||||
title: MusicLDM
|
||||
- local: api/pipelines/paint_by_example
|
||||
title: Paint by Example
|
||||
- local: api/pipelines/pia
|
||||
title: Personalized Image Animator (PIA)
|
||||
- local: api/pipelines/pixart
|
||||
title: PixArt-α
|
||||
- local: api/pipelines/self_attention_guidance
|
||||
|
||||
167
docs/source/en/api/pipelines/pia.md
Normal file
167
docs/source/en/api/pipelines/pia.md
Normal file
@@ -0,0 +1,167 @@
|
||||
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Image-to-Video Generation with PIA (Personalized Image Animator)
|
||||
|
||||
## Overview
|
||||
|
||||
[PIA: Your Personalized Image Animator via Plug-and-Play Modules in Text-to-Image Models](https://arxiv.org/abs/2312.13964) by Yiming Zhang, Zhening Xing, Yanhong Zeng, Youqing Fang, Kai Chen
|
||||
|
||||
Recent advancements in personalized text-to-image (T2I) models have revolutionized content creation, empowering non-experts to generate stunning images with unique styles. While promising, adding realistic motions into these personalized images by text poses significant challenges in preserving distinct styles, high-fidelity details, and achieving motion controllability by text. In this paper, we present PIA, a Personalized Image Animator that excels in aligning with condition images, achieving motion controllability by text, and the compatibility with various personalized T2I models without specific tuning. To achieve these goals, PIA builds upon a base T2I model with well-trained temporal alignment layers, allowing for the seamless transformation of any personalized T2I model into an image animation model. A key component of PIA is the introduction of the condition module, which utilizes the condition frame and inter-frame affinity as input to transfer appearance information guided by the affinity hint for individual frame synthesis in the latent space. This design mitigates the challenges of appearance-related image alignment within and allows for a stronger focus on aligning with motion-related guidance.
|
||||
|
||||
[Project page](https://pi-animator.github.io/)
|
||||
|
||||
## Available Pipelines
|
||||
|
||||
| Pipeline | Tasks | Demo
|
||||
|---|---|:---:|
|
||||
| [PIAPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pia/pipeline_pia.py) | *Image-to-Video Generation with PIA* |
|
||||
|
||||
## Available checkpoints
|
||||
|
||||
Motion Adapter checkpoints for PIA can be found under the [OpenMMLab org](https://huggingface.co/openmmlab/PIA-condition-adapter). These checkpoints are meant to work with any model based on Stable Diffusion 1.5
|
||||
|
||||
## Usage example
|
||||
|
||||
PIA works with a MotionAdapter checkpoint and a Stable Diffusion 1.5 model checkpoint. The MotionAdapter is a collection of Motion Modules that are responsible for adding coherent motion across image frames. These modules are applied after the Resnet and Attention blocks in the Stable Diffusion UNet. In addition to the motion modules, PIA also replaces the input convolution layer of the SD 1.5 UNet model with a 9 channel input convolution layer.
|
||||
|
||||
The following example demonstrates how to use PIA to generate a video from a single image.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import (
|
||||
EulerDiscreteScheduler,
|
||||
MotionAdapter,
|
||||
PIAPipeline,
|
||||
)
|
||||
from diffusers.utils import export_to_gif, load_image
|
||||
|
||||
adapter = MotionAdapter.from_pretrained("openmmlab/PIA-condition-adapter")
|
||||
pipe = PIAPipeline.from_pretrained("SG161222/Realistic_Vision_V6.0_B1_noVAE", motion_adapter=adapter, torch_dtype=torch.float16)
|
||||
|
||||
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
|
||||
pipe.enable_model_cpu_offload()
|
||||
pipe.enable_vae_slicing()
|
||||
|
||||
image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/pix2pix/cat_6.png?download=true"
|
||||
)
|
||||
image = image.resize((512, 512))
|
||||
prompt = "cat in a field"
|
||||
negative_prompt = "wrong white balance, dark, sketches,worst quality,low quality"
|
||||
|
||||
generator = torch.Generator("cpu").manual_seed(0)
|
||||
output = pipe(image=image, prompt=prompt, generator=generator)
|
||||
frames = output.frames[0]
|
||||
export_to_gif(frames, "pia-animation.gif")
|
||||
```
|
||||
|
||||
Here are some sample outputs:
|
||||
|
||||
<table>
|
||||
<tr>
|
||||
<td><center>
|
||||
masterpiece, bestquality, sunset.
|
||||
<br>
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/pia-default-output.gif"
|
||||
alt="cat in a field"
|
||||
style="width: 300px;" />
|
||||
</center></td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
|
||||
<Tip>
|
||||
|
||||
If you plan on using a scheduler that can clip samples, make sure to disable it by setting `clip_sample=False` in the scheduler as this can also have an adverse effect on generated samples. Additionally, the PIA checkpoints can be sensitive to the beta schedule of the scheduler. We recommend setting this to `linear`.
|
||||
|
||||
</Tip>
|
||||
|
||||
## Using FreeInit
|
||||
|
||||
[FreeInit: Bridging Initialization Gap in Video Diffusion Models](https://arxiv.org/abs/2312.07537) by Tianxing Wu, Chenyang Si, Yuming Jiang, Ziqi Huang, Ziwei Liu.
|
||||
|
||||
FreeInit is an effective method that improves temporal consistency and overall quality of videos generated using video-diffusion-models without any addition training. It can be applied to PIA, AnimateDiff, ModelScope, VideoCrafter and various other video generation models seamlessly at inference time, and works by iteratively refining the latent-initialization noise. More details can be found it the paper.
|
||||
|
||||
The following example demonstrates the usage of FreeInit.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import (
|
||||
DDIMScheduler,
|
||||
MotionAdapter,
|
||||
PIAPipeline,
|
||||
)
|
||||
from diffusers.utils import export_to_gif, load_image
|
||||
|
||||
adapter = MotionAdapter.from_pretrained("openmmlab/PIA-condition-adapter")
|
||||
pipe = PIAPipeline.from_pretrained("SG161222/Realistic_Vision_V6.0_B1_noVAE", motion_adapter=adapter)
|
||||
|
||||
# enable FreeInit
|
||||
# Refer to the enable_free_init documentation for a full list of configurable parameters
|
||||
pipe.enable_free_init(method="butterworth", use_fast_sampling=True)
|
||||
|
||||
# Memory saving options
|
||||
pipe.enable_model_cpu_offload()
|
||||
pipe.enable_vae_slicing()
|
||||
|
||||
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
|
||||
image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/pix2pix/cat_6.png?download=true"
|
||||
)
|
||||
image = image.resize((512, 512))
|
||||
prompt = "cat in a hat"
|
||||
negative_prompt = "wrong white balance, dark, sketches,worst quality,low quality"
|
||||
|
||||
generator = torch.Generator("cpu").manual_seed(0)
|
||||
|
||||
output = pipe(image=image, prompt=prompt, generator=generator)
|
||||
frames = output.frames[0]
|
||||
export_to_gif(frames, "pia-freeinit-animation.gif")
|
||||
```
|
||||
|
||||
<table>
|
||||
<tr>
|
||||
<td><center>
|
||||
masterpiece, bestquality, sunset.
|
||||
<br>
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/pia-freeinit-output-cat.gif"
|
||||
alt="cat in a field"
|
||||
style="width: 300px;" />
|
||||
</center></td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
FreeInit is not really free - the improved quality comes at the cost of extra computation. It requires sampling a few extra times depending on the `num_iters` parameter that is set when enabling it. Setting the `use_fast_sampling` parameter to `True` can improve the overall performance (at the cost of lower quality compared to when `use_fast_sampling=False` but still better results than vanilla video generation models).
|
||||
|
||||
</Tip>
|
||||
|
||||
## PIAPipeline
|
||||
|
||||
[[autodoc]] PIAPipeline
|
||||
- all
|
||||
- __call__
|
||||
- enable_freeu
|
||||
- disable_freeu
|
||||
- enable_free_init
|
||||
- disable_free_init
|
||||
- enable_vae_slicing
|
||||
- disable_vae_slicing
|
||||
- enable_vae_tiling
|
||||
- disable_vae_tiling
|
||||
|
||||
## PIAPipelineOutput
|
||||
|
||||
[[autodoc]] pipelines.pia.PIAPipelineOutput
|
||||
@@ -247,6 +247,7 @@ else:
|
||||
"LDMTextToImagePipeline",
|
||||
"MusicLDMPipeline",
|
||||
"PaintByExamplePipeline",
|
||||
"PIAPipeline",
|
||||
"PixArtAlphaPipeline",
|
||||
"SemanticStableDiffusionPipeline",
|
||||
"ShapEImg2ImgPipeline",
|
||||
@@ -606,6 +607,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
LDMTextToImagePipeline,
|
||||
MusicLDMPipeline,
|
||||
PaintByExamplePipeline,
|
||||
PIAPipeline,
|
||||
PixArtAlphaPipeline,
|
||||
SemanticStableDiffusionPipeline,
|
||||
ShapEImg2ImgPipeline,
|
||||
|
||||
@@ -89,6 +89,7 @@ class MotionAdapter(ModelMixin, ConfigMixin):
|
||||
motion_norm_num_groups: int = 32,
|
||||
motion_max_seq_length: int = 32,
|
||||
use_motion_mid_block: bool = True,
|
||||
conv_in_channels: Optional[int] = None,
|
||||
):
|
||||
"""Container to store AnimateDiff Motion Modules
|
||||
|
||||
@@ -113,6 +114,12 @@ class MotionAdapter(ModelMixin, ConfigMixin):
|
||||
down_blocks = []
|
||||
up_blocks = []
|
||||
|
||||
if conv_in_channels:
|
||||
# input
|
||||
self.conv_in = nn.Conv2d(conv_in_channels, block_out_channels[0], kernel_size=3, padding=1)
|
||||
else:
|
||||
self.conv_in = None
|
||||
|
||||
for i, channel in enumerate(block_out_channels):
|
||||
output_channel = block_out_channels[i]
|
||||
down_blocks.append(
|
||||
@@ -410,6 +417,10 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
||||
config["motion_max_seq_length"] = motion_adapter.config["motion_max_seq_length"]
|
||||
config["use_motion_mid_block"] = motion_adapter.config["use_motion_mid_block"]
|
||||
|
||||
# For PIA UNets we need to set the number input channels to 9
|
||||
if motion_adapter.config["conv_in_channels"]:
|
||||
config["in_channels"] = motion_adapter.config["conv_in_channels"]
|
||||
|
||||
# Need this for backwards compatibility with UNet2DConditionModel checkpoints
|
||||
if not config.get("num_attention_heads"):
|
||||
config["num_attention_heads"] = config["attention_head_dim"]
|
||||
@@ -419,7 +430,17 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
||||
if not load_weights:
|
||||
return model
|
||||
|
||||
model.conv_in.load_state_dict(unet.conv_in.state_dict())
|
||||
# Logic for loading PIA UNets which allow the first 4 channels to be any UNet2DConditionModel conv_in weight
|
||||
# while the last 5 channels must be PIA conv_in weights.
|
||||
if has_motion_adapter and motion_adapter.config["conv_in_channels"]:
|
||||
model.conv_in = motion_adapter.conv_in
|
||||
updated_conv_in_weight = torch.cat(
|
||||
[unet.conv_in.weight, motion_adapter.conv_in.weight[:, 4:, :, :]], dim=1
|
||||
)
|
||||
model.conv_in.load_state_dict({"weight": updated_conv_in_weight, "bias": unet.conv_in.bias})
|
||||
else:
|
||||
model.conv_in.load_state_dict(unet.conv_in.state_dict())
|
||||
|
||||
model.time_proj.load_state_dict(unet.time_proj.state_dict())
|
||||
model.time_embedding.load_state_dict(unet.time_embedding.state_dict())
|
||||
|
||||
|
||||
@@ -168,6 +168,7 @@ else:
|
||||
_import_structure["latent_diffusion"].extend(["LDMTextToImagePipeline"])
|
||||
_import_structure["musicldm"] = ["MusicLDMPipeline"]
|
||||
_import_structure["paint_by_example"] = ["PaintByExamplePipeline"]
|
||||
_import_structure["pia"] = ["PIAPipeline"]
|
||||
_import_structure["pixart_alpha"] = ["PixArtAlphaPipeline"]
|
||||
_import_structure["semantic_stable_diffusion"] = ["SemanticStableDiffusionPipeline"]
|
||||
_import_structure["shap_e"] = ["ShapEImg2ImgPipeline", "ShapEPipeline"]
|
||||
@@ -412,6 +413,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .latent_diffusion import LDMTextToImagePipeline
|
||||
from .musicldm import MusicLDMPipeline
|
||||
from .paint_by_example import PaintByExamplePipeline
|
||||
from .pia import PIAPipeline
|
||||
from .pixart_alpha import PixArtAlphaPipeline
|
||||
from .semantic_stable_diffusion import SemanticStableDiffusionPipeline
|
||||
from .shap_e import ShapEImg2ImgPipeline, ShapEPipeline
|
||||
|
||||
46
src/diffusers/pipelines/pia/__init__.py
Normal file
46
src/diffusers/pipelines/pia/__init__.py
Normal file
@@ -0,0 +1,46 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import (
|
||||
DIFFUSERS_SLOW_IMPORT,
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
get_objects_from_module,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
|
||||
|
||||
_dummy_objects = {}
|
||||
_import_structure = {}
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils import dummy_torch_and_transformers_objects
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["pipeline_pia"] = ["PIAPipeline", "PIAPipelineOutput"]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import *
|
||||
|
||||
else:
|
||||
from .pipeline_pia import PIAPipeline, PIAPipelineOutput
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(
|
||||
__name__,
|
||||
globals()["__file__"],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
)
|
||||
for name, value in _dummy_objects.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
1193
src/diffusers/pipelines/pia/pipeline_pia.py
Normal file
1193
src/diffusers/pipelines/pia/pipeline_pia.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -647,6 +647,21 @@ class PaintByExamplePipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class PIAPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class PixArtAlphaPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
|
||||
0
tests/pipelines/pia/__init__.py
Normal file
0
tests/pipelines/pia/__init__.py
Normal file
313
tests/pipelines/pia/test_pia.py
Normal file
313
tests/pipelines/pia/test_pia.py
Normal file
@@ -0,0 +1,313 @@
|
||||
import random
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
import diffusers
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
DDIMScheduler,
|
||||
MotionAdapter,
|
||||
PIAPipeline,
|
||||
UNet2DConditionModel,
|
||||
UNetMotionModel,
|
||||
)
|
||||
from diffusers.utils import is_xformers_available, logging
|
||||
from diffusers.utils.testing_utils import floats_tensor, torch_device
|
||||
|
||||
from ..test_pipelines_common import PipelineTesterMixin
|
||||
|
||||
|
||||
def to_np(tensor):
|
||||
if isinstance(tensor, torch.Tensor):
|
||||
tensor = tensor.detach().cpu().numpy()
|
||||
|
||||
return tensor
|
||||
|
||||
|
||||
class PIAPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = PIAPipeline
|
||||
params = frozenset(
|
||||
[
|
||||
"prompt",
|
||||
"height",
|
||||
"width",
|
||||
"guidance_scale",
|
||||
"negative_prompt",
|
||||
"prompt_embeds",
|
||||
"negative_prompt_embeds",
|
||||
"cross_attention_kwargs",
|
||||
]
|
||||
)
|
||||
batch_params = frozenset(["prompt", "image", "generator"])
|
||||
required_optional_params = frozenset(
|
||||
[
|
||||
"num_inference_steps",
|
||||
"generator",
|
||||
"latents",
|
||||
"return_dict",
|
||||
"callback_on_step_end",
|
||||
"callback_on_step_end_tensor_inputs",
|
||||
]
|
||||
)
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
unet = UNet2DConditionModel(
|
||||
block_out_channels=(32, 64),
|
||||
layers_per_block=2,
|
||||
sample_size=32,
|
||||
in_channels=4,
|
||||
out_channels=4,
|
||||
down_block_types=("CrossAttnDownBlock2D", "DownBlock2D"),
|
||||
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
|
||||
cross_attention_dim=32,
|
||||
norm_num_groups=2,
|
||||
)
|
||||
scheduler = DDIMScheduler(
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
beta_schedule="linear",
|
||||
clip_sample=False,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKL(
|
||||
block_out_channels=[32, 64],
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
|
||||
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
|
||||
latent_channels=4,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
text_encoder_config = CLIPTextConfig(
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
hidden_size=32,
|
||||
intermediate_size=37,
|
||||
layer_norm_eps=1e-05,
|
||||
num_attention_heads=4,
|
||||
num_hidden_layers=5,
|
||||
pad_token_id=1,
|
||||
vocab_size=1000,
|
||||
)
|
||||
text_encoder = CLIPTextModel(text_encoder_config)
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
motion_adapter = MotionAdapter(
|
||||
block_out_channels=(32, 64),
|
||||
motion_layers_per_block=2,
|
||||
motion_norm_num_groups=2,
|
||||
motion_num_attention_heads=4,
|
||||
conv_in_channels=9,
|
||||
)
|
||||
|
||||
components = {
|
||||
"unet": unet,
|
||||
"scheduler": scheduler,
|
||||
"vae": vae,
|
||||
"motion_adapter": motion_adapter,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
"feature_extractor": None,
|
||||
"image_encoder": None,
|
||||
}
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
|
||||
image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
|
||||
inputs = {
|
||||
"image": image,
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 7.5,
|
||||
"output_type": "pt",
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_motion_unet_loading(self):
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
|
||||
assert isinstance(pipe.unet, UNetMotionModel)
|
||||
|
||||
@unittest.skip("Attention slicing is not enabled in this pipeline")
|
||||
def test_attention_slicing_forward_pass(self):
|
||||
pass
|
||||
|
||||
def test_inference_batch_single_identical(
|
||||
self,
|
||||
batch_size=2,
|
||||
expected_max_diff=1e-4,
|
||||
additional_params_copy_to_batched_inputs=["num_inference_steps"],
|
||||
):
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
for components in pipe.components.values():
|
||||
if hasattr(components, "set_default_attn_processor"):
|
||||
components.set_default_attn_processor()
|
||||
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
# Reset generator in case it is has been used in self.get_dummy_inputs
|
||||
inputs["generator"] = self.get_generator(0)
|
||||
|
||||
logger = logging.get_logger(pipe.__module__)
|
||||
logger.setLevel(level=diffusers.logging.FATAL)
|
||||
|
||||
# batchify inputs
|
||||
batched_inputs = {}
|
||||
batched_inputs.update(inputs)
|
||||
|
||||
for name in self.batch_params:
|
||||
if name not in inputs:
|
||||
continue
|
||||
|
||||
value = inputs[name]
|
||||
if name == "prompt":
|
||||
len_prompt = len(value)
|
||||
batched_inputs[name] = [value[: len_prompt // i] for i in range(1, batch_size + 1)]
|
||||
batched_inputs[name][-1] = 100 * "very long"
|
||||
|
||||
else:
|
||||
batched_inputs[name] = batch_size * [value]
|
||||
|
||||
if "generator" in inputs:
|
||||
batched_inputs["generator"] = [self.get_generator(i) for i in range(batch_size)]
|
||||
|
||||
if "batch_size" in inputs:
|
||||
batched_inputs["batch_size"] = batch_size
|
||||
|
||||
for arg in additional_params_copy_to_batched_inputs:
|
||||
batched_inputs[arg] = inputs[arg]
|
||||
|
||||
output = pipe(**inputs)
|
||||
output_batch = pipe(**batched_inputs)
|
||||
|
||||
assert output_batch[0].shape[0] == batch_size
|
||||
|
||||
max_diff = np.abs(to_np(output_batch[0][0]) - to_np(output[0][0])).max()
|
||||
assert max_diff < expected_max_diff
|
||||
|
||||
@unittest.skipIf(torch_device != "cuda", reason="CUDA and CPU are required to switch devices")
|
||||
def test_to_device(self):
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
pipe.to("cpu")
|
||||
# pipeline creates a new motion UNet under the hood. So we need to check the device from pipe.components
|
||||
model_devices = [
|
||||
component.device.type for component in pipe.components.values() if hasattr(component, "device")
|
||||
]
|
||||
self.assertTrue(all(device == "cpu" for device in model_devices))
|
||||
|
||||
output_cpu = pipe(**self.get_dummy_inputs("cpu"))[0]
|
||||
self.assertTrue(np.isnan(output_cpu).sum() == 0)
|
||||
|
||||
pipe.to("cuda")
|
||||
model_devices = [
|
||||
component.device.type for component in pipe.components.values() if hasattr(component, "device")
|
||||
]
|
||||
self.assertTrue(all(device == "cuda" for device in model_devices))
|
||||
|
||||
output_cuda = pipe(**self.get_dummy_inputs("cuda"))[0]
|
||||
self.assertTrue(np.isnan(to_np(output_cuda)).sum() == 0)
|
||||
|
||||
def test_to_dtype(self):
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
# pipeline creates a new motion UNet under the hood. So we need to check the dtype from pipe.components
|
||||
model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
|
||||
self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes))
|
||||
|
||||
pipe.to(torch_dtype=torch.float16)
|
||||
model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
|
||||
self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes))
|
||||
|
||||
def test_prompt_embeds(self):
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.to(torch_device)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
inputs.pop("prompt")
|
||||
inputs["prompt_embeds"] = torch.randn((1, 4, 32), device=torch_device)
|
||||
pipe(**inputs)
|
||||
|
||||
def test_free_init(self):
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.to(torch_device)
|
||||
|
||||
inputs_normal = self.get_dummy_inputs(torch_device)
|
||||
frames_normal = pipe(**inputs_normal).frames[0]
|
||||
|
||||
free_init_generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
pipe.enable_free_init(
|
||||
num_iters=2,
|
||||
use_fast_sampling=True,
|
||||
method="butterworth",
|
||||
order=4,
|
||||
spatial_stop_frequency=0.25,
|
||||
temporal_stop_frequency=0.25,
|
||||
generator=free_init_generator,
|
||||
)
|
||||
inputs_enable_free_init = self.get_dummy_inputs(torch_device)
|
||||
frames_enable_free_init = pipe(**inputs_enable_free_init).frames[0]
|
||||
|
||||
pipe.disable_free_init()
|
||||
inputs_disable_free_init = self.get_dummy_inputs(torch_device)
|
||||
frames_disable_free_init = pipe(**inputs_disable_free_init).frames[0]
|
||||
|
||||
sum_enabled = np.abs(to_np(frames_normal) - to_np(frames_enable_free_init)).sum()
|
||||
max_diff_disabled = np.abs(to_np(frames_normal) - to_np(frames_disable_free_init)).max()
|
||||
self.assertGreater(
|
||||
sum_enabled, 1e1, "Enabling of FreeInit should lead to results different from the default pipeline results"
|
||||
)
|
||||
self.assertLess(
|
||||
max_diff_disabled,
|
||||
1e-4,
|
||||
"Disabling of FreeInit should lead to results similar to the default pipeline results",
|
||||
)
|
||||
|
||||
@unittest.skipIf(
|
||||
torch_device != "cuda" or not is_xformers_available(),
|
||||
reason="XFormers attention is only available with CUDA and `xformers` installed",
|
||||
)
|
||||
def test_xformers_attention_forwardGenerator_pass(self):
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
for component in pipe.components.values():
|
||||
if hasattr(component, "set_default_attn_processor"):
|
||||
component.set_default_attn_processor()
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
output_without_offload = pipe(**inputs).frames[0]
|
||||
output_without_offload = (
|
||||
output_without_offload.cpu() if torch.is_tensor(output_without_offload) else output_without_offload
|
||||
)
|
||||
|
||||
pipe.enable_xformers_memory_efficient_attention()
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
output_with_offload = pipe(**inputs).frames[0]
|
||||
output_with_offload = (
|
||||
output_with_offload.cpu() if torch.is_tensor(output_with_offload) else output_without_offload
|
||||
)
|
||||
|
||||
max_diff = np.abs(to_np(output_with_offload) - to_np(output_without_offload)).max()
|
||||
self.assertLess(max_diff, 1e-4, "XFormers attention should not affect the inference results")
|
||||
Reference in New Issue
Block a user