[Pipeline] AnimateDiff SDXL (#6721)

* update conversion script to handle motion adapter sdxl checkpoint

* add animatediff xl

* handle addition_embed_type

* fix output

* update

* add imports

* make fix-copies

* add decode latents

* update docstrings

* add animatediff sdxl to docs

* remove unnecessary lines

* update example

* add test

* revert conv_in conv_out kernel param

* remove unused param addition_embed_type_num_heads

* latest IPAdapter impl

* make fix-copies

* fix return

* add IPAdapterTesterMixin to tests

* fix return

* revert based on suggestion

* add freeinit

* fix test_to_dtype test

* use StableDiffusionMixin instead of different helper methods

* fix progress bar iterations

* apply suggestions from review

* hardcode flip_sin_to_cos and freq_shift

* make fix-copies

* fix ip adapter implementation

* fix last failing test

* make style

* Update docs/source/en/api/pipelines/animatediff.md

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>

* remove todo

* fix doc-builder errors

---------

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
This commit is contained in:
Aryan
2024-05-08 21:27:14 +05:30
committed by GitHub
parent f29b93488d
commit 818f760732
10 changed files with 1740 additions and 9 deletions

View File

@@ -101,6 +101,53 @@ AnimateDiff tends to work better with finetuned Stable Diffusion models. If you
</Tip>
### AnimateDiffSDXLPipeline
AnimateDiff can also be used with SDXL models. This is currently an experimental feature as only a beta release of the motion adapter checkpoint is available.
```python
import torch
from diffusers.models import MotionAdapter
from diffusers import AnimateDiffSDXLPipeline, DDIMScheduler
from diffusers.utils import export_to_gif
adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-sdxl-beta", torch_dtype=torch.float16)
model_id = "stabilityai/stable-diffusion-xl-base-1.0"
scheduler = DDIMScheduler.from_pretrained(
model_id,
subfolder="scheduler",
clip_sample=False,
timestep_spacing="linspace",
beta_schedule="linear",
steps_offset=1,
)
pipe = AnimateDiffSDXLPipeline.from_pretrained(
model_id,
motion_adapter=adapter,
scheduler=scheduler,
torch_dtype=torch.float16,
variant="fp16",
).to("cuda")
# enable memory savings
pipe.enable_vae_slicing()
pipe.enable_vae_tiling()
output = pipe(
prompt="a panda surfing in the ocean, realistic, high quality",
negative_prompt="low quality, worst quality",
num_inference_steps=20,
guidance_scale=8,
width=1024,
height=1024,
num_frames=16,
)
frames = output.frames[0]
export_to_gif(frames, "animation.gif")
```
### AnimateDiffVideoToVideoPipeline
AnimateDiff can also be used to generate visually similar videos or enable style/character/background or other edits starting from an initial video, allowing you to seamlessly explore creative possibilities.
@@ -522,6 +569,12 @@ export_to_gif(frames, "animatelcm-motion-lora.gif")
- all
- __call__
## AnimateDiffSDXLPipeline
[[autodoc]] AnimateDiffSDXLPipeline
- all
- __call__
## AnimateDiffVideoToVideoPipeline
[[autodoc]] AnimateDiffVideoToVideoPipeline

View File

@@ -31,6 +31,7 @@ def get_args():
parser.add_argument("--output_path", type=str, required=True)
parser.add_argument("--use_motion_mid_block", action="store_true")
parser.add_argument("--motion_max_seq_length", type=int, default=32)
parser.add_argument("--block_out_channels", nargs="+", default=[320, 640, 1280, 1280], type=int)
parser.add_argument("--save_fp16", action="store_true")
return parser.parse_args()
@@ -49,11 +50,13 @@ if __name__ == "__main__":
conv_state_dict = convert_motion_module(state_dict)
adapter = MotionAdapter(
use_motion_mid_block=args.use_motion_mid_block, motion_max_seq_length=args.motion_max_seq_length
block_out_channels=args.block_out_channels,
use_motion_mid_block=args.use_motion_mid_block,
motion_max_seq_length=args.motion_max_seq_length,
)
# skip loading position embeddings
adapter.load_state_dict(conv_state_dict, strict=False)
adapter.save_pretrained(args.output_path)
if args.save_fp16:
adapter.to(torch.float16).save_pretrained(args.output_path, variant="fp16")
adapter.to(dtype=torch.float16).save_pretrained(args.output_path, variant="fp16")

View File

@@ -216,6 +216,7 @@ else:
"AmusedInpaintPipeline",
"AmusedPipeline",
"AnimateDiffPipeline",
"AnimateDiffSDXLPipeline",
"AnimateDiffVideoToVideoPipeline",
"AudioLDM2Pipeline",
"AudioLDM2ProjectionModel",
@@ -595,6 +596,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AmusedInpaintPipeline,
AmusedPipeline,
AnimateDiffPipeline,
AnimateDiffSDXLPipeline,
AnimateDiffVideoToVideoPipeline,
AudioLDM2Pipeline,
AudioLDM2ProjectionModel,

View File

@@ -121,6 +121,7 @@ def get_down_block(
raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlockMotion")
return CrossAttnDownBlockMotion(
num_layers=num_layers,
transformer_layers_per_block=transformer_layers_per_block,
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
@@ -255,6 +256,7 @@ def get_up_block(
raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlockMotion")
return CrossAttnUpBlockMotion(
num_layers=num_layers,
transformer_layers_per_block=transformer_layers_per_block,
in_channels=in_channels,
out_channels=out_channels,
prev_output_channel=prev_output_channel,

View File

@@ -211,6 +211,8 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
norm_num_groups: int = 32,
norm_eps: float = 1e-5,
cross_attention_dim: int = 1280,
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None,
use_linear_projection: bool = False,
num_attention_heads: Union[int, Tuple[int, ...]] = 8,
motion_max_seq_length: int = 32,
@@ -218,6 +220,9 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
use_motion_mid_block: int = True,
encoder_hid_dim: Optional[int] = None,
encoder_hid_dim_type: Optional[str] = None,
addition_embed_type: Optional[str] = None,
addition_time_embed_dim: Optional[int] = None,
projection_class_embeddings_input_dim: Optional[int] = None,
time_cond_proj_dim: Optional[int] = None,
):
super().__init__()
@@ -240,6 +245,21 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
)
if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
)
if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
)
if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None:
for layer_number_per_block in transformer_layers_per_block:
if isinstance(layer_number_per_block, list):
raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.")
# input
conv_in_kernel = 3
conv_out_kernel = 3
@@ -260,6 +280,10 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
if encoder_hid_dim_type is None:
self.encoder_hid_proj = None
if addition_embed_type == "text_time":
self.add_time_proj = Timesteps(addition_time_embed_dim, True, 0)
self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
# class embedding
self.down_blocks = nn.ModuleList([])
self.up_blocks = nn.ModuleList([])
@@ -267,6 +291,15 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
if isinstance(num_attention_heads, int):
num_attention_heads = (num_attention_heads,) * len(down_block_types)
if isinstance(cross_attention_dim, int):
cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
if isinstance(layers_per_block, int):
layers_per_block = [layers_per_block] * len(down_block_types)
if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
# down
output_channel = block_out_channels[0]
for i, down_block_type in enumerate(down_block_types):
@@ -276,7 +309,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
down_block = get_down_block(
down_block_type,
num_layers=layers_per_block,
num_layers=layers_per_block[i],
in_channels=input_channel,
out_channels=output_channel,
temb_channels=time_embed_dim,
@@ -284,13 +317,14 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
cross_attention_dim=cross_attention_dim[i],
num_attention_heads=num_attention_heads[i],
downsample_padding=downsample_padding,
use_linear_projection=use_linear_projection,
dual_cross_attention=False,
temporal_num_attention_heads=motion_num_attention_heads,
temporal_max_seq_length=motion_max_seq_length,
transformer_layers_per_block=transformer_layers_per_block[i],
)
self.down_blocks.append(down_block)
@@ -302,13 +336,14 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor,
cross_attention_dim=cross_attention_dim,
cross_attention_dim=cross_attention_dim[-1],
num_attention_heads=num_attention_heads[-1],
resnet_groups=norm_num_groups,
dual_cross_attention=False,
use_linear_projection=use_linear_projection,
temporal_num_attention_heads=motion_num_attention_heads,
temporal_max_seq_length=motion_max_seq_length,
transformer_layers_per_block=transformer_layers_per_block[-1],
)
else:
@@ -318,11 +353,12 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor,
cross_attention_dim=cross_attention_dim,
cross_attention_dim=cross_attention_dim[-1],
num_attention_heads=num_attention_heads[-1],
resnet_groups=norm_num_groups,
dual_cross_attention=False,
use_linear_projection=use_linear_projection,
transformer_layers_per_block=transformer_layers_per_block[-1],
)
# count how many layers upsample the images
@@ -331,6 +367,9 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
# up
reversed_block_out_channels = list(reversed(block_out_channels))
reversed_num_attention_heads = list(reversed(num_attention_heads))
reversed_layers_per_block = list(reversed(layers_per_block))
reversed_cross_attention_dim = list(reversed(cross_attention_dim))
reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
output_channel = reversed_block_out_channels[0]
for i, up_block_type in enumerate(up_block_types):
@@ -349,7 +388,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
up_block = get_up_block(
up_block_type,
num_layers=layers_per_block + 1,
num_layers=reversed_layers_per_block[i] + 1,
in_channels=input_channel,
out_channels=output_channel,
prev_output_channel=prev_output_channel,
@@ -358,13 +397,14 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
cross_attention_dim=reversed_cross_attention_dim[i],
num_attention_heads=reversed_num_attention_heads[i],
dual_cross_attention=False,
resolution_idx=i,
use_linear_projection=use_linear_projection,
temporal_num_attention_heads=motion_num_attention_heads,
temporal_max_seq_length=motion_max_seq_length,
transformer_layers_per_block=reversed_transformer_layers_per_block[i],
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel
@@ -835,6 +875,28 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
t_emb = t_emb.to(dtype=self.dtype)
emb = self.time_embedding(t_emb, timestep_cond)
aug_emb = None
if self.config.addition_embed_type == "text_time":
if "text_embeds" not in added_cond_kwargs:
raise ValueError(
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
)
text_embeds = added_cond_kwargs.get("text_embeds")
if "time_ids" not in added_cond_kwargs:
raise ValueError(
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
)
time_ids = added_cond_kwargs.get("time_ids")
time_embeds = self.add_time_proj(time_ids.flatten())
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
add_embeds = add_embeds.to(emb.dtype)
aug_emb = self.add_embedding(add_embeds)
emb = emb if aug_emb is None else emb + aug_emb
emb = emb.repeat_interleave(repeats=num_frames, dim=0)
encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=num_frames, dim=0)

