Compare commits

...

10 Commits

Author SHA1 Message Date
Aryan
6830fb0805 remove print statements 2024-08-21 11:52:05 +02:00
Aryan
761c44d116 refactor chunked inference changes 2024-08-21 11:47:31 +02:00
Aryan
76f931d7c8 Merge branch 'main' into animatediff/freenoise-improvements 2024-08-19 05:45:29 +02:00
Aryan
65686818ab update animatediff controlnet with latest changes 2024-08-18 23:54:55 +02:00
Aryan
ec91064966 update 2024-08-18 18:42:15 +02:00
Aryan
74e3ab088c more memory optimizations; todo: refactor 2024-08-18 06:14:44 +02:00
Aryan
94438e1439 resnet memory optimizations 2024-08-18 02:05:32 +02:00
Aryan
a86eabe0bd make style 2024-08-15 17:20:32 +02:00
Aryan
d55903d0b2 implement prompt interpolation 2024-08-15 17:20:05 +02:00
Aryan
d0a81ae604 update 2024-08-14 16:21:29 +02:00
7 changed files with 561 additions and 185 deletions

View File

@@ -43,6 +43,12 @@ def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim:
return ff_output
def _experimental_split_feed_forward(
ff: nn.Module, hidden_states: torch.Tensor, split_size: int, split_dim: int
) -> torch.Tensor:
return torch.cat([ff(hs_split) for hs_split in hidden_states.split(split_size, dim=split_dim)], dim=split_dim)
@maybe_allow_in_graph
class GatedSelfAttentionDense(nn.Module):
r"""
@@ -525,7 +531,10 @@ class BasicTransformerBlock(nn.Module):
if self._chunk_size is not None:
# "feed_forward_chunk_size" can be used to save memory
ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
# ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
ff_output = _experimental_split_feed_forward(
self.ff, norm_hidden_states, self._chunk_size, self._chunk_dim
)
else:
ff_output = self.ff(norm_hidden_states)
@@ -972,15 +981,32 @@ class FreeNoiseTransformerBlock(nn.Module):
return frame_indices
def _get_frame_weights(self, num_frames: int, weighting_scheme: str = "pyramid") -> List[float]:
if weighting_scheme == "pyramid":
if weighting_scheme == "flat":
weights = [1.0] * num_frames
elif weighting_scheme == "pyramid":
if num_frames % 2 == 0:
# num_frames = 4 => [1, 2, 2, 1]
weights = list(range(1, num_frames // 2 + 1))
mid = num_frames // 2
weights = list(range(1, mid + 1))
weights = weights + weights[::-1]
else:
# num_frames = 5 => [1, 2, 3, 2, 1]
weights = list(range(1, num_frames // 2 + 1))
weights = weights + [num_frames // 2 + 1] + weights[::-1]
mid = (num_frames + 1) // 2
weights = list(range(1, mid))
weights = weights + [mid] + weights[::-1]
elif weighting_scheme == "delayed_reverse_sawtooth":
if num_frames % 2 == 0:
# num_frames = 4 => [0.01, 2, 2, 1]
mid = num_frames // 2
weights = [0.01] * (mid - 1) + [mid]
weights = weights + list(range(mid, 0, -1))
else:
# num_frames = 5 => [0.01, 0.01, 3, 2, 1]
mid = (num_frames + 1) // 2
weights = [0.01] * mid
weights = weights + list(range(mid, 0, -1))
else:
raise ValueError(f"Unsupported value for weighting_scheme={weighting_scheme}")
@@ -1087,16 +1113,38 @@ class FreeNoiseTransformerBlock(nn.Module):
accumulated_values[:, frame_start:frame_end] += hidden_states_chunk * weights
num_times_accumulated[:, frame_start:frame_end] += weights
hidden_states = torch.where(
num_times_accumulated > 0, accumulated_values / num_times_accumulated, accumulated_values
hidden_states = torch.cat(
[
torch.where(num_times_split > 0, accumulated_split / num_times_split, accumulated_split)
for accumulated_split, num_times_split in zip(
accumulated_values.split(self.context_length, dim=1),
num_times_accumulated.split(self.context_length, dim=1),
)
],
dim=1,
).to(dtype)
# hidden_states = torch.where(
# num_times_accumulated > 0, accumulated_values / num_times_accumulated, accumulated_values
# ).to(dtype)
# 3. Feed-forward
norm_hidden_states = self.norm3(hidden_states)
if self._chunk_size is not None:
ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
# norm_hidden_states = torch.cat([
# self.norm3(hs_split) for hs_split in hidden_states.split(self._chunk_size, self._chunk_dim)
# ], dim=self._chunk_dim)
# ff_output = torch.cat([
# self.ff(self.norm3(hs_split)) for hs_split in hidden_states.split(self._chunk_size, self._chunk_dim)
# ], dim=self._chunk_dim)
# ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
ff_output = _experimental_split_feed_forward(
self.ff, norm_hidden_states, self._chunk_size, self._chunk_dim
)
else:
norm_hidden_states = self.norm3(hidden_states)
ff_output = self.ff(norm_hidden_states)
hidden_states = ff_output + hidden_states

View File

@@ -2221,6 +2221,8 @@ class AttnProcessor2_0:
hidden_states = hidden_states.to(query.dtype)
# linear proj
# TODO: figure out a better way to do this
# hidden_states = torch.cat([attn.to_out[1](attn.to_out[0](x)) for x in hidden_states.split(4, dim=0)], dim=0)
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)

View File

@@ -116,7 +116,7 @@ class AnimateDiffTransformer3D(nn.Module):
self.in_channels = in_channels
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
self.norm = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
self.proj_in = nn.Linear(in_channels, inner_dim)
# 3. Define transformers blocks
@@ -187,12 +187,12 @@ class AnimateDiffTransformer3D(nn.Module):
hidden_states = self.norm(hidden_states)
hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape(batch_size * height * width, num_frames, channel)
hidden_states = self.proj_in(hidden_states)
hidden_states = self.proj_in(input=hidden_states)
# 2. Blocks
for block in self.transformer_blocks:
hidden_states = block(
hidden_states,
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
timestep=timestep,
cross_attention_kwargs=cross_attention_kwargs,
@@ -200,7 +200,7 @@ class AnimateDiffTransformer3D(nn.Module):
)
# 3. Output
hidden_states = self.proj_out(hidden_states)
hidden_states = self.proj_out(input=hidden_states)
hidden_states = (
hidden_states[None, None, :]
.reshape(batch_size, height, width, num_frames, channel)
@@ -344,7 +344,7 @@ class DownBlockMotion(nn.Module):
)
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = resnet(input_tensor=hidden_states, temb=temb)
hidden_states = motion_module(hidden_states, num_frames=num_frames)
@@ -352,7 +352,7 @@ class DownBlockMotion(nn.Module):
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
hidden_states = downsampler(hidden_states=hidden_states)
output_states = output_states + (hidden_states,)
@@ -531,25 +531,18 @@ class CrossAttnDownBlockMotion(nn.Module):
temb,
**ckpt_kwargs,
)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = resnet(input_tensor=hidden_states, temb=temb)
hidden_states = attn(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
hidden_states = motion_module(
hidden_states,
num_frames=num_frames,
@@ -563,7 +556,7 @@ class CrossAttnDownBlockMotion(nn.Module):
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
hidden_states = downsampler(hidden_states=hidden_states)
output_states = output_states + (hidden_states,)
@@ -757,25 +750,18 @@ class CrossAttnUpBlockMotion(nn.Module):
temb,
**ckpt_kwargs,
)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = resnet(input_tensor=hidden_states, temb=temb)
hidden_states = attn(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
hidden_states = motion_module(
hidden_states,
num_frames=num_frames,
@@ -783,7 +769,7 @@ class CrossAttnUpBlockMotion(nn.Module):
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states, upsample_size)
hidden_states = upsampler(hidden_states=hidden_states, output_size=upsample_size)
return hidden_states
@@ -929,13 +915,13 @@ class UpBlockMotion(nn.Module):
create_custom_forward(resnet), hidden_states, temb
)
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = resnet(input_tensor=hidden_states, temb=temb)
hidden_states = motion_module(hidden_states, num_frames=num_frames)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states, upsample_size)
hidden_states = upsampler(hidden_states=hidden_states, output_size=upsample_size)
return hidden_states
@@ -1080,10 +1066,19 @@ class UNetMidBlockCrossAttnMotion(nn.Module):
if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
hidden_states = self.resnets[0](hidden_states, temb)
hidden_states = self.resnets[0](input_tensor=hidden_states, temb=temb)
blocks = zip(self.attentions, self.resnets[1:], self.motion_modules)
for attn, resnet, motion_module in blocks:
hidden_states = attn(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
@@ -1096,14 +1091,6 @@ class UNetMidBlockCrossAttnMotion(nn.Module):
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(motion_module),
hidden_states,
@@ -1117,19 +1104,11 @@ class UNetMidBlockCrossAttnMotion(nn.Module):
**ckpt_kwargs,
)
else:
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
hidden_states = motion_module(
hidden_states,
num_frames=num_frames,
)
hidden_states = resnet(hidden_states, temb)
hidden_states = resnet(input_tensor=hidden_states, temb=temb)
return hidden_states
@@ -2178,7 +2157,6 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft
emb = emb if aug_emb is None else emb + aug_emb
emb = emb.repeat_interleave(repeats=num_frames, dim=0)
encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=num_frames, dim=0)
if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj":
if "image_embeds" not in added_cond_kwargs:

View File

@@ -432,7 +432,6 @@ class AnimateDiffPipeline(
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs
def check_inputs(
self,
prompt,
@@ -470,8 +469,8 @@ class AnimateDiffPipeline(
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
elif prompt is not None and not isinstance(prompt, (str, list, dict)):
raise ValueError(f"`prompt` has to be of type `str`, `list` or `dict` but is {type(prompt)=}")
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
@@ -557,11 +556,15 @@ class AnimateDiffPipeline(
def num_timesteps(self):
return self._num_timesteps
@property
def interrupt(self):
return self._interrupt
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Union[str, List[str]] = None,
prompt: Optional[Union[str, List[str]]] = None,
num_frames: Optional[int] = 16,
height: Optional[int] = None,
width: Optional[int] = None,
@@ -701,9 +704,10 @@ class AnimateDiffPipeline(
self._guidance_scale = guidance_scale
self._clip_skip = clip_skip
self._cross_attention_kwargs = cross_attention_kwargs
self._interrupt = False
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
if prompt is not None and isinstance(prompt, (str, dict)):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
@@ -716,22 +720,39 @@ class AnimateDiffPipeline(
text_encoder_lora_scale = (
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
)
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt,
device,
num_videos_per_prompt,
self.do_classifier_free_guidance,
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
lora_scale=text_encoder_lora_scale,
clip_skip=self.clip_skip,
)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
if self.free_noise_enabled:
prompt_embeds, negative_prompt_embeds = self._encode_prompt_free_noise(
prompt=prompt,
num_frames=num_frames,
device=device,
num_videos_per_prompt=num_videos_per_prompt,
do_classifier_free_guidance=self.do_classifier_free_guidance,
negative_prompt=negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
lora_scale=text_encoder_lora_scale,
clip_skip=self.clip_skip,
)
else:
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt,
device,
num_videos_per_prompt,
self.do_classifier_free_guidance,
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
lora_scale=text_encoder_lora_scale,
clip_skip=self.clip_skip,
)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
prompt_embeds = prompt_embeds.repeat_interleave(repeats=num_frames, dim=0)
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
image_embeds = self.prepare_ip_adapter_image_embeds(
@@ -783,6 +804,9 @@ class AnimateDiffPipeline(
# 8. Denoising loop
with self.progress_bar(total=self._num_timesteps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

View File

@@ -505,8 +505,8 @@ class AnimateDiffControlNetPipeline(
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
elif prompt is not None and not isinstance(prompt, (str, list, dict)):
raise ValueError(f"`prompt` has to be of type `str`, `list` or `dict` but is {type(prompt)}")
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
@@ -699,6 +699,10 @@ class AnimateDiffControlNetPipeline(
def num_timesteps(self):
return self._num_timesteps
@property
def interrupt(self):
return self._interrupt
@torch.no_grad()
def __call__(
self,
@@ -858,9 +862,10 @@ class AnimateDiffControlNetPipeline(
self._guidance_scale = guidance_scale
self._clip_skip = clip_skip
self._cross_attention_kwargs = cross_attention_kwargs
self._interrupt = False
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
if prompt is not None and isinstance(prompt, (str, dict)):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
@@ -883,22 +888,39 @@ class AnimateDiffControlNetPipeline(
text_encoder_lora_scale = (
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
)
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt,
device,
num_videos_per_prompt,
self.do_classifier_free_guidance,
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
lora_scale=text_encoder_lora_scale,
clip_skip=self.clip_skip,
)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
if self.free_noise_enabled:
prompt_embeds, negative_prompt_embeds = self._encode_prompt_free_noise(
prompt=prompt,
num_frames=num_frames,
device=device,
num_videos_per_prompt=num_videos_per_prompt,
do_classifier_free_guidance=self.do_classifier_free_guidance,
negative_prompt=negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
lora_scale=text_encoder_lora_scale,
clip_skip=self.clip_skip,
)
else:
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt,
device,
num_videos_per_prompt,
self.do_classifier_free_guidance,
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
lora_scale=text_encoder_lora_scale,
clip_skip=self.clip_skip,
)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
prompt_embeds = prompt_embeds.repeat_interleave(repeats=num_frames, dim=0)
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
image_embeds = self.prepare_ip_adapter_image_embeds(
@@ -990,6 +1012,9 @@ class AnimateDiffControlNetPipeline(
# 8. Denoising loop
with self.progress_bar(total=self._num_timesteps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
@@ -1002,7 +1027,6 @@ class AnimateDiffControlNetPipeline(
else:
control_model_input = latent_model_input
controlnet_prompt_embeds = prompt_embeds
controlnet_prompt_embeds = controlnet_prompt_embeds.repeat_interleave(num_frames, dim=0)
if isinstance(controlnet_keep[i], list):
cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]

View File

@@ -246,7 +246,6 @@ class AnimateDiffVideoToVideoPipeline(
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt
def encode_prompt(
self,
prompt,
@@ -299,7 +298,7 @@ class AnimateDiffVideoToVideoPipeline(
else:
scale_lora_layers(self.text_encoder, lora_scale)
if prompt is not None and isinstance(prompt, str):
if prompt is not None and isinstance(prompt, (str, dict)):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
@@ -582,8 +581,8 @@ class AnimateDiffVideoToVideoPipeline(
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
elif prompt is not None and not isinstance(prompt, (str, list, dict)):
raise ValueError(f"`prompt` has to be of type `str`, `list` or `dict` but is {type(prompt)}")
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
@@ -628,23 +627,20 @@ class AnimateDiffVideoToVideoPipeline(
def prepare_latents(
self,
video,
height,
width,
num_channels_latents,
batch_size,
timestep,
dtype,
device,
generator,
latents=None,
video: Optional[torch.Tensor] = None,
height: int = 64,
width: int = 64,
num_channels_latents: int = 4,
batch_size: int = 1,
timestep: Optional[int] = None,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
decode_chunk_size: int = 16,
):
if latents is None:
num_frames = video.shape[1]
else:
num_frames = latents.shape[2]
add_noise: bool = False,
) -> torch.Tensor:
num_frames = video.shape[1] if latents is None else latents.shape[2]
shape = (
batch_size,
num_channels_latents,
@@ -708,8 +704,13 @@ class AnimateDiffVideoToVideoPipeline(
if shape != latents.shape:
# [B, C, F, H, W]
raise ValueError(f"`latents` expected to have {shape=}, but found {latents.shape=}")
latents = latents.to(device, dtype=dtype)
if add_noise:
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
latents = self.scheduler.add_noise(latents, noise, timestep)
return latents
@property
@@ -735,6 +736,10 @@ class AnimateDiffVideoToVideoPipeline(
def num_timesteps(self):
return self._num_timesteps
@property
def interrupt(self):
return self._interrupt
@torch.no_grad()
def __call__(
self,
@@ -743,6 +748,7 @@ class AnimateDiffVideoToVideoPipeline(
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
enforce_inference_steps: bool = False,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
guidance_scale: float = 7.5,
@@ -874,9 +880,10 @@ class AnimateDiffVideoToVideoPipeline(
self._guidance_scale = guidance_scale
self._clip_skip = clip_skip
self._cross_attention_kwargs = cross_attention_kwargs
self._interrupt = False
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
if prompt is not None and isinstance(prompt, (str, dict)):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
@@ -884,29 +891,85 @@ class AnimateDiffVideoToVideoPipeline(
batch_size = prompt_embeds.shape[0]
device = self._execution_device
dtype = self.dtype
# 3. Encode input prompt
# 3. Prepare timesteps
if not enforce_inference_steps:
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler, num_inference_steps, device, timesteps, sigmas
)
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, timesteps, strength, device)
latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt)
else:
denoising_inference_steps = int(num_inference_steps / strength)
timesteps, denoising_inference_steps = retrieve_timesteps(
self.scheduler, denoising_inference_steps, device, timesteps, sigmas
)
timesteps = timesteps[-num_inference_steps:]
latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt)
# 4. Prepare latent variables
if latents is None:
video = self.video_processor.preprocess_video(video, height=height, width=width)
# Move the number of frames before the number of channels.
video = video.permute(0, 2, 1, 3, 4)
video = video.to(device=device, dtype=dtype)
num_channels_latents = self.unet.config.in_channels
latents = self.prepare_latents(
video=video,
height=height,
width=width,
num_channels_latents=num_channels_latents,
batch_size=batch_size * num_videos_per_prompt,
timestep=latent_timestep,
dtype=dtype,
device=device,
generator=generator,
latents=latents,
decode_chunk_size=decode_chunk_size,
add_noise=enforce_inference_steps,
)
# 5. Encode input prompt
text_encoder_lora_scale = (
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
)
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt,
device,
num_videos_per_prompt,
self.do_classifier_free_guidance,
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
lora_scale=text_encoder_lora_scale,
clip_skip=self.clip_skip,
)
num_frames = latents.shape[2]
if self.free_noise_enabled:
prompt_embeds, negative_prompt_embeds = self._encode_prompt_free_noise(
prompt=prompt,
num_frames=num_frames,
device=device,
num_videos_per_prompt=num_videos_per_prompt,
do_classifier_free_guidance=self.do_classifier_free_guidance,
negative_prompt=negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
lora_scale=text_encoder_lora_scale,
clip_skip=self.clip_skip,
)
else:
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt,
device,
num_videos_per_prompt,
self.do_classifier_free_guidance,
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
lora_scale=text_encoder_lora_scale,
clip_skip=self.clip_skip,
)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
prompt_embeds = prompt_embeds.repeat_interleave(repeats=num_frames, dim=0)
# 6. Prepare IP-Adapter embeddings
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
image_embeds = self.prepare_ip_adapter_image_embeds(
ip_adapter_image,
@@ -916,38 +979,10 @@ class AnimateDiffVideoToVideoPipeline(
self.do_classifier_free_guidance,
)
# 4. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler, num_inference_steps, device, timesteps, sigmas
)
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, timesteps, strength, device)
latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt)
# 5. Prepare latent variables
if latents is None:
video = self.video_processor.preprocess_video(video, height=height, width=width)
# Move the number of frames before the number of channels.
video = video.permute(0, 2, 1, 3, 4)
video = video.to(device=device, dtype=prompt_embeds.dtype)
num_channels_latents = self.unet.config.in_channels
latents = self.prepare_latents(
video=video,
height=height,
width=width,
num_channels_latents=num_channels_latents,
batch_size=batch_size * num_videos_per_prompt,
timestep=latent_timestep,
dtype=prompt_embeds.dtype,
device=device,
generator=generator,
latents=latents,
decode_chunk_size=decode_chunk_size,
)
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 7. Add image embeds for IP-Adapter
# 8. Add image embeds for IP-Adapter
added_cond_kwargs = (
{"image_embeds": image_embeds}
if ip_adapter_image is not None or ip_adapter_image_embeds is not None
@@ -967,9 +1002,12 @@ class AnimateDiffVideoToVideoPipeline(
self._num_timesteps = len(timesteps)
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
# 8. Denoising loop
# 9. Denoising loop
with self.progress_bar(total=self._num_timesteps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
@@ -1005,14 +1043,14 @@ class AnimateDiffVideoToVideoPipeline(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
# 9. Post-processing
# 10. Post-processing
if output_type == "latent":
video = latents
else:
video_tensor = self.decode_latents(latents, decode_chunk_size)
video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type)
# 10. Offload all models
# 11. Offload all models
self.maybe_free_model_hooks()
if not return_dict:

View File

@@ -12,16 +12,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Union
from typing import Callable, Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
from ..models.attention import BasicTransformerBlock, FreeNoiseTransformerBlock
from ..models.resnet import Downsample2D, ResnetBlock2D, Upsample2D
from ..models.transformers.transformer_2d import Transformer2DModel
from ..models.unets.unet_motion_model import (
AnimateDiffTransformer3D,
CrossAttnDownBlockMotion,
DownBlockMotion,
UpBlockMotion,
)
from ..pipelines.pipeline_utils import DiffusionPipeline
from ..utils import logging
from ..utils.torch_utils import randn_tensor
@@ -29,6 +34,53 @@ from ..utils.torch_utils import randn_tensor
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class ChunkedInferenceModule(nn.Module):
def __init__(
self,
module: nn.Module,
chunk_size: int = 1,
chunk_dim: int = 0,
input_kwargs_to_chunk: List[str] = ["hidden_states"],
) -> None:
super().__init__()
self.module = module
self.chunk_size = chunk_size
self.chunk_dim = chunk_dim
self.input_kwargs_to_chunk = set(input_kwargs_to_chunk)
def forward(self, *args, **kwargs) -> Union[torch.Tensor, Tuple[torch.Tensor]]:
r"""Forward method of `ChunkedInferenceModule`.
All inputs that should be chunked should be passed as keyword arguments. Only those keywords arguments will be
chunked that are specified in `inputs_to_chunk` when initializing the module.
"""
chunked_inputs = {}
for key in list(kwargs.keys()):
if key not in self.input_kwargs_to_chunk or not torch.is_tensor(kwargs[key]):
continue
chunked_inputs[key] = torch.split(kwargs[key], self.chunk_size, self.chunk_dim)
kwargs.pop(key)
results = []
for chunked_input in zip(*chunked_inputs.values()):
inputs = dict(zip(chunked_inputs.keys(), chunked_input))
inputs.update(kwargs)
intermediate_tensor_or_tensor_tuple = self.module(*args, **inputs)
results.append(intermediate_tensor_or_tensor_tuple)
if isinstance(results[0], torch.Tensor):
return torch.cat(results, dim=self.chunk_dim)
elif isinstance(results[0], tuple):
return tuple([torch.cat(x, dim=self.chunk_dim) for x in zip(*results)])
else:
raise ValueError(
"In order to use the ChunkedInferenceModule, it is necessary for the underlying `module` to either return a torch.Tensor or a tuple of torch.Tensor's."
)
class AnimateDiffFreeNoiseMixin:
r"""Mixin class for [FreeNoise](https://arxiv.org/abs/2310.15169)."""
@@ -69,6 +121,9 @@ class AnimateDiffFreeNoiseMixin:
motion_module.transformer_blocks[i].load_state_dict(
basic_transfomer_block.state_dict(), strict=True
)
motion_module.transformer_blocks[i].set_chunk_feed_forward(
basic_transfomer_block._chunk_size, basic_transfomer_block._chunk_dim
)
def _disable_free_noise_in_block(self, block: Union[CrossAttnDownBlockMotion, DownBlockMotion, UpBlockMotion]):
r"""Helper function to disable FreeNoise in transformer blocks."""
@@ -97,6 +152,145 @@ class AnimateDiffFreeNoiseMixin:
motion_module.transformer_blocks[i].load_state_dict(
free_noise_transfomer_block.state_dict(), strict=True
)
motion_module.transformer_blocks[i].set_chunk_feed_forward(
free_noise_transfomer_block._chunk_size, free_noise_transfomer_block._chunk_dim
)
def _check_inputs_free_noise(
self,
prompt,
negative_prompt,
prompt_embeds,
negative_prompt_embeds,
num_frames,
) -> None:
if not isinstance(prompt, (str, dict)):
raise ValueError(f"Expected `prompt` to have type `str` or `dict` but found {type(prompt)=}")
if negative_prompt is not None:
if not isinstance(negative_prompt, (str, dict)):
raise ValueError(
f"Expected `negative_prompt` to have type `str` or `dict` but found {type(negative_prompt)=}"
)
if prompt_embeds is not None or negative_prompt_embeds is not None:
raise ValueError("`prompt_embeds` and `negative_prompt_embeds` is not supported in FreeNoise yet.")
frame_indices = [isinstance(x, int) for x in prompt.keys()]
frame_prompts = [isinstance(x, str) for x in prompt.values()]
min_frame = min(list(prompt.keys()))
max_frame = max(list(prompt.keys()))
if not all(frame_indices):
raise ValueError("Expected integer keys in `prompt` dict for FreeNoise.")
if not all(frame_prompts):
raise ValueError("Expected str values in `prompt` dict for FreeNoise.")
if min_frame != 0:
raise ValueError("The minimum frame index in `prompt` dict must be 0 as a starting prompt is necessary.")
if max_frame >= num_frames:
raise ValueError(
f"The maximum frame index in `prompt` dict must be lesser than {num_frames=} and follow 0-based indexing."
)
def _encode_prompt_free_noise(
self,
prompt: Union[str, Dict[int, str]],
num_frames: int,
device: torch.device,
num_videos_per_prompt: int,
do_classifier_free_guidance: bool,
negative_prompt: Optional[Union[str, Dict[int, str]]] = None,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
lora_scale: Optional[float] = None,
clip_skip: Optional[int] = None,
) -> torch.Tensor:
if negative_prompt is None:
negative_prompt = ""
# Ensure that we have a dictionary of prompts
if isinstance(prompt, str):
prompt = {0: prompt}
if isinstance(negative_prompt, str):
negative_prompt = {0: negative_prompt}
self._check_inputs_free_noise(prompt, negative_prompt, prompt_embeds, negative_prompt_embeds, num_frames)
# Sort the prompts based on frame indices
prompt = dict(sorted(prompt.items()))
negative_prompt = dict(sorted(negative_prompt.items()))
# Ensure that we have a prompt for the last frame index
prompt[num_frames - 1] = prompt[list(prompt.keys())[-1]]
negative_prompt[num_frames - 1] = negative_prompt[list(negative_prompt.keys())[-1]]
frame_indices = list(prompt.keys())
frame_prompts = list(prompt.values())
frame_negative_indices = list(negative_prompt.keys())
frame_negative_prompts = list(negative_prompt.values())
# Generate and interpolate positive prompts
prompt_embeds, _ = self.encode_prompt(
prompt=frame_prompts,
device=device,
num_images_per_prompt=num_videos_per_prompt,
do_classifier_free_guidance=False,
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
lora_scale=lora_scale,
clip_skip=clip_skip,
)
shape = (num_frames, *prompt_embeds.shape[1:])
prompt_interpolation_embeds = prompt_embeds.new_zeros(shape)
for i in range(len(frame_indices) - 1):
start_frame = frame_indices[i]
end_frame = frame_indices[i + 1]
start_tensor = prompt_embeds[i].unsqueeze(0)
end_tensor = prompt_embeds[i + 1].unsqueeze(0)
prompt_interpolation_embeds[start_frame : end_frame + 1] = self._free_noise_prompt_interpolation_callback(
start_frame, end_frame, start_tensor, end_tensor
)
# Generate and interpolate negative prompts
negative_prompt_embeds = None
negative_prompt_interpolation_embeds = None
if do_classifier_free_guidance:
_, negative_prompt_embeds = self.encode_prompt(
prompt=[""] * len(frame_negative_prompts),
device=device,
num_images_per_prompt=num_videos_per_prompt,
do_classifier_free_guidance=True,
negative_prompt=frame_negative_prompts,
prompt_embeds=None,
negative_prompt_embeds=None,
lora_scale=lora_scale,
clip_skip=clip_skip,
)
negative_prompt_interpolation_embeds = negative_prompt_embeds.new_zeros(shape)
for i in range(len(frame_negative_indices) - 1):
start_frame = frame_negative_indices[i]
end_frame = frame_negative_indices[i + 1]
start_tensor = negative_prompt_embeds[i].unsqueeze(0)
end_tensor = negative_prompt_embeds[i + 1].unsqueeze(0)
negative_prompt_interpolation_embeds[
start_frame : end_frame + 1
] = self._free_noise_prompt_interpolation_callback(start_frame, end_frame, start_tensor, end_tensor)
prompt_embeds = prompt_interpolation_embeds
negative_prompt_embeds = negative_prompt_interpolation_embeds
if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
return prompt_embeds, negative_prompt_embeds
def _prepare_latents_free_noise(
self,
@@ -172,12 +366,29 @@ class AnimateDiffFreeNoiseMixin:
latents = latents[:, :, :num_frames]
return latents
def _lerp(
self, start_index: int, end_index: int, start_tensor: torch.Tensor, end_tensor: torch.Tensor
) -> torch.Tensor:
num_indices = end_index - start_index + 1
interpolated_tensors = []
for i in range(num_indices):
alpha = i / (num_indices - 1)
interpolated_tensor = (1 - alpha) * start_tensor + alpha * end_tensor
interpolated_tensors.append(interpolated_tensor)
interpolated_tensors = torch.cat(interpolated_tensors)
return interpolated_tensors
def enable_free_noise(
self,
context_length: Optional[int] = 16,
context_stride: int = 4,
weighting_scheme: str = "pyramid",
noise_type: str = "shuffle_context",
prompt_interpolation_callback: Optional[
Callable[[DiffusionPipeline, int, int, torch.Tensor, torch.Tensor], torch.Tensor]
] = None,
) -> None:
r"""
Enable long video generation using FreeNoise.
@@ -201,7 +412,7 @@ class AnimateDiffFreeNoiseMixin:
TODO
"""
allowed_weighting_scheme = ["pyramid"]
allowed_weighting_scheme = ["flat", "pyramid", "delayed_reverse_sawtooth"]
allowed_noise_type = ["shuffle_context", "repeat_context", "random"]
if context_length > self.motion_adapter.config.motion_max_seq_length:
@@ -219,6 +430,7 @@ class AnimateDiffFreeNoiseMixin:
self._free_noise_context_stride = context_stride
self._free_noise_weighting_scheme = weighting_scheme
self._free_noise_noise_type = noise_type
self._free_noise_prompt_interpolation_callback = prompt_interpolation_callback or self._lerp
blocks = [*self.unet.down_blocks, self.unet.mid_block, *self.unet.up_blocks]
for block in blocks:
@@ -231,6 +443,56 @@ class AnimateDiffFreeNoiseMixin:
for block in blocks:
self._disable_free_noise_in_block(block)
def _enable_chunked_inference_motion_modules_(
self, motion_modules: List[AnimateDiffTransformer3D], spatial_chunk_size: int
) -> None:
for motion_module in motion_modules:
motion_module.proj_in = ChunkedInferenceModule(motion_module.proj_in, spatial_chunk_size, 0, ["input"])
for i in range(len(motion_module.transformer_blocks)):
motion_module.transformer_blocks[i] = ChunkedInferenceModule(
motion_module.transformer_blocks[i],
spatial_chunk_size,
0,
["hidden_states", "encoder_hidden_states"],
)
motion_module.proj_out = ChunkedInferenceModule(motion_module.proj_out, spatial_chunk_size, 0, ["input"])
def _enable_chunked_inference_attentions_(
self, attentions: List[Transformer2DModel], temporal_chunk_size: int
) -> None:
for i in range(len(attentions)):
attentions[i] = ChunkedInferenceModule(
attentions[i], temporal_chunk_size, 0, ["hidden_states", "encoder_hidden_states"]
)
def _enable_chunked_inference_resnets_(self, resnets: List[ResnetBlock2D], temporal_chunk_size: int) -> None:
for i in range(len(resnets)):
resnets[i] = ChunkedInferenceModule(resnets[i], temporal_chunk_size, 0, ["input_tensor", "temb"])
def _enable_chunked_inference_samplers_(
self, samplers: Union[List[Downsample2D], List[Upsample2D]], temporal_chunk_size: int
) -> None:
for i in range(len(samplers)):
samplers[i] = ChunkedInferenceModule(samplers[i], temporal_chunk_size, 0, ["hidden_states"])
def enable_free_noise_chunked_inference(
self, spatial_chunk_size: int = 256, temporal_chunk_size: int = 16
) -> None:
blocks = [*self.unet.down_blocks, self.unet.mid_block, *self.unet.up_blocks]
for block in blocks:
if getattr(block, "motion_modules", None) is not None:
self._enable_chunked_inference_motion_modules_(block.motion_modules, spatial_chunk_size)
if getattr(block, "attentions", None) is not None:
self._enable_chunked_inference_attentions_(block.attentions, temporal_chunk_size)
if getattr(block, "resnets", None) is not None:
self._enable_chunked_inference_resnets_(block.resnets, temporal_chunk_size)
if getattr(block, "downsamplers", None) is not None:
self._enable_chunked_inference_samplers_(block.downsamplers, temporal_chunk_size)
if getattr(block, "upsamplers", None) is not None:
self._enable_chunked_inference_samplers_(block.upsamplers, temporal_chunk_size)
@property
def free_noise_enabled(self):
return hasattr(self, "_free_noise_context_length") and self._free_noise_context_length is not None