Compare commits

...

56 Commits

Author SHA1 Message Date
Dhruv Nair
88e76c6aee update 2023-11-02 13:21:17 +00:00
Dhruv Nair
beb1646b1c clean up 2023-11-02 13:06:31 +00:00
Dhruv Nair
0d6f5be46d update 2023-11-02 12:37:54 +00:00
Dhruv Nair
a2e38cce48 update 2023-11-02 12:33:38 +00:00
Dhruv Nair
ef893c4b38 fix docstrings 2023-11-02 10:35:04 +00:00
Dhruv Nair
c24c97b70b fix docstrings 2023-11-02 10:30:17 +00:00
Dhruv Nair
dfa52fbc81 clean up 2023-11-02 10:21:44 +00:00
Dhruv Nair
ee51b907cd clean up 2023-11-02 09:57:12 +00:00
Dhruv Nair
a6d025befc update 2023-11-02 09:53:19 +00:00
Dhruv Nair
840f576a7f clean up 2023-11-02 09:21:00 +00:00
Dhruv Nair
6d81f2aabe update 2023-11-02 09:04:20 +00:00
Dhruv Nair
d41f71783b add docs 2023-11-01 16:56:37 +00:00
Dhruv Nair
ec8bb6e119 fix mistake 2023-11-01 16:09:22 +00:00
Dhruv Nair
6f6f8aa258 fix bug 2023-11-01 16:06:46 +00:00
Dhruv Nair
5f003e5f15 max fix copies 2023-11-01 15:38:02 +00:00
Dhruv Nair
dc6eb04b4e merge upstream 2023-11-01 15:30:34 +00:00
Dhruv Nair
5e43f2412e Merge branch 'main' into animatediff-model 2023-11-01 15:16:01 +00:00
Dhruv Nair
9e6a146ad1 update 2023-11-01 14:47:25 +00:00
DN6
d939379906 fix embeddings 2023-11-01 16:35:08 +05:30
Dhruv Nair
3f5d8dec4b update 2023-11-01 10:01:55 +00:00
Dhruv Nair
2b78f1edb6 make style 2023-11-01 08:33:09 +00:00
Dhruv Nair
5d65837a46 update 2023-10-31 17:24:44 +00:00
Dhruv Nair
71dc350996 update 2023-10-31 09:33:27 +00:00
Dhruv Nair
37de1de70f update 2023-10-31 06:36:57 +00:00
Dhruv Nair
e82331e0c3 update 2023-10-30 16:14:29 +00:00
Dhruv Nair
313db1dd32 update model test 2023-10-27 06:08:47 +00:00
Dhruv Nair
8be5f1f892 update 2023-10-26 19:14:17 +00:00
Dhruv Nair
4d0b5ecd43 update 2023-10-26 11:42:06 +00:00
Dhruv Nair
3ba1ba0e18 clean up 2023-10-26 10:22:57 +00:00
Dhruv Nair
bf5b65a024 update 2023-10-26 08:49:17 +00:00
DN6
4df582eef5 update 2023-10-26 13:53:28 +05:30
Dhruv Nair
bcbc2d1507 update 2023-10-26 07:58:46 +00:00
Dhruv Nair
fe3828a3e7 update 2023-10-25 18:35:28 +00:00
Dhruv Nair
ee79cf37e9 update 2023-10-25 15:10:53 +00:00
Dhruv Nair
c7e1b14e4c update 2023-10-25 15:05:13 +00:00
Dhruv Nair
0e1f7a83f3 update 2023-10-25 15:00:42 +00:00
Dhruv Nair
22c9f7b3e3 update 2023-10-25 14:19:55 +00:00
Dhruv Nair
1bd65de5d8 clean up 2023-10-25 11:28:55 +00:00
Dhruv Nair
9c66c21bfd clean up 2023-10-25 11:04:52 +00:00
Dhruv Nair
0deab59ca8 clean up 2023-10-25 10:47:25 +00:00
Dhruv Nair
2688d07a60 change motion block 2023-10-25 10:26:40 +00:00
Dhruv Nair
b24f58a5f4 add tests 2023-10-25 10:04:37 +00:00
Dhruv Nair
79f402f2d6 clean up 2023-10-24 15:54:12 +00:00
Dhruv Nair
6ec184ab96 clean up 2023-10-24 06:40:33 +00:00
Dhruv Nair
c7ba4b8c13 clean up 2023-10-23 14:09:17 +00:00
Dhruv Nair
86a4d31cdb update pipeline 2023-10-22 17:56:29 +00:00
Dhruv Nair
9eeee36d6e clean up 2023-10-22 17:13:24 +00:00
Dhruv Nair
7a5fbf8e9e clean up 2023-10-22 16:49:38 +00:00
Dhruv Nair
d8d3515ed2 clean up 2023-10-21 22:13:13 +00:00
Dhruv Nair
72e0fa65d9 clean up 2023-10-20 17:12:03 +00:00
Dhruv Nair
2db7bd3516 clean up 2023-10-20 07:31:06 +00:00
Dhruv Nair
36b3a44a4c clean up 2023-10-18 18:52:58 +00:00
DN6
bbb2b6cb96 clean up 2023-10-16 22:20:45 +05:30
DN6
a026ea5024 clean up 2023-10-16 20:25:52 +05:30
Dhruv Nair
9e4c700441 clean up 2023-10-16 12:14:59 +00:00
DN6
d8ced0ff7a draft design 2023-10-15 21:13:16 +05:30
18 changed files with 3322 additions and 1 deletions

View File

@@ -184,6 +184,8 @@
title: UNet2DConditionModel
- local: api/models/unet3d-cond
title: UNet3DConditionModel
- local: api/models/unet-motion
title: UNetMotionModel
- local: api/models/vq
title: VQModel
- local: api/models/autoencoderkl
@@ -206,6 +208,8 @@
title: Overview
- local: api/pipelines/alt_diffusion
title: AltDiffusion
- local: api/pipelines/animatediff
title: AnimateDiff
- local: api/pipelines/attend_and_excite
title: Attend-and-Excite
- local: api/pipelines/audio_diffusion
@@ -392,5 +396,5 @@
title: Utilities
- local: api/image_processor
title: VAE Image Processor
title: Internal classes
title: Internal classes
title: API

View File

