mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-06 12:34:13 +08:00
[Pipeline] AnimateDiff SDXL (#6721)
* update conversion script to handle motion adapter sdxl checkpoint * add animatediff xl * handle addition_embed_type * fix output * update * add imports * make fix-copies * add decode latents * update docstrings * add animatediff sdxl to docs * remove unnecessary lines * update example * add test * revert conv_in conv_out kernel param * remove unused param addition_embed_type_num_heads * latest IPAdapter impl * make fix-copies * fix return * add IPAdapterTesterMixin to tests * fix return * revert based on suggestion * add freeinit * fix test_to_dtype test * use StableDiffusionMixin instead of different helper methods * fix progress bar iterations * apply suggestions from review * hardcode flip_sin_to_cos and freq_shift * make fix-copies * fix ip adapter implementation * fix last failing test * make style * Update docs/source/en/api/pipelines/animatediff.md Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com> * remove todo * fix doc-builder errors --------- Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
This commit is contained in:
@@ -101,6 +101,53 @@ AnimateDiff tends to work better with finetuned Stable Diffusion models. If you
|
||||
|
||||
</Tip>
|
||||
|
||||
### AnimateDiffSDXLPipeline
|
||||
|
||||
AnimateDiff can also be used with SDXL models. This is currently an experimental feature as only a beta release of the motion adapter checkpoint is available.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers.models import MotionAdapter
|
||||
from diffusers import AnimateDiffSDXLPipeline, DDIMScheduler
|
||||
from diffusers.utils import export_to_gif
|
||||
|
||||
adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-sdxl-beta", torch_dtype=torch.float16)
|
||||
|
||||
model_id = "stabilityai/stable-diffusion-xl-base-1.0"
|
||||
scheduler = DDIMScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
clip_sample=False,
|
||||
timestep_spacing="linspace",
|
||||
beta_schedule="linear",
|
||||
steps_offset=1,
|
||||
)
|
||||
pipe = AnimateDiffSDXLPipeline.from_pretrained(
|
||||
model_id,
|
||||
motion_adapter=adapter,
|
||||
scheduler=scheduler,
|
||||
torch_dtype=torch.float16,
|
||||
variant="fp16",
|
||||
).to("cuda")
|
||||
|
||||
# enable memory savings
|
||||
pipe.enable_vae_slicing()
|
||||
pipe.enable_vae_tiling()
|
||||
|
||||
output = pipe(
|
||||
prompt="a panda surfing in the ocean, realistic, high quality",
|
||||
negative_prompt="low quality, worst quality",
|
||||
num_inference_steps=20,
|
||||
guidance_scale=8,
|
||||
width=1024,
|
||||
height=1024,
|
||||
num_frames=16,
|
||||
)
|
||||
|
||||
frames = output.frames[0]
|
||||
export_to_gif(frames, "animation.gif")
|
||||
```
|
||||
|
||||
### AnimateDiffVideoToVideoPipeline
|
||||
|
||||
AnimateDiff can also be used to generate visually similar videos or enable style/character/background or other edits starting from an initial video, allowing you to seamlessly explore creative possibilities.
|
||||
@@ -522,6 +569,12 @@ export_to_gif(frames, "animatelcm-motion-lora.gif")
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## AnimateDiffSDXLPipeline
|
||||
|
||||
[[autodoc]] AnimateDiffSDXLPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## AnimateDiffVideoToVideoPipeline
|
||||
|
||||
[[autodoc]] AnimateDiffVideoToVideoPipeline
|
||||
|
||||
@@ -31,6 +31,7 @@ def get_args():
|
||||
parser.add_argument("--output_path", type=str, required=True)
|
||||
parser.add_argument("--use_motion_mid_block", action="store_true")
|
||||
parser.add_argument("--motion_max_seq_length", type=int, default=32)
|
||||
parser.add_argument("--block_out_channels", nargs="+", default=[320, 640, 1280, 1280], type=int)
|
||||
parser.add_argument("--save_fp16", action="store_true")
|
||||
|
||||
return parser.parse_args()
|
||||
@@ -49,11 +50,13 @@ if __name__ == "__main__":
|
||||
|
||||
conv_state_dict = convert_motion_module(state_dict)
|
||||
adapter = MotionAdapter(
|
||||
use_motion_mid_block=args.use_motion_mid_block, motion_max_seq_length=args.motion_max_seq_length
|
||||
block_out_channels=args.block_out_channels,
|
||||
use_motion_mid_block=args.use_motion_mid_block,
|
||||
motion_max_seq_length=args.motion_max_seq_length,
|
||||
)
|
||||
# skip loading position embeddings
|
||||
adapter.load_state_dict(conv_state_dict, strict=False)
|
||||
adapter.save_pretrained(args.output_path)
|
||||
|
||||
if args.save_fp16:
|
||||
adapter.to(torch.float16).save_pretrained(args.output_path, variant="fp16")
|
||||
adapter.to(dtype=torch.float16).save_pretrained(args.output_path, variant="fp16")
|
||||
|
||||
@@ -216,6 +216,7 @@ else:
|
||||
"AmusedInpaintPipeline",
|
||||
"AmusedPipeline",
|
||||
"AnimateDiffPipeline",
|
||||
"AnimateDiffSDXLPipeline",
|
||||
"AnimateDiffVideoToVideoPipeline",
|
||||
"AudioLDM2Pipeline",
|
||||
"AudioLDM2ProjectionModel",
|
||||
@@ -595,6 +596,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
AmusedInpaintPipeline,
|
||||
AmusedPipeline,
|
||||
AnimateDiffPipeline,
|
||||
AnimateDiffSDXLPipeline,
|
||||
AnimateDiffVideoToVideoPipeline,
|
||||
AudioLDM2Pipeline,
|
||||
AudioLDM2ProjectionModel,
|
||||
|
||||
@@ -121,6 +121,7 @@ def get_down_block(
|
||||
raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlockMotion")
|
||||
return CrossAttnDownBlockMotion(
|
||||
num_layers=num_layers,
|
||||
transformer_layers_per_block=transformer_layers_per_block,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
@@ -255,6 +256,7 @@ def get_up_block(
|
||||
raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlockMotion")
|
||||
return CrossAttnUpBlockMotion(
|
||||
num_layers=num_layers,
|
||||
transformer_layers_per_block=transformer_layers_per_block,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
prev_output_channel=prev_output_channel,
|
||||
|
||||
@@ -211,6 +211,8 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
||||
norm_num_groups: int = 32,
|
||||
norm_eps: float = 1e-5,
|
||||
cross_attention_dim: int = 1280,
|
||||
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
|
||||
reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None,
|
||||
use_linear_projection: bool = False,
|
||||
num_attention_heads: Union[int, Tuple[int, ...]] = 8,
|
||||
motion_max_seq_length: int = 32,
|
||||
@@ -218,6 +220,9 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
||||
use_motion_mid_block: int = True,
|
||||
encoder_hid_dim: Optional[int] = None,
|
||||
encoder_hid_dim_type: Optional[str] = None,
|
||||
addition_embed_type: Optional[str] = None,
|
||||
addition_time_embed_dim: Optional[int] = None,
|
||||
projection_class_embeddings_input_dim: Optional[int] = None,
|
||||
time_cond_proj_dim: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
@@ -240,6 +245,21 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
||||
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
|
||||
)
|
||||
|
||||
if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
|
||||
raise ValueError(
|
||||
f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
|
||||
)
|
||||
|
||||
if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
|
||||
raise ValueError(
|
||||
f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
|
||||
)
|
||||
|
||||
if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None:
|
||||
for layer_number_per_block in transformer_layers_per_block:
|
||||
if isinstance(layer_number_per_block, list):
|
||||
raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.")
|
||||
|
||||
# input
|
||||
conv_in_kernel = 3
|
||||
conv_out_kernel = 3
|
||||
@@ -260,6 +280,10 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
||||
if encoder_hid_dim_type is None:
|
||||
self.encoder_hid_proj = None
|
||||
|
||||
if addition_embed_type == "text_time":
|
||||
self.add_time_proj = Timesteps(addition_time_embed_dim, True, 0)
|
||||
self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
|
||||
|
||||
# class embedding
|
||||
self.down_blocks = nn.ModuleList([])
|
||||
self.up_blocks = nn.ModuleList([])
|
||||
@@ -267,6 +291,15 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
||||
if isinstance(num_attention_heads, int):
|
||||
num_attention_heads = (num_attention_heads,) * len(down_block_types)
|
||||
|
||||
if isinstance(cross_attention_dim, int):
|
||||
cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
|
||||
|
||||
if isinstance(layers_per_block, int):
|
||||
layers_per_block = [layers_per_block] * len(down_block_types)
|
||||
|
||||
if isinstance(transformer_layers_per_block, int):
|
||||
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
|
||||
|
||||
# down
|
||||
output_channel = block_out_channels[0]
|
||||
for i, down_block_type in enumerate(down_block_types):
|
||||
@@ -276,7 +309,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
||||
|
||||
down_block = get_down_block(
|
||||
down_block_type,
|
||||
num_layers=layers_per_block,
|
||||
num_layers=layers_per_block[i],
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
temb_channels=time_embed_dim,
|
||||
@@ -284,13 +317,14 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
||||
resnet_eps=norm_eps,
|
||||
resnet_act_fn=act_fn,
|
||||
resnet_groups=norm_num_groups,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
cross_attention_dim=cross_attention_dim[i],
|
||||
num_attention_heads=num_attention_heads[i],
|
||||
downsample_padding=downsample_padding,
|
||||
use_linear_projection=use_linear_projection,
|
||||
dual_cross_attention=False,
|
||||
temporal_num_attention_heads=motion_num_attention_heads,
|
||||
temporal_max_seq_length=motion_max_seq_length,
|
||||
transformer_layers_per_block=transformer_layers_per_block[i],
|
||||
)
|
||||
self.down_blocks.append(down_block)
|
||||
|
||||
@@ -302,13 +336,14 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
||||
resnet_eps=norm_eps,
|
||||
resnet_act_fn=act_fn,
|
||||
output_scale_factor=mid_block_scale_factor,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
cross_attention_dim=cross_attention_dim[-1],
|
||||
num_attention_heads=num_attention_heads[-1],
|
||||
resnet_groups=norm_num_groups,
|
||||
dual_cross_attention=False,
|
||||
use_linear_projection=use_linear_projection,
|
||||
temporal_num_attention_heads=motion_num_attention_heads,
|
||||
temporal_max_seq_length=motion_max_seq_length,
|
||||
transformer_layers_per_block=transformer_layers_per_block[-1],
|
||||
)
|
||||
|
||||
else:
|
||||
@@ -318,11 +353,12 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
||||
resnet_eps=norm_eps,
|
||||
resnet_act_fn=act_fn,
|
||||
output_scale_factor=mid_block_scale_factor,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
cross_attention_dim=cross_attention_dim[-1],
|
||||
num_attention_heads=num_attention_heads[-1],
|
||||
resnet_groups=norm_num_groups,
|
||||
dual_cross_attention=False,
|
||||
use_linear_projection=use_linear_projection,
|
||||
transformer_layers_per_block=transformer_layers_per_block[-1],
|
||||
)
|
||||
|
||||
# count how many layers upsample the images
|
||||
@@ -331,6 +367,9 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
||||
# up
|
||||
reversed_block_out_channels = list(reversed(block_out_channels))
|
||||
reversed_num_attention_heads = list(reversed(num_attention_heads))
|
||||
reversed_layers_per_block = list(reversed(layers_per_block))
|
||||
reversed_cross_attention_dim = list(reversed(cross_attention_dim))
|
||||
reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
|
||||
|
||||
output_channel = reversed_block_out_channels[0]
|
||||
for i, up_block_type in enumerate(up_block_types):
|
||||
@@ -349,7 +388,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
||||
|
||||
up_block = get_up_block(
|
||||
up_block_type,
|
||||
num_layers=layers_per_block + 1,
|
||||
num_layers=reversed_layers_per_block[i] + 1,
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
prev_output_channel=prev_output_channel,
|
||||
@@ -358,13 +397,14 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
||||
resnet_eps=norm_eps,
|
||||
resnet_act_fn=act_fn,
|
||||
resnet_groups=norm_num_groups,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
cross_attention_dim=reversed_cross_attention_dim[i],
|
||||
num_attention_heads=reversed_num_attention_heads[i],
|
||||
dual_cross_attention=False,
|
||||
resolution_idx=i,
|
||||
use_linear_projection=use_linear_projection,
|
||||
temporal_num_attention_heads=motion_num_attention_heads,
|
||||
temporal_max_seq_length=motion_max_seq_length,
|
||||
transformer_layers_per_block=reversed_transformer_layers_per_block[i],
|
||||
)
|
||||
self.up_blocks.append(up_block)
|
||||
prev_output_channel = output_channel
|
||||
@@ -835,6 +875,28 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
||||
t_emb = t_emb.to(dtype=self.dtype)
|
||||
|
||||
emb = self.time_embedding(t_emb, timestep_cond)
|
||||
aug_emb = None
|
||||
|
||||
if self.config.addition_embed_type == "text_time":
|
||||
if "text_embeds" not in added_cond_kwargs:
|
||||
raise ValueError(
|
||||
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
|
||||
)
|
||||
|
||||
text_embeds = added_cond_kwargs.get("text_embeds")
|
||||
if "time_ids" not in added_cond_kwargs:
|
||||
raise ValueError(
|
||||
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
|
||||
)
|
||||
time_ids = added_cond_kwargs.get("time_ids")
|
||||
time_embeds = self.add_time_proj(time_ids.flatten())
|
||||
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
|
||||
|
||||
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
|
||||
add_embeds = add_embeds.to(emb.dtype)
|
||||
aug_emb = self.add_embedding(add_embeds)
|
||||
|
||||
emb = emb if aug_emb is None else emb + aug_emb
|
||||
emb = emb.repeat_interleave(repeats=num_frames, dim=0)
|
||||
encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=num_frames, dim=0)
|
||||
|
||||
|
||||
@@ -114,6 +114,7 @@ else:
|
||||
_import_structure["amused"] = ["AmusedImg2ImgPipeline", "AmusedInpaintPipeline", "AmusedPipeline"]
|
||||
_import_structure["animatediff"] = [
|
||||
"AnimateDiffPipeline",
|
||||
"AnimateDiffSDXLPipeline",
|
||||
"AnimateDiffVideoToVideoPipeline",
|
||||
]
|
||||
_import_structure["audioldm"] = ["AudioLDMPipeline"]
|
||||
@@ -367,7 +368,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from ..utils.dummy_torch_and_transformers_objects import *
|
||||
else:
|
||||
from .amused import AmusedImg2ImgPipeline, AmusedInpaintPipeline, AmusedPipeline
|
||||
from .animatediff import AnimateDiffPipeline, AnimateDiffVideoToVideoPipeline
|
||||
from .animatediff import AnimateDiffPipeline, AnimateDiffSDXLPipeline, AnimateDiffVideoToVideoPipeline
|
||||
from .audioldm import AudioLDMPipeline
|
||||
from .audioldm2 import (
|
||||
AudioLDM2Pipeline,
|
||||
|
||||
@@ -22,6 +22,7 @@ except OptionalDependencyNotAvailable:
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["pipeline_animatediff"] = ["AnimateDiffPipeline"]
|
||||
_import_structure["pipeline_animatediff_sdxl"] = ["AnimateDiffSDXLPipeline"]
|
||||
_import_structure["pipeline_animatediff_video2video"] = ["AnimateDiffVideoToVideoPipeline"]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
@@ -33,6 +34,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
|
||||
else:
|
||||
from .pipeline_animatediff import AnimateDiffPipeline
|
||||
from .pipeline_animatediff_sdxl import AnimateDiffSDXLPipeline
|
||||
from .pipeline_animatediff_video2video import AnimateDiffVideoToVideoPipeline
|
||||
from .pipeline_output import AnimateDiffPipelineOutput
|
||||
|
||||
|
||||
1284
src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py
Normal file
1284
src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -92,6 +92,21 @@ class AnimateDiffPipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class AnimateDiffSDXLPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class AnimateDiffVideoToVideoPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
|
||||
307
tests/pipelines/animatediff/test_animatediff_sdxl.py
Normal file
307
tests/pipelines/animatediff/test_animatediff_sdxl.py
Normal file
@@ -0,0 +1,307 @@
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
|
||||
|
||||
import diffusers
|
||||
from diffusers import (
|
||||
AnimateDiffSDXLPipeline,
|
||||
AutoencoderKL,
|
||||
DDIMScheduler,
|
||||
MotionAdapter,
|
||||
UNet2DConditionModel,
|
||||
UNetMotionModel,
|
||||
)
|
||||
from diffusers.utils import is_xformers_available, logging
|
||||
from diffusers.utils.testing_utils import torch_device
|
||||
|
||||
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS, TEXT_TO_IMAGE_PARAMS
|
||||
from ..test_pipelines_common import (
|
||||
IPAdapterTesterMixin,
|
||||
PipelineTesterMixin,
|
||||
SDFunctionTesterMixin,
|
||||
SDXLOptionalComponentsTesterMixin,
|
||||
)
|
||||
|
||||
|
||||
def to_np(tensor):
|
||||
if isinstance(tensor, torch.Tensor):
|
||||
tensor = tensor.detach().cpu().numpy()
|
||||
|
||||
return tensor
|
||||
|
||||
|
||||
class AnimateDiffPipelineSDXLFastTests(
|
||||
IPAdapterTesterMixin,
|
||||
SDFunctionTesterMixin,
|
||||
PipelineTesterMixin,
|
||||
SDXLOptionalComponentsTesterMixin,
|
||||
unittest.TestCase,
|
||||
):
|
||||
pipeline_class = AnimateDiffSDXLPipeline
|
||||
params = TEXT_TO_IMAGE_PARAMS
|
||||
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
|
||||
required_optional_params = frozenset(
|
||||
[
|
||||
"num_inference_steps",
|
||||
"generator",
|
||||
"latents",
|
||||
"return_dict",
|
||||
"callback_on_step_end",
|
||||
"callback_on_step_end_tensor_inputs",
|
||||
]
|
||||
)
|
||||
callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS.union({"add_text_embeds", "add_time_ids"})
|
||||
|
||||
def get_dummy_components(self, time_cond_proj_dim=None):
|
||||
torch.manual_seed(0)
|
||||
unet = UNet2DConditionModel(
|
||||
block_out_channels=(32, 64, 128),
|
||||
layers_per_block=2,
|
||||
time_cond_proj_dim=time_cond_proj_dim,
|
||||
sample_size=32,
|
||||
in_channels=4,
|
||||
out_channels=4,
|
||||
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D"),
|
||||
up_block_types=("CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "UpBlock2D"),
|
||||
# SD2-specific config below
|
||||
attention_head_dim=(2, 4, 8),
|
||||
use_linear_projection=True,
|
||||
addition_embed_type="text_time",
|
||||
addition_time_embed_dim=8,
|
||||
transformer_layers_per_block=(1, 2, 4),
|
||||
projection_class_embeddings_input_dim=80, # 6 * 8 + 32
|
||||
cross_attention_dim=64,
|
||||
norm_num_groups=1,
|
||||
)
|
||||
scheduler = DDIMScheduler(
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
beta_schedule="linear",
|
||||
clip_sample=False,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKL(
|
||||
block_out_channels=[32, 64],
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
|
||||
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
|
||||
latent_channels=4,
|
||||
sample_size=128,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
text_encoder_config = CLIPTextConfig(
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
hidden_size=32,
|
||||
intermediate_size=37,
|
||||
layer_norm_eps=1e-05,
|
||||
num_attention_heads=4,
|
||||
num_hidden_layers=5,
|
||||
pad_token_id=1,
|
||||
vocab_size=1000,
|
||||
# SD2-specific config below
|
||||
hidden_act="gelu",
|
||||
projection_dim=32,
|
||||
)
|
||||
text_encoder = CLIPTextModel(text_encoder_config)
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
text_encoder_2 = CLIPTextModelWithProjection(text_encoder_config)
|
||||
tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
motion_adapter = MotionAdapter(
|
||||
block_out_channels=(32, 64, 128),
|
||||
motion_layers_per_block=2,
|
||||
motion_norm_num_groups=2,
|
||||
motion_num_attention_heads=4,
|
||||
use_motion_mid_block=False,
|
||||
)
|
||||
|
||||
components = {
|
||||
"unet": unet,
|
||||
"scheduler": scheduler,
|
||||
"vae": vae,
|
||||
"motion_adapter": motion_adapter,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
"text_encoder_2": text_encoder_2,
|
||||
"tokenizer_2": tokenizer_2,
|
||||
"feature_extractor": None,
|
||||
"image_encoder": None,
|
||||
}
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
|
||||
inputs = {
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 7.5,
|
||||
"output_type": "np",
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_motion_unet_loading(self):
|
||||
components = self.get_dummy_components()
|
||||
pipe = AnimateDiffSDXLPipeline(**components)
|
||||
|
||||
assert isinstance(pipe.unet, UNetMotionModel)
|
||||
|
||||
@unittest.skip("Attention slicing is not enabled in this pipeline")
|
||||
def test_attention_slicing_forward_pass(self):
|
||||
pass
|
||||
|
||||
def test_inference_batch_single_identical(
|
||||
self,
|
||||
batch_size=2,
|
||||
expected_max_diff=1e-4,
|
||||
additional_params_copy_to_batched_inputs=["num_inference_steps"],
|
||||
):
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
for components in pipe.components.values():
|
||||
if hasattr(components, "set_default_attn_processor"):
|
||||
components.set_default_attn_processor()
|
||||
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
# Reset generator in case it is has been used in self.get_dummy_inputs
|
||||
inputs["generator"] = self.get_generator(0)
|
||||
|
||||
logger = logging.get_logger(pipe.__module__)
|
||||
logger.setLevel(level=diffusers.logging.FATAL)
|
||||
|
||||
# batchify inputs
|
||||
batched_inputs = {}
|
||||
batched_inputs.update(inputs)
|
||||
|
||||
for name in self.batch_params:
|
||||
if name not in inputs:
|
||||
continue
|
||||
|
||||
value = inputs[name]
|
||||
if name == "prompt":
|
||||
len_prompt = len(value)
|
||||
batched_inputs[name] = [value[: len_prompt // i] for i in range(1, batch_size + 1)]
|
||||
batched_inputs[name][-1] = 100 * "very long"
|
||||
|
||||
else:
|
||||
batched_inputs[name] = batch_size * [value]
|
||||
|
||||
if "generator" in inputs:
|
||||
batched_inputs["generator"] = [self.get_generator(i) for i in range(batch_size)]
|
||||
|
||||
if "batch_size" in inputs:
|
||||
batched_inputs["batch_size"] = batch_size
|
||||
|
||||
for arg in additional_params_copy_to_batched_inputs:
|
||||
batched_inputs[arg] = inputs[arg]
|
||||
|
||||
output = pipe(**inputs)
|
||||
output_batch = pipe(**batched_inputs)
|
||||
|
||||
assert output_batch[0].shape[0] == batch_size
|
||||
|
||||
max_diff = np.abs(to_np(output_batch[0][0]) - to_np(output[0][0])).max()
|
||||
assert max_diff < expected_max_diff
|
||||
|
||||
@unittest.skipIf(torch_device != "cuda", reason="CUDA and CPU are required to switch devices")
|
||||
def test_to_device(self):
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
pipe.to("cpu")
|
||||
# pipeline creates a new motion UNet under the hood. So we need to check the device from pipe.components
|
||||
model_devices = [
|
||||
component.device.type for component in pipe.components.values() if hasattr(component, "device")
|
||||
]
|
||||
self.assertTrue(all(device == "cpu" for device in model_devices))
|
||||
|
||||
output_cpu = pipe(**self.get_dummy_inputs("cpu"))[0]
|
||||
self.assertTrue(np.isnan(output_cpu).sum() == 0)
|
||||
|
||||
pipe.to("cuda")
|
||||
model_devices = [
|
||||
component.device.type for component in pipe.components.values() if hasattr(component, "device")
|
||||
]
|
||||
self.assertTrue(all(device == "cuda" for device in model_devices))
|
||||
|
||||
output_cuda = pipe(**self.get_dummy_inputs("cuda"))[0]
|
||||
self.assertTrue(np.isnan(to_np(output_cuda)).sum() == 0)
|
||||
|
||||
def test_to_dtype(self):
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
# pipeline creates a new motion UNet under the hood. So we need to check the dtype from pipe.components
|
||||
model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
|
||||
self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes))
|
||||
|
||||
pipe.to(dtype=torch.float16)
|
||||
model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
|
||||
self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes))
|
||||
|
||||
def test_prompt_embeds(self):
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.to(torch_device)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
prompt = inputs.pop("prompt")
|
||||
|
||||
(
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds,
|
||||
) = pipe.encode_prompt(prompt)
|
||||
|
||||
pipe(
|
||||
**inputs,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
pooled_prompt_embeds=pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
||||
)
|
||||
|
||||
def test_save_load_optional_components(self):
|
||||
self._test_save_load_optional_components()
|
||||
|
||||
@unittest.skipIf(
|
||||
torch_device != "cuda" or not is_xformers_available(),
|
||||
reason="XFormers attention is only available with CUDA and `xformers` installed",
|
||||
)
|
||||
def test_xformers_attention_forwardGenerator_pass(self):
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
for component in pipe.components.values():
|
||||
if hasattr(component, "set_default_attn_processor"):
|
||||
component.set_default_attn_processor()
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
output_without_offload = pipe(**inputs).frames[0]
|
||||
output_without_offload = (
|
||||
output_without_offload.cpu() if torch.is_tensor(output_without_offload) else output_without_offload
|
||||
)
|
||||
|
||||
pipe.enable_xformers_memory_efficient_attention()
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
output_with_offload = pipe(**inputs).frames[0]
|
||||
output_with_offload = (
|
||||
output_with_offload.cpu() if torch.is_tensor(output_with_offload) else output_without_offload
|
||||
)
|
||||
|
||||
max_diff = np.abs(to_np(output_with_offload) - to_np(output_without_offload)).max()
|
||||
self.assertLess(max_diff, 1e-4, "XFormers attention should not affect the inference results")
|
||||
Reference in New Issue
Block a user