View File

@@ -114,6 +114,7 @@ else:
_import_structure["amused"] = ["AmusedImg2ImgPipeline", "AmusedInpaintPipeline", "AmusedPipeline"]
_import_structure["animatediff"] = [
"AnimateDiffPipeline",
"AnimateDiffSDXLPipeline",
"AnimateDiffVideoToVideoPipeline",
]
_import_structure["audioldm"] = ["AudioLDMPipeline"]
@@ -367,7 +368,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from ..utils.dummy_torch_and_transformers_objects import *
else:
from .amused import AmusedImg2ImgPipeline, AmusedInpaintPipeline, AmusedPipeline
from .animatediff import AnimateDiffPipeline, AnimateDiffVideoToVideoPipeline
from .animatediff import AnimateDiffPipeline, AnimateDiffSDXLPipeline, AnimateDiffVideoToVideoPipeline
from .audioldm import AudioLDMPipeline
from .audioldm2 import (
AudioLDM2Pipeline,

View File

@@ -22,6 +22,7 @@ except OptionalDependencyNotAvailable:
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_animatediff"] = ["AnimateDiffPipeline"]
_import_structure["pipeline_animatediff_sdxl"] = ["AnimateDiffSDXLPipeline"]
_import_structure["pipeline_animatediff_video2video"] = ["AnimateDiffVideoToVideoPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
@@ -33,6 +34,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
else:
from .pipeline_animatediff import AnimateDiffPipeline
from .pipeline_animatediff_sdxl import AnimateDiffSDXLPipeline
from .pipeline_animatediff_video2video import AnimateDiffVideoToVideoPipeline
from .pipeline_output import AnimateDiffPipelineOutput

File diff suppressed because it is too large Load Diff

View File

@@ -92,6 +92,21 @@ class AnimateDiffPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class AnimateDiffSDXLPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class AnimateDiffVideoToVideoPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]

View File

@@ -0,0 +1,307 @@
import unittest
import numpy as np
import torch
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
import diffusers
from diffusers import (
AnimateDiffSDXLPipeline,
AutoencoderKL,
DDIMScheduler,
MotionAdapter,
UNet2DConditionModel,
UNetMotionModel,
)
from diffusers.utils import is_xformers_available, logging
from diffusers.utils.testing_utils import torch_device
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import (
IPAdapterTesterMixin,
PipelineTesterMixin,
SDFunctionTesterMixin,
SDXLOptionalComponentsTesterMixin,
)
def to_np(tensor):
if isinstance(tensor, torch.Tensor):
tensor = tensor.detach().cpu().numpy()
return tensor
class AnimateDiffPipelineSDXLFastTests(
IPAdapterTesterMixin,
SDFunctionTesterMixin,
PipelineTesterMixin,
SDXLOptionalComponentsTesterMixin,
unittest.TestCase,
):
pipeline_class = AnimateDiffSDXLPipeline
params = TEXT_TO_IMAGE_PARAMS
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
required_optional_params = frozenset(
[
"num_inference_steps",
"generator",
"latents",
"return_dict",
"callback_on_step_end",
"callback_on_step_end_tensor_inputs",
]
)
callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS.union({"add_text_embeds", "add_time_ids"})
def get_dummy_components(self, time_cond_proj_dim=None):
torch.manual_seed(0)
unet = UNet2DConditionModel(
block_out_channels=(32, 64, 128),
layers_per_block=2,
time_cond_proj_dim=time_cond_proj_dim,
sample_size=32,
in_channels=4,
out_channels=4,
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D"),
up_block_types=("CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "UpBlock2D"),
# SD2-specific config below
attention_head_dim=(2, 4, 8),
use_linear_projection=True,
addition_embed_type="text_time",
addition_time_embed_dim=8,
transformer_layers_per_block=(1, 2, 4),
projection_class_embeddings_input_dim=80, # 6 * 8 + 32
cross_attention_dim=64,
norm_num_groups=1,
)
scheduler = DDIMScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="linear",
clip_sample=False,
)
torch.manual_seed(0)
vae = AutoencoderKL(
block_out_channels=[32, 64],
in_channels=3,
out_channels=3,
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
latent_channels=4,
sample_size=128,
)
torch.manual_seed(0)
text_encoder_config = CLIPTextConfig(
bos_token_id=0,
eos_token_id=2,
hidden_size=32,
intermediate_size=37,
layer_norm_eps=1e-05,
num_attention_heads=4,
num_hidden_layers=5,
pad_token_id=1,
vocab_size=1000,
# SD2-specific config below
hidden_act="gelu",
projection_dim=32,
)
text_encoder = CLIPTextModel(text_encoder_config)
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
text_encoder_2 = CLIPTextModelWithProjection(text_encoder_config)
tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
motion_adapter = MotionAdapter(
block_out_channels=(32, 64, 128),
motion_layers_per_block=2,
motion_norm_num_groups=2,
motion_num_attention_heads=4,
use_motion_mid_block=False,
)
components = {
"unet": unet,
"scheduler": scheduler,
"vae": vae,
"motion_adapter": motion_adapter,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"text_encoder_2": text_encoder_2,
"tokenizer_2": tokenizer_2,
"feature_extractor": None,
"image_encoder": None,
}
return components
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
inputs = {
"prompt": "A painting of a squirrel eating a burger",
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 7.5,
"output_type": "np",
}
return inputs
def test_motion_unet_loading(self):
components = self.get_dummy_components()
pipe = AnimateDiffSDXLPipeline(**components)
assert isinstance(pipe.unet, UNetMotionModel)
@unittest.skip("Attention slicing is not enabled in this pipeline")
def test_attention_slicing_forward_pass(self):
pass
def test_inference_batch_single_identical(
self,
batch_size=2,
expected_max_diff=1e-4,
additional_params_copy_to_batched_inputs=["num_inference_steps"],
):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
for components in pipe.components.values():
if hasattr(components, "set_default_attn_processor"):
components.set_default_attn_processor()
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
# Reset generator in case it is has been used in self.get_dummy_inputs
inputs["generator"] = self.get_generator(0)
logger = logging.get_logger(pipe.__module__)
logger.setLevel(level=diffusers.logging.FATAL)
# batchify inputs
batched_inputs = {}
batched_inputs.update(inputs)
for name in self.batch_params:
if name not in inputs:
continue
value = inputs[name]
if name == "prompt":
len_prompt = len(value)
batched_inputs[name] = [value[: len_prompt // i] for i in range(1, batch_size + 1)]
batched_inputs[name][-1] = 100 * "very long"
else:
batched_inputs[name] = batch_size * [value]
if "generator" in inputs:
batched_inputs["generator"] = [self.get_generator(i) for i in range(batch_size)]
if "batch_size" in inputs:
batched_inputs["batch_size"] = batch_size
for arg in additional_params_copy_to_batched_inputs:
batched_inputs[arg] = inputs[arg]
output = pipe(**inputs)
output_batch = pipe(**batched_inputs)
assert output_batch[0].shape[0] == batch_size
max_diff = np.abs(to_np(output_batch[0][0]) - to_np(output[0][0])).max()
assert max_diff < expected_max_diff
@unittest.skipIf(torch_device != "cuda", reason="CUDA and CPU are required to switch devices")
def test_to_device(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.set_progress_bar_config(disable=None)
pipe.to("cpu")
# pipeline creates a new motion UNet under the hood. So we need to check the device from pipe.components
model_devices = [
component.device.type for component in pipe.components.values() if hasattr(component, "device")
]
self.assertTrue(all(device == "cpu" for device in model_devices))
output_cpu = pipe(**self.get_dummy_inputs("cpu"))[0]
self.assertTrue(np.isnan(output_cpu).sum() == 0)
pipe.to("cuda")
model_devices = [
component.device.type for component in pipe.components.values() if hasattr(component, "device")
]
self.assertTrue(all(device == "cuda" for device in model_devices))
output_cuda = pipe(**self.get_dummy_inputs("cuda"))[0]
self.assertTrue(np.isnan(to_np(output_cuda)).sum() == 0)
def test_to_dtype(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.set_progress_bar_config(disable=None)
# pipeline creates a new motion UNet under the hood. So we need to check the dtype from pipe.components
model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes))
pipe.to(dtype=torch.float16)
model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes))
def test_prompt_embeds(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.set_progress_bar_config(disable=None)
pipe.to(torch_device)
inputs = self.get_dummy_inputs(torch_device)
prompt = inputs.pop("prompt")
(
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
) = pipe.encode_prompt(prompt)
pipe(
**inputs,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
)
def test_save_load_optional_components(self):
self._test_save_load_optional_components()
@unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(),
reason="XFormers attention is only available with CUDA and `xformers` installed",
)
def test_xformers_attention_forwardGenerator_pass(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
for component in pipe.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
output_without_offload = pipe(**inputs).frames[0]
output_without_offload = (
output_without_offload.cpu() if torch.is_tensor(output_without_offload) else output_without_offload
)
pipe.enable_xformers_memory_efficient_attention()
inputs = self.get_dummy_inputs(torch_device)
output_with_offload = pipe(**inputs).frames[0]
output_with_offload = (
output_with_offload.cpu() if torch.is_tensor(output_with_offload) else output_without_offload
)
max_diff = np.abs(to_np(output_with_offload) - to_np(output_without_offload)).max()
self.assertLess(max_diff, 1e-4, "XFormers attention should not affect the inference results")