@@ -0,0 +1,13 @@
# UNetMotionModel
The [UNet](https://huggingface.co/papers/1505.04597) model was originally introduced by Ronneberger et al for biomedical image segmentation, but it is also commonly used in 🤗 Diffusers because it outputs images that are the same size as the input. It is one of the most important components of a diffusion system because it facilitates the actual diffusion process. There are several variants of the UNet model in 🤗 Diffusers, depending on it's number of dimensions and whether it is a conditional model or not. This is a 2D UNet model.
The abstract from the paper is:
*There is large consent that successful training of deep networks requires many thousand annotated training samples. In this paper, we present a network and training strategy that relies on the strong use of data augmentation to use the available annotated samples more efficiently. The architecture consists of a contracting path to capture context and a symmetric expanding path that enables precise localization. We show that such a network can be trained end-to-end from very few images and outperforms the prior best method (a sliding-window convolutional network) on the ISBI challenge for segmentation of neuronal structures in electron microscopic stacks. Using the same network trained on transmitted light microscopy images (phase contrast and DIC) we won the ISBI cell tracking challenge 2015 in these categories by a large margin. Moreover, the network is fast. Segmentation of a 512x512 image takes less than a second on a recent GPU. The full implementation (based on Caffe) and the trained networks are available at http://lmb.informatik.uni-freiburg.de/people/ronneber/u-net.*
## UNetMotionModel
[[autodoc]] UNetMotionModel
## UNet3DConditionOutput
[[autodoc]] models.unet_3d_condition.UNet3DConditionOutput

View File

@@ -0,0 +1,108 @@
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Text-to-Video Generation with AnimateDiff
## Overview
[AnimateDiff: Animate Your Personalized Text-to-Image Diffusion Models without Specific Tuning](https://arxiv.org/abs/2307.04725) by Yuwei Guo, Ceyuan Yang*, Anyi Rao, Yaohui Wang, Yu Qiao, Dahua Lin, Bo Dai
The abstract of the paper is the following:
With the advance of text-to-image models (e.g., Stable Diffusion) and corresponding personalization techniques such as DreamBooth and LoRA, everyone can manifest their imagination into high-quality images at an affordable cost. Subsequently, there is a great demand for image animation techniques to further combine generated static images with motion dynamics. In this report, we propose a practical framework to animate most of the existing personalized text-to-image models once and for all, saving efforts in model-specific tuning. At the core of the proposed framework is to insert a newly initialized motion modeling module into the frozen text-to-image model and train it on video clips to distill reasonable motion priors. Once trained, by simply injecting this motion modeling module, all personalized versions derived from the same base T2I readily become text-driven models that produce diverse and personalized animated images. We conduct our evaluation on several public representative personalized text-to-image models across anime pictures and realistic photographs, and demonstrate that our proposed framework helps these models generate temporally smooth animation clips while preserving the domain and diversity of their outputs. Code and pre-trained weights will be publicly available at this https URL .
## Available Pipelines:
| Pipeline | Tasks | Demo
|---|---|:---:|
| [AnimateDiffPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/animatediff/pipeline_animatediff.py) | *Text-to-Video Generation with AnimateDiff* |
## Usage example
AnimateDiff works with a MotionAdapter checkpoint and a Stable Diffusion model checkpoint. The MotionAdapter is a collection of Motion Modules that are responsible for adding coherent motion across image frames. These modules are applied after the Resnet and Attention blocks in Stable Diffusion UNet.
The following example demonstrates how to use a *MotionAdapter* checkpoint with Diffusers for inference based on StableDiffusion-1.4/1.5.
```python
import torch
from diffusers import MotionAdapter, AnimateDiffPipeline, DDIMScheduler
from diffusers.utils import export_to_gif
# Load the motion adapter
adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2")
# load SD 1.5 based finetuned model
model_id = "SG161222/Realistic_Vision_V5.1_noVAE"
pipe = AnimateDiffPipeline.from_pretrained(model_id, motion_adapter=adapter)
scheduler = DDIMScheduler.from_pretrained(
model_id, subfolder="scheduler", clip_sample=False, timestep_spacing="linspace", steps_offset=1
)
pipe.scheduler = scheduler
# enable memory savings
pipe.enable_vae_slicing()
pipe.enable_model_cpu_offload()
output = pipe(
prompt=(
"masterpiece, bestquality, highlydetailed, ultradetailed, sunset, "
"orange sky, warm lighting, fishing boats, ocean waves seagulls, "
"rippling water, wharf, silhouette, serene atmosphere, dusk, evening glow, "
"golden hour, coastal landscape, seaside scenery"
),
negative_prompt="bad quality, worse quality",
num_frames=16,
guidance_scale=7.5,
num_inference_steps=25,
generator=torch.Generator("cpu").manual_seed(42),
)
frames = output.frames[0]
export_to_gif(frames, "animation.gif")
```
Here are some sample outputs:
<table>
<tr>
<td><center>
masterpiece, bestquality, sunset.
<br>
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-realistic-doc.gif"
alt="masterpiece, bestquality, sunset"
style="width: 300px;" />
</center></td>
</tr>
</table>
<Tip>
AnimateDiff tends to work better with finetuned Stable Diffusion models. If you plan on using a scheduler that can clip samples, make sure to disable it by setting `clip_sample=False` in the scheduler as this can also have an adverse effect on generated samples.
</Tip>
## AnimateDiffPipeline
[[autodoc]] AnimateDiffPipeline
- all
- __call__
- enable_freeu
- disable_freeu
- enable_vae_slicing
- disable_vae_slicing
- enable_vae_tiling
- disable_vae_tiling
## AnimateDiffPipelineOutput
[[autodoc]] pipelines.animatediff.AnimateDiffPipelineOutput
## Available checkpoints
Motion Adapter checkpoints can be found under [guoyww](https://huggingface.co/guoyww/). These checkpoints are meant to work with any model based on Stable Diffusion 1.4/1.5

View File

@@ -79,6 +79,7 @@ else:
"AutoencoderTiny",
"ControlNetModel",
"ModelMixin",
"MotionAdapter",
"MultiAdapter",
"PriorTransformer",
"T2IAdapter",
@@ -88,6 +89,7 @@ else:
"UNet2DConditionModel",
"UNet2DModel",
"UNet3DConditionModel",
"UNetMotionModel",
"VQModel",
]
)
@@ -195,6 +197,7 @@ else:
[
"AltDiffusionImg2ImgPipeline",
"AltDiffusionPipeline",
"AnimateDiffPipeline",
"AudioLDM2Pipeline",
"AudioLDM2ProjectionModel",
"AudioLDM2UNet2DConditionModel",
@@ -440,6 +443,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AutoencoderTiny,
ControlNetModel,
ModelMixin,
MotionAdapter,
MultiAdapter,
PriorTransformer,
T2IAdapter,
@@ -449,6 +453,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
UNet2DConditionModel,
UNet2DModel,
UNet3DConditionModel,
UNetMotionModel,
VQModel,
)
from .optimization import (
@@ -537,6 +542,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .pipelines import (
AltDiffusionImg2ImgPipeline,
AltDiffusionPipeline,
AnimateDiffPipeline,
AudioLDM2Pipeline,
AudioLDM2ProjectionModel,
AudioLDM2UNet2DConditionModel,

View File

@@ -35,6 +35,7 @@ if is_torch_available():
_import_structure["unet_2d"] = ["UNet2DModel"]
_import_structure["unet_2d_condition"] = ["UNet2DConditionModel"]
_import_structure["unet_3d_condition"] = ["UNet3DConditionModel"]
_import_structure["unet_motion_model"] = ["MotionAdapter", "UNetMotionModel"]
_import_structure["vq_model"] = ["VQModel"]
if is_flax_available():
@@ -60,6 +61,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .unet_2d import UNet2DModel
from .unet_2d_condition import UNet2DConditionModel
from .unet_3d_condition import UNet3DConditionModel
from .unet_motion_model import MotionAdapter, UNetMotionModel
from .vq_model import VQModel
if is_flax_available():

View File

@@ -20,6 +20,7 @@ from ..utils import USE_PEFT_BACKEND
from ..utils.torch_utils import maybe_allow_in_graph
from .activations import GEGLU, GELU, ApproximateGELU
from .attention_processor import Attention
from .embeddings import SinusoidalPositionalEmbedding
from .lora import LoRACompatibleLinear
from .normalization import AdaLayerNorm, AdaLayerNormZero
@@ -96,6 +97,10 @@ class BasicTransformerBlock(nn.Module):
Whether to apply a final dropout after the last feed-forward layer.
attention_type (`str`, *optional*, defaults to `"default"`):
The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
positional_embeddings (`str`, *optional*, defaults to `None`):
The type of positional embeddings to apply to.
num_positional_embeddings (`int`, *optional*, defaults to `None`):
The maximum number of positional embeddings to apply.
"""
def __init__(
@@ -115,6 +120,8 @@ class BasicTransformerBlock(nn.Module):
norm_type: str = "layer_norm",
final_dropout: bool = False,
attention_type: str = "default",
positional_embeddings: Optional[str] = None,
num_positional_embeddings: Optional[int] = None,
):
super().__init__()
self.only_cross_attention = only_cross_attention
@@ -128,6 +135,16 @@ class BasicTransformerBlock(nn.Module):
f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
)
if positional_embeddings and (num_positional_embeddings is None):
raise ValueError(
"If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
)
if positional_embeddings == "sinusoidal":
self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
else:
self.pos_embed = None
# Define 3 blocks. Each block has its own normalization layer.
# 1. Self-Attn
if self.use_ada_layer_norm:
@@ -207,6 +224,9 @@ class BasicTransformerBlock(nn.Module):
else:
norm_hidden_states = self.norm1(hidden_states)
if self.pos_embed is not None:
norm_hidden_states = self.pos_embed(norm_hidden_states)
# 1. Retrieve lora scale.
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
@@ -234,6 +254,8 @@ class BasicTransformerBlock(nn.Module):
norm_hidden_states = (
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
)
if self.pos_embed is not None:
norm_hidden_states = self.pos_embed(norm_hidden_states)
attn_output = self.attn2(
norm_hidden_states,

View File

@@ -251,6 +251,33 @@ class GaussianFourierProjection(nn.Module):
return out
class SinusoidalPositionalEmbedding(nn.Module):
"""Apply positional information to a sequence of embeddings.
Takes in a sequence of embeddings with shape (batch_size, seq_length, embed_dim) and adds positional embeddings to
them
Args:
embed_dim: (int): Dimension of the positional embedding.
max_seq_length: Maximum sequence length to apply positional embeddings
"""
def __init__(self, embed_dim: int, max_seq_length: int = 32):
super().__init__()
position = torch.arange(max_seq_length).unsqueeze(1)
div_term = torch.exp(torch.arange(0, embed_dim, 2) * (-math.log(10000.0) / embed_dim))
pe = torch.zeros(1, max_seq_length, embed_dim)
pe[0, :, 0::2] = torch.sin(position * div_term)
pe[0, :, 1::2] = torch.cos(position * div_term)
self.register_buffer("pe", pe)
def forward(self, x):
_, seq_length, _ = x.shape
x = x + self.pe[:, :seq_length]
return x
class ImagePositionalEmbeddings(nn.Module):
"""
Converts latent image classes into vector embeddings. Sums the vector embeddings with positional embeddings for the

View File

@@ -59,6 +59,10 @@ class TransformerTemporalModel(ModelMixin, ConfigMixin):
Configure if the `TransformerBlock` should use learnable elementwise affine parameters for normalization.
double_self_attention (`bool`, *optional*):
Configure if each `TransformerBlock` should contain two self-attention layers.
positional_embeddings: (`str`, *optional*):
The type of positional embeddings to apply to the sequence input before passing use.
num_positional_embeddings: (`int`, *optional*):
The maximum length of the sequence over which to apply positional embeddings.
"""
@register_to_config
@@ -77,6 +81,8 @@ class TransformerTemporalModel(ModelMixin, ConfigMixin):
activation_fn: str = "geglu",
norm_elementwise_affine: bool = True,
double_self_attention: bool = True,
positional_embeddings: Optional[str] = None,
num_positional_embeddings: Optional[int] = None,
):
super().__init__()
self.num_attention_heads = num_attention_heads
@@ -101,6 +107,8 @@ class TransformerTemporalModel(ModelMixin, ConfigMixin):
attention_bias=attention_bias,
double_self_attention=double_self_attention,
norm_elementwise_affine=norm_elementwise_affine,
positional_embeddings=positional_embeddings,
num_positional_embeddings=num_positional_embeddings,
)
for d in range(num_layers)
]

View File

@@ -12,10 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional, Tuple
import torch
from torch import nn
from ..utils import is_torch_version
from ..utils.torch_utils import apply_freeu
from .dual_transformer_2d import DualTransformer2DModel
from .resnet import Downsample2D, ResnetBlock2D, TemporalConvLayer, Upsample2D
from .transformer_2d import Transformer2DModel
from .transformer_temporal import TransformerTemporalModel
@@ -39,6 +43,8 @@ def get_down_block(
only_cross_attention=False,
upcast_attention=False,
resnet_time_scale_shift="default",
temporal_num_attention_heads=8,
temporal_max_seq_length=32,
):
if down_block_type == "DownBlock3D":
return DownBlock3D(
@@ -74,6 +80,45 @@ def get_down_block(
upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift,
)
if down_block_type == "DownBlockMotion":
return DownBlockMotion(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
add_downsample=add_downsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
downsample_padding=downsample_padding,
resnet_time_scale_shift=resnet_time_scale_shift,
temporal_num_attention_heads=temporal_num_attention_heads,
temporal_max_seq_length=temporal_max_seq_length,
)
elif down_block_type == "CrossAttnDownBlockMotion":
if cross_attention_dim is None:
raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlockMotion")
return CrossAttnDownBlockMotion(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
add_downsample=add_downsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
downsample_padding=downsample_padding,
cross_attention_dim=cross_attention_dim,
num_attention_heads=num_attention_heads,
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift,
temporal_num_attention_heads=temporal_num_attention_heads,
temporal_max_seq_length=temporal_max_seq_length,
)
raise ValueError(f"{down_block_type} does not exist.")
@@ -96,6 +141,9 @@ def get_up_block(
only_cross_attention=False,
upcast_attention=False,
resnet_time_scale_shift="default",
temporal_num_attention_heads=8,
temporal_cross_attention_dim=None,
temporal_max_seq_length=32,
):
if up_block_type == "UpBlock3D":
return UpBlock3D(
@@ -133,6 +181,46 @@ def get_up_block(
resnet_time_scale_shift=resnet_time_scale_shift,
resolution_idx=resolution_idx,
)
if up_block_type == "UpBlockMotion":
return UpBlockMotion(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
prev_output_channel=prev_output_channel,
temb_channels=temb_channels,
add_upsample=add_upsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
resnet_time_scale_shift=resnet_time_scale_shift,
resolution_idx=resolution_idx,
temporal_num_attention_heads=temporal_num_attention_heads,
temporal_max_seq_length=temporal_max_seq_length,
)
elif up_block_type == "CrossAttnUpBlockMotion":
if cross_attention_dim is None:
raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlockMotion")
return CrossAttnUpBlockMotion(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
prev_output_channel=prev_output_channel,
temb_channels=temb_channels,
add_upsample=add_upsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
cross_attention_dim=cross_attention_dim,
num_attention_heads=num_attention_heads,
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift,
resolution_idx=resolution_idx,
temporal_num_attention_heads=temporal_num_attention_heads,
temporal_max_seq_length=temporal_max_seq_length,
)
raise ValueError(f"{up_block_type} does not exist.")
@@ -724,3 +812,800 @@ class UpBlock3D(nn.Module):
hidden_states = upsampler(hidden_states, upsample_size)
return hidden_states
class DownBlockMotion(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
temb_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
output_scale_factor=1.0,
add_downsample=True,
downsample_padding=1,
temporal_num_attention_heads=1,
temporal_cross_attention_dim=None,
temporal_max_seq_length=32,
):
super().__init__()
resnets = []
motion_modules = []
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
resnets.append(
ResnetBlock2D(
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)
motion_modules.append(
TransformerTemporalModel(
num_attention_heads=temporal_num_attention_heads,
in_channels=out_channels,
norm_num_groups=resnet_groups,
cross_attention_dim=temporal_cross_attention_dim,
attention_bias=False,
activation_fn="geglu",
positional_embeddings="sinusoidal",
num_positional_embeddings=temporal_max_seq_length,
attention_head_dim=out_channels // temporal_num_attention_heads,
)
)
self.resnets = nn.ModuleList(resnets)
self.motion_modules = nn.ModuleList(motion_modules)
if add_downsample:
self.downsamplers = nn.ModuleList(
[
Downsample2D(
out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
)
]
)
else:
self.downsamplers = None
self.gradient_checkpointing = False
def forward(self, hidden_states, temb=None, scale: float = 1.0, num_frames=1):
output_states = ()
blocks = zip(self.resnets, self.motion_modules)
for resnet, motion_module in blocks:
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
if is_torch_version(">=", "1.11.0"):
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
)
else:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, scale
)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, num_frames
)
else:
hidden_states = resnet(hidden_states, temb, scale=scale)
hidden_states = motion_module(hidden_states, num_frames=num_frames)[0]
output_states = output_states + (hidden_states,)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states, scale=scale)
output_states = output_states + (hidden_states,)
return hidden_states, output_states
class CrossAttnDownBlockMotion(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
temb_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
transformer_layers_per_block: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
num_attention_heads=1,
cross_attention_dim=1280,
output_scale_factor=1.0,
downsample_padding=1,
add_downsample=True,
dual_cross_attention=False,
use_linear_projection=False,
only_cross_attention=False,
upcast_attention=False,
attention_type="default",
temporal_cross_attention_dim=None,
temporal_num_attention_heads=8,
temporal_max_seq_length=32,
):
super().__init__()
resnets = []
attentions = []
motion_modules = []
self.has_cross_attention = True
self.num_attention_heads = num_attention_heads
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
resnets.append(
ResnetBlock2D(
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)
if not dual_cross_attention:
attentions.append(
Transformer2DModel(
num_attention_heads,
out_channels // num_attention_heads,
in_channels=out_channels,
num_layers=transformer_layers_per_block,
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
upcast_attention=upcast_attention,
attention_type=attention_type,
)
)
else:
attentions.append(
DualTransformer2DModel(
num_attention_heads,
out_channels // num_attention_heads,
in_channels=out_channels,
num_layers=1,
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
)
)
motion_modules.append(
TransformerTemporalModel(
num_attention_heads=temporal_num_attention_heads,
in_channels=out_channels,
norm_num_groups=resnet_groups,
cross_attention_dim=temporal_cross_attention_dim,
attention_bias=False,
activation_fn="geglu",
positional_embeddings="sinusoidal",
num_positional_embeddings=temporal_max_seq_length,
attention_head_dim=out_channels // temporal_num_attention_heads,
)
)
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
self.motion_modules = nn.ModuleList(motion_modules)
if add_downsample:
self.downsamplers = nn.ModuleList(
[
Downsample2D(
out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
)
]
)
else:
self.downsamplers = None
self.gradient_checkpointing = False
def forward(
self,
hidden_states,
temb=None,
encoder_hidden_states=None,
attention_mask=None,
num_frames=1,
encoder_attention_mask=None,
cross_attention_kwargs=None,
additional_residuals=None,
):
output_states = ()
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
blocks = list(zip(self.resnets, self.attentions, self.motion_modules))
for i, (resnet, attn, motion_module) in enumerate(blocks):
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
**ckpt_kwargs,
)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
else:
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
hidden_states = motion_module(
hidden_states,
num_frames=num_frames,
)[0]
# apply additional residuals to the output of the last pair of resnet and attention blocks
if i == len(blocks) - 1 and additional_residuals is not None:
hidden_states = hidden_states + additional_residuals
output_states = output_states + (hidden_states,)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states, scale=lora_scale)
output_states = output_states + (hidden_states,)
return hidden_states, output_states
class CrossAttnUpBlockMotion(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
prev_output_channel: int,
temb_channels: int,
resolution_idx: int = None,
dropout: float = 0.0,
num_layers: int = 1,
transformer_layers_per_block: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
num_attention_heads=1,
cross_attention_dim=1280,
output_scale_factor=1.0,
add_upsample=True,
dual_cross_attention=False,
use_linear_projection=False,
only_cross_attention=False,
upcast_attention=False,
attention_type="default",
temporal_cross_attention_dim=None,
temporal_num_attention_heads=8,
temporal_max_seq_length=32,
):
super().__init__()
resnets = []
attentions = []
motion_modules = []
self.has_cross_attention = True
self.num_attention_heads = num_attention_heads
for i in range(num_layers):
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
resnet_in_channels = prev_output_channel if i == 0 else out_channels
resnets.append(
ResnetBlock2D(
in_channels=resnet_in_channels + res_skip_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)
if not dual_cross_attention:
attentions.append(
Transformer2DModel(
num_attention_heads,
out_channels // num_attention_heads,
in_channels=out_channels,
num_layers=transformer_layers_per_block,
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
upcast_attention=upcast_attention,
attention_type=attention_type,
)
)
else:
attentions.append(
DualTransformer2DModel(
num_attention_heads,
out_channels // num_attention_heads,
in_channels=out_channels,
num_layers=1,
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
)
)
motion_modules.append(
TransformerTemporalModel(
num_attention_heads=temporal_num_attention_heads,
in_channels=out_channels,
norm_num_groups=resnet_groups,
cross_attention_dim=temporal_cross_attention_dim,
attention_bias=False,
activation_fn="geglu",
positional_embeddings="sinusoidal",
num_positional_embeddings=temporal_max_seq_length,
attention_head_dim=out_channels // temporal_num_attention_heads,
)
)
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
self.motion_modules = nn.ModuleList(motion_modules)
if add_upsample:
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
else:
self.upsamplers = None
self.gradient_checkpointing = False
self.resolution_idx = resolution_idx
def forward(
self,
hidden_states: torch.FloatTensor,
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
temb: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
upsample_size: Optional[int] = None,
attention_mask: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
num_frames=1,
):
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
is_freeu_enabled = (
getattr(self, "s1", None)
and getattr(self, "s2", None)
and getattr(self, "b1", None)
and getattr(self, "b2", None)
)
blocks = zip(self.resnets, self.attentions, self.motion_modules)
for resnet, attn, motion_module in blocks:
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
# FreeU: Only operate on the first two stages
if is_freeu_enabled:
hidden_states, res_hidden_states = apply_freeu(
self.resolution_idx,
hidden_states,
res_hidden_states,
s1=self.s1,
s2=self.s2,
b1=self.b1,
b2=self.b2,
)
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
**ckpt_kwargs,
)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
else:
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
hidden_states = motion_module(
hidden_states,
num_frames=num_frames,
)[0]
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states, upsample_size, scale=lora_scale)
return hidden_states
class UpBlockMotion(nn.Module):
def __init__(
self,
in_channels: int,
prev_output_channel: int,
out_channels: int,
temb_channels: int,
resolution_idx: int = None,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
output_scale_factor=1.0,
add_upsample=True,
temporal_norm_num_groups=32,
temporal_cross_attention_dim=None,
temporal_num_attention_heads=8,
temporal_max_seq_length=32,
):
super().__init__()
resnets = []
motion_modules = []
for i in range(num_layers):
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
resnet_in_channels = prev_output_channel if i == 0 else out_channels
resnets.append(
ResnetBlock2D(
in_channels=resnet_in_channels + res_skip_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)
motion_modules.append(
TransformerTemporalModel(
num_attention_heads=temporal_num_attention_heads,
in_channels=out_channels,
norm_num_groups=temporal_norm_num_groups,
cross_attention_dim=temporal_cross_attention_dim,
attention_bias=False,
activation_fn="geglu",
positional_embeddings="sinusoidal",
num_positional_embeddings=temporal_max_seq_length,
attention_head_dim=out_channels // temporal_num_attention_heads,
)
)
self.resnets = nn.ModuleList(resnets)
self.motion_modules = nn.ModuleList(motion_modules)
if add_upsample:
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
else:
self.upsamplers = None
self.gradient_checkpointing = False
self.resolution_idx = resolution_idx
def forward(
self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, scale: float = 1.0, num_frames=1
):
is_freeu_enabled = (
getattr(self, "s1", None)
and getattr(self, "s2", None)
and getattr(self, "b1", None)
and getattr(self, "b2", None)
)
blocks = zip(self.resnets, self.motion_modules)
for resnet, motion_module in blocks:
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
# FreeU: Only operate on the first two stages
if is_freeu_enabled:
hidden_states, res_hidden_states = apply_freeu(
self.resolution_idx,
hidden_states,
res_hidden_states,
s1=self.s1,
s2=self.s2,
b1=self.b1,
b2=self.b2,
)
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
if is_torch_version(">=", "1.11.0"):
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
)
else:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb
)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
)
else:
hidden_states = resnet(hidden_states, temb, scale=scale)
hidden_states = motion_module(hidden_states, num_frames=num_frames)[0]
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states, upsample_size, scale=scale)
return hidden_states
class UNetMidBlockCrossAttnMotion(nn.Module):
def __init__(
self,
in_channels: int,
temb_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
transformer_layers_per_block: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
num_attention_heads=1,
output_scale_factor=1.0,
cross_attention_dim=1280,
dual_cross_attention=False,
use_linear_projection=False,
upcast_attention=False,
attention_type="default",
temporal_num_attention_heads=1,
temporal_cross_attention_dim=None,
temporal_max_seq_length=32,
):
super().__init__()
self.has_cross_attention = True
self.num_attention_heads = num_attention_heads
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
# there is always at least one resnet
resnets = [
ResnetBlock2D(
in_channels=in_channels,
out_channels=in_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
]
attentions = []
motion_modules = []
for _ in range(num_layers):
if not dual_cross_attention:
attentions.append(
Transformer2DModel(
num_attention_heads,
in_channels // num_attention_heads,
in_channels=in_channels,
num_layers=transformer_layers_per_block,
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
use_linear_projection=use_linear_projection,
upcast_attention=upcast_attention,
attention_type=attention_type,
)
)
else:
attentions.append(
DualTransformer2DModel(
num_attention_heads,
in_channels // num_attention_heads,
in_channels=in_channels,
num_layers=1,
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
)
)
resnets.append(
ResnetBlock2D(
in_channels=in_channels,
out_channels=in_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)
motion_modules.append(
TransformerTemporalModel(
num_attention_heads=temporal_num_attention_heads,
attention_head_dim=in_channels // temporal_num_attention_heads,
in_channels=in_channels,
norm_num_groups=resnet_groups,
cross_attention_dim=temporal_cross_attention_dim,
attention_bias=False,
positional_embeddings="sinusoidal",
num_positional_embeddings=temporal_max_seq_length,
activation_fn="geglu",
)
)
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
self.motion_modules = nn.ModuleList(motion_modules)
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.FloatTensor,
temb: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
num_frames=1,
) -> torch.FloatTensor:
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale)
blocks = zip(self.attentions, self.resnets[1:], self.motion_modules)
for attn, resnet, motion_module in blocks:
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(motion_module),
hidden_states,
temb,
**ckpt_kwargs,
)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
**ckpt_kwargs,
)
else:
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
hidden_states = motion_module(
hidden_states,
num_frames=num_frames,
)[0]
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
return hidden_states

View File

@@ -0,0 +1,874 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.utils.checkpoint
from ..configuration_utils import ConfigMixin, register_to_config
from ..loaders import UNet2DConditionLoadersMixin
from ..utils import logging
from .attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
CROSS_ATTENTION_PROCESSORS,
AttentionProcessor,
AttnAddedKVProcessor,
AttnProcessor,
)
from .embeddings import TimestepEmbedding, Timesteps
from .modeling_utils import ModelMixin
from .transformer_temporal import TransformerTemporalModel
from .unet_2d_blocks import UNetMidBlock2DCrossAttn
from .unet_2d_condition import UNet2DConditionModel
from .unet_3d_blocks import (
CrossAttnDownBlockMotion,
CrossAttnUpBlockMotion,
DownBlockMotion,
UNetMidBlockCrossAttnMotion,
UpBlockMotion,
get_down_block,
get_up_block,
)
from .unet_3d_condition import UNet3DConditionOutput
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class MotionModules(nn.Module):
def __init__(
self,
in_channels,
layers_per_block=2,
num_attention_heads=8,
attention_bias=False,
cross_attention_dim=None,
activation_fn="geglu",
norm_num_groups=32,
max_seq_length=32,
):
super().__init__()
self.motion_modules = nn.ModuleList([])
for i in range(layers_per_block):
self.motion_modules.append(
TransformerTemporalModel(
in_channels=in_channels,
norm_num_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
activation_fn=activation_fn,
attention_bias=attention_bias,
num_attention_heads=num_attention_heads,
attention_head_dim=in_channels // num_attention_heads,
positional_embeddings="sinusoidal",
num_positional_embeddings=max_seq_length,
)
)
class MotionAdapter(ModelMixin, ConfigMixin):
@register_to_config
def __init__(
self,
block_out_channels=(320, 640, 1280, 1280),
motion_layers_per_block=2,
motion_mid_block_layers_per_block=1,
motion_num_attention_heads=8,
motion_norm_num_groups=32,
motion_max_seq_length=32,
use_motion_mid_block=True,
):
"""Container to store AnimateDiff Motion Modules
Args:
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
The tuple of output channels for each UNet block.
motion_layers_per_block (`int`, *optional*, defaults to 2):
The number of motion layers per UNet block.
motion_mid_block_layers_per_block (`int`, *optional*, defaults to 1):
The number of motion layers in the middle UNet block.
motion_num_attention_heads (`int`, *optional*, defaults to 8):
The number of heads to use in each attention layer of the motion module.
motion_norm_num_groups (`int`, *optional*, defaults to 32):
The number of groups to use in each group normalization layer of the motion module.
motion_max_seq_length (`int`, *optional*, defaults to 32):
The maximum sequence length to use in the motion module.
use_motion_mid_block (`bool`, *optional*, defaults to True):
Whether to use a motion module in the middle of the UNet.
"""
super().__init__()
down_blocks = []
up_blocks = []
for i, channel in enumerate(block_out_channels):
output_channel = block_out_channels[i]
down_blocks.append(
MotionModules(
in_channels=output_channel,
norm_num_groups=motion_norm_num_groups,
cross_attention_dim=None,
activation_fn="geglu",
attention_bias=False,
num_attention_heads=motion_num_attention_heads,
max_seq_length=motion_max_seq_length,
layers_per_block=motion_layers_per_block,
)
)
if use_motion_mid_block:
self.mid_block = MotionModules(
in_channels=block_out_channels[-1],
norm_num_groups=motion_norm_num_groups,
cross_attention_dim=None,
activation_fn="geglu",
attention_bias=False,
num_attention_heads=motion_num_attention_heads,
layers_per_block=motion_mid_block_layers_per_block,
max_seq_length=motion_max_seq_length,
)
else:
self.mid_block = None
reversed_block_out_channels = list(reversed(block_out_channels))
output_channel = reversed_block_out_channels[0]
for i, channel in enumerate(reversed_block_out_channels):
output_channel = reversed_block_out_channels[i]
up_blocks.append(
MotionModules(
in_channels=output_channel,
norm_num_groups=motion_norm_num_groups,
cross_attention_dim=None,
activation_fn="geglu",
attention_bias=False,
num_attention_heads=motion_num_attention_heads,
max_seq_length=motion_max_seq_length,
layers_per_block=motion_layers_per_block + 1,
)
)
self.down_blocks = nn.ModuleList(down_blocks)
self.up_blocks = nn.ModuleList(up_blocks)
def forward(self, sample):
pass
class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
r"""
A modified conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a
sample shaped output.
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
for all models (such as downloading or saving).
"""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
sample_size: Optional[int] = None,
in_channels: int = 4,
out_channels: int = 4,
down_block_types: Tuple[str] = (
"CrossAttnDownBlockMotion",
"CrossAttnDownBlockMotion",
"CrossAttnDownBlockMotion",
"DownBlockMotion",
),
up_block_types: Tuple[str] = (
"UpBlockMotion",
"CrossAttnUpBlockMotion",
"CrossAttnUpBlockMotion",
"CrossAttnUpBlockMotion",
),
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
layers_per_block: int = 2,
downsample_padding: int = 1,
mid_block_scale_factor: float = 1,
act_fn: str = "silu",
norm_num_groups: Optional[int] = 32,
norm_eps: float = 1e-5,
cross_attention_dim: int = 1280,
use_linear_projection: bool = False,
num_attention_heads: Optional[Union[int, Tuple[int]]] = 8,
motion_max_seq_length: Optional[int] = 32,
motion_num_attention_heads: int = 8,
use_motion_mid_block: int = True,
):
super().__init__()
self.sample_size = sample_size
# Check inputs
if len(down_block_types) != len(up_block_types):
raise ValueError(
f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
)
if len(block_out_channels) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
)
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
)
# input
conv_in_kernel = 3
conv_out_kernel = 3
conv_in_padding = (conv_in_kernel - 1) // 2
self.conv_in = nn.Conv2d(
in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
)
# time
time_embed_dim = block_out_channels[0] * 4
self.time_proj = Timesteps(block_out_channels[0], True, 0)
timestep_input_dim = block_out_channels[0]
self.time_embedding = TimestepEmbedding(
timestep_input_dim,
time_embed_dim,
act_fn=act_fn,
)
# class embedding
self.down_blocks = nn.ModuleList([])
self.up_blocks = nn.ModuleList([])
if isinstance(num_attention_heads, int):
num_attention_heads = (num_attention_heads,) * len(down_block_types)
# down
output_channel = block_out_channels[0]
for i, down_block_type in enumerate(down_block_types):
input_channel = output_channel
output_channel = block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
down_block = get_down_block(
down_block_type,
num_layers=layers_per_block,
in_channels=input_channel,
out_channels=output_channel,
temb_channels=time_embed_dim,
add_downsample=not is_final_block,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
num_attention_heads=num_attention_heads[i],
downsample_padding=downsample_padding,
use_linear_projection=use_linear_projection,
dual_cross_attention=False,
temporal_num_attention_heads=motion_num_attention_heads,
temporal_max_seq_length=motion_max_seq_length,
)
self.down_blocks.append(down_block)
# mid
if use_motion_mid_block:
self.mid_block = UNetMidBlockCrossAttnMotion(
in_channels=block_out_channels[-1],
temb_channels=time_embed_dim,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor,
cross_attention_dim=cross_attention_dim,
num_attention_heads=num_attention_heads[-1],
resnet_groups=norm_num_groups,
dual_cross_attention=False,
temporal_num_attention_heads=motion_num_attention_heads,
temporal_max_seq_length=motion_max_seq_length,
)
else:
self.mid_block = UNetMidBlock2DCrossAttn(
in_channels=block_out_channels[-1],
temb_channels=time_embed_dim,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor,
cross_attention_dim=cross_attention_dim,
num_attention_heads=num_attention_heads[-1],
resnet_groups=norm_num_groups,
dual_cross_attention=False,
)
# count how many layers upsample the images
self.num_upsamplers = 0
# up
reversed_block_out_channels = list(reversed(block_out_channels))
reversed_num_attention_heads = list(reversed(num_attention_heads))
output_channel = reversed_block_out_channels[0]
for i, up_block_type in enumerate(up_block_types):
is_final_block = i == len(block_out_channels) - 1
prev_output_channel = output_channel
output_channel = reversed_block_out_channels[i]
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
# add upsample block for all BUT final layer
if not is_final_block:
add_upsample = True
self.num_upsamplers += 1
else:
add_upsample = False
up_block = get_up_block(
up_block_type,
num_layers=layers_per_block + 1,
in_channels=input_channel,
out_channels=output_channel,
prev_output_channel=prev_output_channel,
temb_channels=time_embed_dim,
add_upsample=add_upsample,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
num_attention_heads=reversed_num_attention_heads[i],
dual_cross_attention=False,
resolution_idx=i,
use_linear_projection=use_linear_projection,
temporal_num_attention_heads=motion_num_attention_heads,
temporal_max_seq_length=motion_max_seq_length,
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel
# out
if norm_num_groups is not None:
self.conv_norm_out = nn.GroupNorm(
num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
)
self.conv_act = nn.SiLU()
else:
self.conv_norm_out = None
self.conv_act = None
conv_out_padding = (conv_out_kernel - 1) // 2
self.conv_out = nn.Conv2d(
block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
)
@classmethod
def from_unet2d(
cls,
unet: UNet2DConditionModel,
motion_adapter: Optional[MotionAdapter] = None,
load_weights: bool = True,
):
has_motion_adapter = motion_adapter is not None
# based on https://github.com/guoyww/AnimateDiff/blob/895f3220c06318ea0760131ec70408b466c49333/animatediff/models/unet.py#L459
config = unet.config
config["_class_name"] = cls.__name__
down_blocks = []
for down_blocks_type in config["down_block_types"]:
if "CrossAttn" in down_blocks_type:
down_blocks.append("CrossAttnDownBlockMotion")
else:
down_blocks.append("DownBlockMotion")
config["down_block_types"] = down_blocks
up_blocks = []
for down_blocks_type in config["up_block_types"]:
if "CrossAttn" in down_blocks_type:
up_blocks.append("CrossAttnUpBlockMotion")
else:
up_blocks.append("UpBlockMotion")
config["up_block_types"] = up_blocks
if has_motion_adapter:
config["motion_num_attention_heads"] = motion_adapter.config["motion_num_attention_heads"]
config["motion_max_seq_length"] = motion_adapter.config["motion_max_seq_length"]
config["use_motion_mid_block"] = motion_adapter.config["use_motion_mid_block"]
# Need this for backwards compatibility with UNet2DConditionModel checkpoints
if not config.get("num_attention_heads"):
config["num_attention_heads"] = config["attention_head_dim"]
model = cls.from_config(config)
if not load_weights:
return model
model.conv_in.load_state_dict(unet.conv_in.state_dict())
model.time_proj.load_state_dict(unet.time_proj.state_dict())
model.time_embedding.load_state_dict(unet.time_embedding.state_dict())
for i, down_block in enumerate(unet.down_blocks):
model.down_blocks[i].resnets.load_state_dict(down_block.resnets.state_dict())
if hasattr(model.down_blocks[i], "attentions"):
model.down_blocks[i].attentions.load_state_dict(down_block.attentions.state_dict())
if model.down_blocks[i].downsamplers:
model.down_blocks[i].downsamplers.load_state_dict(down_block.downsamplers.state_dict())
for i, up_block in enumerate(unet.up_blocks):
model.up_blocks[i].resnets.load_state_dict(up_block.resnets.state_dict())
if hasattr(model.up_blocks[i], "attentions"):
model.up_blocks[i].attentions.load_state_dict(up_block.attentions.state_dict())
if model.up_blocks[i].upsamplers:
model.up_blocks[i].upsamplers.load_state_dict(up_block.upsamplers.state_dict())
model.mid_block.resnets.load_state_dict(unet.mid_block.resnets.state_dict())
model.mid_block.attentions.load_state_dict(unet.mid_block.attentions.state_dict())
if unet.conv_norm_out is not None:
model.conv_norm_out.load_state_dict(unet.conv_norm_out.state_dict())
if unet.conv_act is not None:
model.conv_act.load_state_dict(unet.conv_act.state_dict())
model.conv_out.load_state_dict(unet.conv_out.state_dict())
if has_motion_adapter:
model.load_motion_modules(motion_adapter)
# ensure that the Motion UNet is the same dtype as the UNet2DConditionModel
model.to(unet.dtype)
return model
def freeze_unet2d_params(self):
"""Freeze the weights of just the UNet2DConditionModel, and leave the motion modules
unfrozen for fine tuning.
"""
# Freeze everything
for param in self.parameters():
param.requires_grad = False
# Unfreeze Motion Modules
for down_block in self.down_blocks:
motion_modules = down_block.motion_modules
for param in motion_modules.parameters():
param.requires_grad = True
for up_block in self.up_blocks:
motion_modules = up_block.motion_modules
for param in motion_modules.parameters():
param.requires_grad = True
if hasattr(self.mid_block, "motion_modules"):
motion_modules = self.mid_block.motion_modules
for param in motion_modules.parameters():
param.requires_grad = True
return
def load_motion_modules(self, motion_adapter: Optional[MotionAdapter]):
for i, down_block in enumerate(motion_adapter.down_blocks):
self.down_blocks[i].motion_modules.load_state_dict(down_block.motion_modules.state_dict())
for i, up_block in enumerate(motion_adapter.up_blocks):
self.up_blocks[i].motion_modules.load_state_dict(up_block.motion_modules.state_dict())
# to support older motion modules that don't have a mid_block
if hasattr(self.mid_block, "motion_modules"):
self.mid_block.motion_modules.load_state_dict(motion_adapter.mid_block.motion_modules.state_dict())
def save_motion_modules(
self,
save_directory: str,
is_main_process: bool = True,
safe_serialization: bool = True,
variant: Optional[str] = None,
push_to_hub: bool = False,
**kwargs,
):
state_dict = self.state_dict()
# Extract all motion modules
motion_state_dict = {}
for k, v in state_dict.items():
if "motion_modules" in k:
motion_state_dict[k] = v
adapter = MotionAdapter(
block_out_channels=self.config["block_out_channels"],
motion_layers_per_block=self.config["layers_per_block"],
motion_norm_num_groups=self.config["norm_num_groups"],
motion_num_attention_heads=self.config["motion_num_attention_heads"],
motion_max_seq_length=self.config["motion_max_seq_length"],
use_motion_mid_block=self.config["use_motion_mid_block"],
)
adapter.load_state_dict(motion_state_dict)
adapter.save_pretrained(
save_directory=save_directory,
is_main_process=is_main_process,
safe_serialization=safe_serialization,
variant=variant,
push_to_hub=push_to_hub,
**kwargs,
)
@property
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor, _remove_lora=_remove_lora)
else:
module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
# Copied from diffusers.models.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
def enable_forward_chunking(self, chunk_size=None, dim=0):
"""
Sets the attention processor to use [feed forward
chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
Parameters:
chunk_size (`int`, *optional*):
The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
over each tensor of dim=`dim`.
dim (`int`, *optional*, defaults to `0`):
The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
or dim=1 (sequence length).
"""
if dim not in [0, 1]:
raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
# By default chunk size is 1
chunk_size = chunk_size or 1
def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
if hasattr(module, "set_chunk_feed_forward"):
module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
for child in module.children():
fn_recursive_feed_forward(child, chunk_size, dim)
for module in self.children():
fn_recursive_feed_forward(module, chunk_size, dim)
# Copied from diffusers.models.unet_3d_condition.UNet3DConditionModel.disable_forward_chunking
def disable_forward_chunking(self):
def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
if hasattr(module, "set_chunk_feed_forward"):
module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
for child in module.children():
fn_recursive_feed_forward(child, chunk_size, dim)
for module in self.children():
fn_recursive_feed_forward(module, None, 0)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
def set_default_attn_processor(self):
"""
Disables custom attention processors and sets the default attention implementation.
"""
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
processor = AttnAddedKVProcessor()
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
processor = AttnProcessor()
else:
raise ValueError(
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
)
self.set_attn_processor(processor, _remove_lora=True)
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (CrossAttnDownBlockMotion, DownBlockMotion, CrossAttnUpBlockMotion, UpBlockMotion)):
module.gradient_checkpointing = value
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.enable_freeu
def enable_freeu(self, s1, s2, b1, b2):
r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
The suffixes after the scaling factors represent the stage blocks where they are being applied.
Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that
are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
Args:
s1 (`float`):
Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
mitigate the "oversmoothing effect" in the enhanced denoising process.
s2 (`float`):
Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
mitigate the "oversmoothing effect" in the enhanced denoising process.
b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
"""
for i, upsample_block in enumerate(self.up_blocks):
setattr(upsample_block, "s1", s1)
setattr(upsample_block, "s2", s2)
setattr(upsample_block, "b1", b1)
setattr(upsample_block, "b2", b2)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.disable_freeu
def disable_freeu(self):
"""Disables the FreeU mechanism."""
freeu_keys = {"s1", "s2", "b1", "b2"}
for i, upsample_block in enumerate(self.up_blocks):
for k in freeu_keys:
if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
setattr(upsample_block, k, None)
def forward(
self,
sample: torch.FloatTensor,
timestep: Union[torch.Tensor, float, int],
encoder_hidden_states: torch.Tensor,
timestep_cond: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
mid_block_additional_residual: Optional[torch.Tensor] = None,
return_dict: bool = True,
) -> Union[UNet3DConditionOutput, Tuple]:
r"""
The [`UNetMotionModel`] forward method.
Args:
sample (`torch.FloatTensor`):
The noisy input tensor with the following shape `(batch, num_frames, channel, height, width`.
timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
encoder_hidden_states (`torch.FloatTensor`):
The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`):
Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed
through the `self.time_embedding` layer to obtain the timestep embeddings.
attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
negative values to the attention scores corresponding to "discard" tokens.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*):
A tuple of tensors that if specified are added to the residuals of down unet blocks.
mid_block_additional_residual: (`torch.Tensor`, *optional*):
A tensor that if specified is added to the residual of the middle unet block.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.unet_3d_condition.UNet3DConditionOutput`] instead of a plain
tuple.
Returns:
[`~models.unet_3d_condition.UNet3DConditionOutput`] or `tuple`:
If `return_dict` is True, an [`~models.unet_3d_condition.UNet3DConditionOutput`] is returned, otherwise
a `tuple` is returned where the first element is the sample tensor.
"""
# By default samples have to be AT least a multiple of the overall upsampling factor.
# The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
# However, the upsampling interpolation output size can be forced to fit any upsampling size
# on the fly if necessary.
default_overall_up_factor = 2**self.num_upsamplers
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
forward_upsample_size = False
upsample_size = None
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
logger.info("Forward upsample size to force interpolation output size.")
forward_upsample_size = True
# prepare attention_mask
if attention_mask is not None:
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
attention_mask = attention_mask.unsqueeze(1)
# 1. time
timesteps = timestep
if not torch.is_tensor(timesteps):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = sample.device.type == "mps"
if isinstance(timestep, float):
dtype = torch.float32 if is_mps else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
num_frames = sample.shape[2]
timesteps = timesteps.expand(sample.shape[0])
t_emb = self.time_proj(timesteps)
# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=self.dtype)
emb = self.time_embedding(t_emb, timestep_cond)
emb = emb.repeat_interleave(repeats=num_frames, dim=0)
encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=num_frames, dim=0)
# 2. pre-process
sample = sample.permute(0, 2, 1, 3, 4).reshape((sample.shape[0] * num_frames, -1) + sample.shape[3:])
sample = self.conv_in(sample)
# 3. down
down_block_res_samples = (sample,)
for downsample_block in self.down_blocks:
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
sample, res_samples = downsample_block(
hidden_states=sample,
temb=emb,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
num_frames=num_frames,
cross_attention_kwargs=cross_attention_kwargs,
)
else:
sample, res_samples = downsample_block(hidden_states=sample, temb=emb, num_frames=num_frames)
down_block_res_samples += res_samples
if down_block_additional_residuals is not None:
new_down_block_res_samples = ()
for down_block_res_sample, down_block_additional_residual in zip(
down_block_res_samples, down_block_additional_residuals
):
down_block_res_sample = down_block_res_sample + down_block_additional_residual
new_down_block_res_samples += (down_block_res_sample,)
down_block_res_samples = new_down_block_res_samples
# 4. mid
if self.mid_block is not None:
# To support older versions of motion modules that don't have a mid_block
if hasattr(self.mid_block, "motion_modules"):
sample = self.mid_block(
sample,
emb,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
num_frames=num_frames,
cross_attention_kwargs=cross_attention_kwargs,
)
else:
sample = self.mid_block(
sample,
emb,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
cross_attention_kwargs=cross_attention_kwargs,
)
if mid_block_additional_residual is not None:
sample = sample + mid_block_additional_residual
# 5. up
for i, upsample_block in enumerate(self.up_blocks):
is_final_block = i == len(self.up_blocks) - 1
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
# if we have not reached the final block and need to forward the
# upsample size, we do it here
if not is_final_block and forward_upsample_size:
upsample_size = down_block_res_samples[-1].shape[2:]
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
sample = upsample_block(
hidden_states=sample,
temb=emb,
res_hidden_states_tuple=res_samples,
encoder_hidden_states=encoder_hidden_states,
upsample_size=upsample_size,
attention_mask=attention_mask,
num_frames=num_frames,
cross_attention_kwargs=cross_attention_kwargs,
)
else:
sample = upsample_block(
hidden_states=sample,
temb=emb,
res_hidden_states_tuple=res_samples,
upsample_size=upsample_size,
num_frames=num_frames,
)
# 6. post-process
if self.conv_norm_out:
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
# reshape to (batch, channel, framerate, width, height)
sample = sample[None, :].reshape((-1, num_frames) + sample.shape[1:]).permute(0, 2, 1, 3, 4)
if not return_dict:
return (sample,)
return UNet3DConditionOutput(sample=sample)

View File

@@ -62,6 +62,7 @@ except OptionalDependencyNotAvailable:
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["alt_diffusion"] = ["AltDiffusionImg2ImgPipeline", "AltDiffusionPipeline"]
_import_structure["animatediff"] = ["AnimateDiffPipeline"]
_import_structure["audioldm"] = ["AudioLDMPipeline"]
_import_structure["audioldm2"] = [
"AudioLDM2Pipeline",
@@ -291,6 +292,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from ..utils.dummy_torch_and_transformers_objects import *
else:
from .alt_diffusion import AltDiffusionImg2ImgPipeline, AltDiffusionPipeline
from .animatediff import AnimateDiffPipeline
from .audioldm import AudioLDMPipeline
from .audioldm2 import AudioLDM2Pipeline, AudioLDM2ProjectionModel, AudioLDM2UNet2DConditionModel
from .blip_diffusion import BlipDiffusionPipeline

View File

@@ -0,0 +1,46 @@
from typing import TYPE_CHECKING
from ...utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
)
_dummy_objects = {}
_import_structure = {}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_animatediff"] = ["AnimateDiffPipeline", "AnimateDiffPipelineOutput"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .pipeline_animatediff import AnimateDiffPipeline, AnimateDiffPipelineOutput
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)

View File

@@ -0,0 +1,694 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np
import torch
from transformers import CLIPTextModel, CLIPTokenizer
from ...image_processor import VaeImageProcessor
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel, UNetMotionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...models.unet_motion_model import MotionAdapter
from ...schedulers import (
DDIMScheduler,
DPMSolverMultistepScheduler,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
LMSDiscreteScheduler,
PNDMScheduler,
)
from ...utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> import torch
>>> from diffusers import MotionAdapter, AnimateDiffPipeline, DDIMScheduler
>>> from diffusers.utils import export_to_gif
>>> adapter = MotionAdapter.from_pretrained("diffusers/motion-adapter")
>>> pipe = AnimateDiffPipeline.from_pretrained("frankjoshua/toonyou_beta6", motion_adapter=adapter)
>>> pipe.scheduler = DDIMScheduler(beta_schedule="linear", steps_offset=1, clip_sample=False)
>>> output = pipe(prompt="A corgi walking in the park")
>>> frames = output.frames[0]
>>> export_to_gif(frames, "animation.gif")
```
"""
def tensor2vid(video: torch.Tensor, processor, output_type="np"):
# Based on:
# https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78
batch_size, channels, num_frames, height, width = video.shape
outputs = []
for batch_idx in range(batch_size):
batch_vid = video[batch_idx].permute(1, 0, 2, 3)
batch_output = processor.postprocess(batch_vid, output_type)
outputs.append(batch_output)
return outputs
@dataclass
class AnimateDiffPipelineOutput(BaseOutput):
frames: Union[torch.Tensor, np.ndarray]
class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
r"""
Pipeline for text-to-video generation.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`CLIPTextModel`]):
Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
tokenizer (`CLIPTokenizer`):
A [`~transformers.CLIPTokenizer`] to tokenize text.
unet ([`UNet2DConditionModel`]):
A [`UNet2DConditionModel`] used to create a UNetMotionModel to denoise the encoded video latents.
motion_adapter ([`MotionAdapter`]):
A [`MotionAdapter`] to be used in combination with `unet` to denoise the encoded video latents.
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
"""
model_cpu_offload_seq = "text_encoder->unet->vae"
def __init__(
self,
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
motion_adapter: MotionAdapter,
scheduler: Union[
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
],
):
super().__init__()
unet = UNetMotionModel.from_unet2d(unet, motion_adapter)
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
motion_adapter=motion_adapter,
scheduler=scheduler,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt
def encode_prompt(
self,
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt=None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
lora_scale: Optional[float] = None,
clip_skip: Optional[int] = None,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
device: (`torch.device`):
torch device
num_images_per_prompt (`int`):
number of images that should be generated per prompt
do_classifier_free_guidance (`bool`):
whether to use classifier free guidance or not
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
lora_scale (`float`, *optional*):
A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
"""
# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(self, LoraLoaderMixin):
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
else:
scale_lora_layers(self.text_encoder, lora_scale)
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
# textual inversion: procecss multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
text_input_ids, untruncated_ids
):
removed_text = self.tokenizer.batch_decode(
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
)
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
attention_mask = text_inputs.attention_mask.to(device)
else:
attention_mask = None
if clip_skip is None:
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
prompt_embeds = prompt_embeds[0]
else:
prompt_embeds = self.text_encoder(
text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
)
# Access the `hidden_states` first, that contains a tuple of
# all the hidden states from the encoder layers. Then index into
# the tuple to access the hidden states from the desired layer.
prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
# We also need to apply the final LayerNorm here to not mess with the
# representations. The `last_hidden_states` that we typically use for
# obtaining the final prompt representations passes through the LayerNorm
# layer.
prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
if self.text_encoder is not None:
prompt_embeds_dtype = self.text_encoder.dtype
elif self.unet is not None:
prompt_embeds_dtype = self.unet.dtype
else:
prompt_embeds_dtype = prompt_embeds.dtype
prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance and negative_prompt_embeds is None:
uncond_tokens: List[str]
if negative_prompt is None:
uncond_tokens = [""] * batch_size
elif prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif isinstance(negative_prompt, str):
uncond_tokens = [negative_prompt]
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
else:
uncond_tokens = negative_prompt
# textual inversion: procecss multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer(
uncond_tokens,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
)
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
attention_mask = uncond_input.attention_mask.to(device)
else:
attention_mask = None
negative_prompt_embeds = self.text_encoder(
uncond_input.input_ids.to(device),
attention_mask=attention_mask,
)
negative_prompt_embeds = negative_prompt_embeds[0]
if do_classifier_free_guidance:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder, lora_scale)
return prompt_embeds, negative_prompt_embeds
# Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents
def decode_latents(self, latents):
latents = 1 / self.vae.config.scaling_factor * latents
batch_size, channels, num_frames, height, width = latents.shape
latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width)
image = self.vae.decode(latents).sample
video = (
image[None, :]
.reshape(
(
batch_size,
num_frames,
-1,
)
+ image.shape[2:]
)
.permute(0, 2, 1, 3, 4)
)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
video = video.float()
return video
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
def enable_vae_slicing(self):
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.vae.enable_slicing()
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
def disable_vae_slicing(self):
r"""
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
computing decoding in one step.
"""
self.vae.disable_slicing()
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling
def enable_vae_tiling(self):
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
self.vae.enable_tiling()
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling
def disable_vae_tiling(self):
r"""
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
computing decoding in one step.
"""
self.vae.disable_tiling()
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_freeu
def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497.
The suffixes after the scaling factors represent the stages where they are being applied.
Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values
that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
Args:
s1 (`float`):
Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
mitigate "oversmoothing effect" in the enhanced denoising process.
s2 (`float`):
Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
mitigate "oversmoothing effect" in the enhanced denoising process.
b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
"""
if not hasattr(self, "unet"):
raise ValueError("The pipeline must have `unet` for using FreeU.")
self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_freeu
def disable_freeu(self):
"""Disables the FreeU mechanism if enabled."""
self.unet.disable_freeu()
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
# check if the scheduler accepts generator
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
if accepts_generator:
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs
def check_inputs(
self,
prompt,
height,
width,
callback_steps,
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
):
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
# Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.prepare_latents
def prepare_latents(
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
):
shape = (
batch_size,
num_channels_latents,
num_frames,
height // self.vae_scale_factor,
width // self.vae_scale_factor,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
latents = latents.to(device)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
return latents
@torch.no_grad()
def __call__(
self,
prompt: Union[str, List[str]],
num_frames: Optional[int] = 16,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_videos_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: Optional[int] = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
clip_skip: Optional[int] = None,
):
r"""
The call function to the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
The height in pixels of the generated video.
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
The width in pixels of the generated video.
num_frames (`int`, *optional*, defaults to 16):
The number of video frames that are generated. Defaults to 16 frames which at 8 frames per seconds
amounts to 2 seconds of video.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality videos at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 7.5):
A higher guidance scale value encourages the model to generate images closely linked to the text
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide what to not include in image generation. If not defined, you need to
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor is generated by sampling using the supplied random `generator`. Latents should be of shape
`(batch_size, num_channel, num_frames, height, width)`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
provided, text embeddings are generated from the `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated video. Choose between `torch.FloatTensor`, `PIL.Image` or
`np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] instead
of a plain tuple.
callback (`Callable`, *optional*):
A function that calls every `callback_steps` steps during inference. The function is called with the
following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function is called. If not specified, the callback is called at
every step.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
Examples:
Returns:
[`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] is
returned, otherwise a `tuple` is returned where the first element is a list with the generated frames.
"""
# 0. Default height and width to unet
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor
num_videos_per_prompt = 1
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
)
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
text_encoder_lora_scale = (
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
)
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt,
device,
num_videos_per_prompt,
do_classifier_free_guidance,
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
lora_scale=text_encoder_lora_scale,
clip_skip=clip_skip,
)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
# 5. Prepare latent variables
num_channels_latents = self.unet.config.in_channels
latents = self.prepare_latents(
batch_size * num_videos_per_prompt,
num_channels_latents,
num_frames,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
)
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
noise_pred = self.unet(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs,
).sample
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
if output_type == "latent":
return AnimateDiffPipelineOutput(frames=latents)
# Post-processing
video_tensor = self.decode_latents(latents)
if output_type == "pt":
video = video_tensor
else:
video = tensor2vid(video_tensor, self.image_processor, output_type=output_type)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (video,)
return AnimateDiffPipelineOutput(frames=video)

View File

@@ -77,6 +77,21 @@ class ModelMixin(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class MotionAdapter(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class MultiAdapter(metaclass=DummyObject):
_backends = ["torch"]
@@ -212,6 +227,21 @@ class UNet3DConditionModel(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class UNetMotionModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class VQModel(metaclass=DummyObject):
_backends = ["torch"]

View File

@@ -32,6 +32,21 @@ class AltDiffusionPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class AnimateDiffPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class AudioLDM2Pipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]

View File

@@ -0,0 +1,306 @@
# coding=utf-8
# Copyright 2023 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import os
import tempfile
import unittest
import numpy as np
import torch
from diffusers import MotionAdapter, UNet2DConditionModel, UNetMotionModel
from diffusers.utils import logging
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import (
enable_full_determinism,
floats_tensor,
torch_device,
)
from .test_modeling_common import ModelTesterMixin, UNetTesterMixin
logger = logging.get_logger(__name__)
enable_full_determinism()
class UNetMotionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
model_class = UNetMotionModel
main_input_name = "sample"
@property
def dummy_input(self):
batch_size = 4
num_channels = 4
num_frames = 8
sizes = (32, 32)
noise = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device)
time_step = torch.tensor([10]).to(torch_device)
encoder_hidden_states = floats_tensor((batch_size, 4, 32)).to(torch_device)
return {"sample": noise, "timestep": time_step, "encoder_hidden_states": encoder_hidden_states}
@property
def input_shape(self):
return (4, 8, 32, 32)
@property
def output_shape(self):
return (4, 8, 32, 32)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"block_out_channels": (32, 64),
"down_block_types": ("CrossAttnDownBlockMotion", "DownBlockMotion"),
"up_block_types": ("UpBlockMotion", "CrossAttnUpBlockMotion"),
"cross_attention_dim": 32,
"num_attention_heads": 4,
"out_channels": 4,
"in_channels": 4,
"layers_per_block": 1,
"sample_size": 32,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_from_unet2d(self):
torch.manual_seed(0)
unet2d = UNet2DConditionModel()
torch.manual_seed(1)
model = self.model_class.from_unet2d(unet2d)
model_state_dict = model.state_dict()
for param_name, param_value in unet2d.named_parameters():
self.assertTrue(torch.equal(model_state_dict[param_name], param_value))
def test_freeze_unet2d(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.freeze_unet2d_params()
for param_name, param_value in model.named_parameters():
if "motion_modules" not in param_name:
self.assertFalse(param_value.requires_grad)
else:
self.assertTrue(param_value.requires_grad)
def test_loading_motion_adapter(self):
model = self.model_class()
adapter = MotionAdapter()
model.load_motion_modules(adapter)
for idx, down_block in enumerate(model.down_blocks):
adapter_state_dict = adapter.down_blocks[idx].motion_modules.state_dict()
for param_name, param_value in down_block.motion_modules.named_parameters():
self.assertTrue(torch.equal(adapter_state_dict[param_name], param_value))
for idx, up_block in enumerate(model.up_blocks):
adapter_state_dict = adapter.up_blocks[idx].motion_modules.state_dict()
for param_name, param_value in up_block.motion_modules.named_parameters():
self.assertTrue(torch.equal(adapter_state_dict[param_name], param_value))
mid_block_adapter_state_dict = adapter.mid_block.motion_modules.state_dict()
for param_name, param_value in model.mid_block.motion_modules.named_parameters():
self.assertTrue(torch.equal(mid_block_adapter_state_dict[param_name], param_value))
def test_saving_motion_modules(self):
torch.manual_seed(0)
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_motion_modules(tmpdirname)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "diffusion_pytorch_model.safetensors")))
adapter_loaded = MotionAdapter.from_pretrained(tmpdirname)
torch.manual_seed(0)
model_loaded = self.model_class(**init_dict)
model_loaded.load_motion_modules(adapter_loaded)
model_loaded.to(torch_device)
with torch.no_grad():
output = model(**inputs_dict)[0]
output_loaded = model_loaded(**inputs_dict)[0]
max_diff = (output - output_loaded).abs().max().item()
self.assertLessEqual(max_diff, 1e-4, "Models give different forward passes")
@unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(),
reason="XFormers attention is only available with CUDA and `xformers` installed",
)
def test_xformers_enable_works(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.enable_xformers_memory_efficient_attention()
assert (
model.mid_block.attentions[0].transformer_blocks[0].attn1.processor.__class__.__name__
== "XFormersAttnProcessor"
), "xformers is not enabled"
def test_gradient_checkpointing_is_applied(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model_class_copy = copy.copy(self.model_class)
modules_with_gc_enabled = {}
# now monkey patch the following function:
# def _set_gradient_checkpointing(self, module, value=False):
# if hasattr(module, "gradient_checkpointing"):
# module.gradient_checkpointing = value
def _set_gradient_checkpointing_new(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
modules_with_gc_enabled[module.__class__.__name__] = True
model_class_copy._set_gradient_checkpointing = _set_gradient_checkpointing_new
model = model_class_copy(**init_dict)
model.enable_gradient_checkpointing()
EXPECTED_SET = {
"CrossAttnUpBlockMotion",
"CrossAttnDownBlockMotion",
"UNetMidBlockCrossAttnMotion",
"UpBlockMotion",
"Transformer2DModel",
"DownBlockMotion",
}
assert set(modules_with_gc_enabled.keys()) == EXPECTED_SET
assert all(modules_with_gc_enabled.values()), "All modules should be enabled"
def test_feed_forward_chunking(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["norm_num_groups"] = 32
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
with torch.no_grad():
output = model(**inputs_dict)[0]
model.enable_forward_chunking()
with torch.no_grad():
output_2 = model(**inputs_dict)[0]
self.assertEqual(output.shape, output_2.shape, "Shape doesn't match")
assert np.abs(output.cpu() - output_2.cpu()).max() < 1e-2
def test_pickle(self):
# enable deterministic behavior for gradient checkpointing
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device)
with torch.no_grad():
sample = model(**inputs_dict).sample
sample_copy = copy.copy(sample)
assert (sample - sample_copy).abs().max() < 1e-4
def test_from_save_pretrained(self, expected_max_diff=5e-5):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
torch.manual_seed(0)
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname, safe_serialization=False)
torch.manual_seed(0)
new_model = self.model_class.from_pretrained(tmpdirname)
new_model.to(torch_device)
with torch.no_grad():
image = model(**inputs_dict)
if isinstance(image, dict):
image = image.to_tuple()[0]
new_image = new_model(**inputs_dict)
if isinstance(new_image, dict):
new_image = new_image.to_tuple()[0]
max_diff = (image - new_image).abs().max().item()
self.assertLessEqual(max_diff, expected_max_diff, "Models give different forward passes")
def test_from_save_pretrained_variant(self, expected_max_diff=5e-5):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
torch.manual_seed(0)
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname, variant="fp16", safe_serialization=False)
torch.manual_seed(0)
new_model = self.model_class.from_pretrained(tmpdirname, variant="fp16")
# non-variant cannot be loaded
with self.assertRaises(OSError) as error_context:
self.model_class.from_pretrained(tmpdirname)
# make sure that error message states what keys are missing
assert "Error no file named diffusion_pytorch_model.bin found in directory" in str(error_context.exception)
new_model.to(torch_device)
with torch.no_grad():
image = model(**inputs_dict)
if isinstance(image, dict):
image = image.to_tuple()[0]
new_image = new_model(**inputs_dict)
if isinstance(new_image, dict):
new_image = new_image.to_tuple()[0]
max_diff = (image - new_image).abs().max().item()
self.assertLessEqual(max_diff, expected_max_diff, "Models give different forward passes")
def test_forward_with_norm_groups(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["norm_num_groups"] = 16
init_dict["block_out_channels"] = (16, 32)
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
with torch.no_grad():
output = model(**inputs_dict)
if isinstance(output, dict):
output = output.to_tuple()[0]
self.assertIsNotNone(output)
expected_shape = inputs_dict["sample"].shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")

View File

View File

@@ -0,0 +1,279 @@
import gc
import unittest
import numpy as np
import torch
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
import diffusers
from diffusers import (
AnimateDiffPipeline,
AutoencoderKL,
DDIMScheduler,
MotionAdapter,
UNet2DConditionModel,
UNetMotionModel,
)
from diffusers.utils import logging
from diffusers.utils.testing_utils import numpy_cosine_similarity_distance, require_torch_gpu, slow, torch_device
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin
def to_np(tensor):
if isinstance(tensor, torch.Tensor):
tensor = tensor.detach().cpu().numpy()
return tensor
class AnimateDiffPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = AnimateDiffPipeline
params = TEXT_TO_IMAGE_PARAMS
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
required_optional_params = frozenset(
[
"num_inference_steps",
"generator",
"latents",
"return_dict",
"callback",
"callback_steps",
]
)
def get_dummy_components(self):
torch.manual_seed(0)
unet = UNet2DConditionModel(
block_out_channels=(32, 64),
layers_per_block=2,
sample_size=32,
in_channels=4,
out_channels=4,
down_block_types=("CrossAttnDownBlock2D", "DownBlock2D"),
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
cross_attention_dim=32,
norm_num_groups=2,
)
scheduler = DDIMScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="linear",
clip_sample=False,
)
torch.manual_seed(0)
vae = AutoencoderKL(
block_out_channels=[32, 64],
in_channels=3,
out_channels=3,
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
latent_channels=4,
)
torch.manual_seed(0)
text_encoder_config = CLIPTextConfig(
bos_token_id=0,
eos_token_id=2,
hidden_size=32,
intermediate_size=37,
layer_norm_eps=1e-05,
num_attention_heads=4,
num_hidden_layers=5,
pad_token_id=1,
vocab_size=1000,
)
text_encoder = CLIPTextModel(text_encoder_config)
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
motion_adapter = MotionAdapter(
block_out_channels=(32, 64),
motion_layers_per_block=2,
motion_norm_num_groups=2,
motion_num_attention_heads=4,
)
components = {
"unet": unet,
"scheduler": scheduler,
"vae": vae,
"motion_adapter": motion_adapter,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
}
return components
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
inputs = {
"prompt": "A painting of a squirrel eating a burger",
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 7.5,
"output_type": "pt",
}
return inputs
def test_motion_unet_loading(self):
components = self.get_dummy_components()
pipe = AnimateDiffPipeline(**components)
assert isinstance(pipe.unet, UNetMotionModel)
@unittest.skip("Attention slicing is not enabled in this pipeline")
def test_attention_slicing_forward_pass(self):
pass
def test_inference_batch_single_identical(
self,
batch_size=2,
expected_max_diff=1e-4,
additional_params_copy_to_batched_inputs=["num_inference_steps"],
):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
for components in pipe.components.values():
if hasattr(components, "set_default_attn_processor"):
components.set_default_attn_processor()
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
# Reset generator in case it is has been used in self.get_dummy_inputs
inputs["generator"] = self.get_generator(0)
logger = logging.get_logger(pipe.__module__)
logger.setLevel(level=diffusers.logging.FATAL)
# batchify inputs
batched_inputs = {}
batched_inputs.update(inputs)
for name in self.batch_params:
if name not in inputs:
continue
value = inputs[name]
if name == "prompt":
len_prompt = len(value)
batched_inputs[name] = [value[: len_prompt // i] for i in range(1, batch_size + 1)]
batched_inputs[name][-1] = 100 * "very long"
else:
batched_inputs[name] = batch_size * [value]
if "generator" in inputs:
batched_inputs["generator"] = [self.get_generator(i) for i in range(batch_size)]
if "batch_size" in inputs:
batched_inputs["batch_size"] = batch_size
for arg in additional_params_copy_to_batched_inputs:
batched_inputs[arg] = inputs[arg]
output = pipe(**inputs)
output_batch = pipe(**batched_inputs)
assert output_batch[0].shape[0] == batch_size
max_diff = np.abs(to_np(output_batch[0][0]) - to_np(output[0][0])).max()
assert max_diff < expected_max_diff
@unittest.skipIf(torch_device != "cuda", reason="CUDA and CPU are required to switch devices")
def test_to_device(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.set_progress_bar_config(disable=None)
pipe.to("cpu")
# pipeline creates a new motion UNet under the hood. So we need to check the device from pipe.components
model_devices = [
component.device.type for component in pipe.components.values() if hasattr(component, "device")
]
self.assertTrue(all(device == "cpu" for device in model_devices))
output_cpu = pipe(**self.get_dummy_inputs("cpu"))[0]
self.assertTrue(np.isnan(output_cpu).sum() == 0)
pipe.to("cuda")
model_devices = [
component.device.type for component in pipe.components.values() if hasattr(component, "device")
]
self.assertTrue(all(device == "cuda" for device in model_devices))
output_cuda = pipe(**self.get_dummy_inputs("cuda"))[0]
self.assertTrue(np.isnan(to_np(output_cuda)).sum() == 0)
def test_to_dtype(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.set_progress_bar_config(disable=None)
# pipeline creates a new motion UNet under the hood. So we need to check the dtype from pipe.components
model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes))
pipe.to(torch_dtype=torch.float16)
model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes))
@slow
@require_torch_gpu
class AnimateDiffPipelineSlowTests(unittest.TestCase):
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def test_animatediff(self):
adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2")
pipe = AnimateDiffPipeline.from_pretrained("frankjoshua/toonyou_beta6", motion_adapter=adapter)
pipe = pipe.to(torch_device)
pipe.scheduler = DDIMScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="linear",
steps_offset=1,
clip_sample=False,
)
pipe.enable_vae_slicing()
pipe.enable_model_cpu_offload()
pipe.set_progress_bar_config(disable=None)
prompt = "night, b&w photo of old house, post apocalypse, forest, storm weather, wind, rocks, 8k uhd, dslr, soft lighting, high quality, film grain"
negative_prompt = "bad quality, worse quality"
generator = torch.Generator("cpu").manual_seed(0)
output = pipe(
prompt,
negative_prompt=negative_prompt,
num_frames=16,
generator=generator,
guidance_scale=7.5,
num_inference_steps=3,
output_type="np",
)
image = output.frames[0]
assert image.shape == (16, 512, 512, 3)
image_slice = image[0, -3:, -3:, -1]
expected_slice = np.array(
[
0.11357737,
0.11285847,
0.11180121,
0.11084166,
0.11414117,
0.09785956,
0.10742754,
0.10510018,
0.08045256,
]
)
assert numpy_cosine_similarity_distance(image_slice.flatten(), expected_slice.flatten()) < 1e-3