mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-07 04:54:47 +08:00
Compare commits
10 Commits
model-test
...
animatedif
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6830fb0805 | ||
|
|
761c44d116 | ||
|
|
76f931d7c8 | ||
|
|
65686818ab | ||
|
|
ec91064966 | ||
|
|
74e3ab088c | ||
|
|
94438e1439 | ||
|
|
a86eabe0bd | ||
|
|
d55903d0b2 | ||
|
|
d0a81ae604 |
@@ -43,6 +43,12 @@ def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim:
|
||||
return ff_output
|
||||
|
||||
|
||||
def _experimental_split_feed_forward(
|
||||
ff: nn.Module, hidden_states: torch.Tensor, split_size: int, split_dim: int
|
||||
) -> torch.Tensor:
|
||||
return torch.cat([ff(hs_split) for hs_split in hidden_states.split(split_size, dim=split_dim)], dim=split_dim)
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class GatedSelfAttentionDense(nn.Module):
|
||||
r"""
|
||||
@@ -525,7 +531,10 @@ class BasicTransformerBlock(nn.Module):
|
||||
|
||||
if self._chunk_size is not None:
|
||||
# "feed_forward_chunk_size" can be used to save memory
|
||||
ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
|
||||
# ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
|
||||
ff_output = _experimental_split_feed_forward(
|
||||
self.ff, norm_hidden_states, self._chunk_size, self._chunk_dim
|
||||
)
|
||||
else:
|
||||
ff_output = self.ff(norm_hidden_states)
|
||||
|
||||
@@ -972,15 +981,32 @@ class FreeNoiseTransformerBlock(nn.Module):
|
||||
return frame_indices
|
||||
|
||||
def _get_frame_weights(self, num_frames: int, weighting_scheme: str = "pyramid") -> List[float]:
|
||||
if weighting_scheme == "pyramid":
|
||||
if weighting_scheme == "flat":
|
||||
weights = [1.0] * num_frames
|
||||
|
||||
elif weighting_scheme == "pyramid":
|
||||
if num_frames % 2 == 0:
|
||||
# num_frames = 4 => [1, 2, 2, 1]
|
||||
weights = list(range(1, num_frames // 2 + 1))
|
||||
mid = num_frames // 2
|
||||
weights = list(range(1, mid + 1))
|
||||
weights = weights + weights[::-1]
|
||||
else:
|
||||
# num_frames = 5 => [1, 2, 3, 2, 1]
|
||||
weights = list(range(1, num_frames // 2 + 1))
|
||||
weights = weights + [num_frames // 2 + 1] + weights[::-1]
|
||||
mid = (num_frames + 1) // 2
|
||||
weights = list(range(1, mid))
|
||||
weights = weights + [mid] + weights[::-1]
|
||||
|
||||
elif weighting_scheme == "delayed_reverse_sawtooth":
|
||||
if num_frames % 2 == 0:
|
||||
# num_frames = 4 => [0.01, 2, 2, 1]
|
||||
mid = num_frames // 2
|
||||
weights = [0.01] * (mid - 1) + [mid]
|
||||
weights = weights + list(range(mid, 0, -1))
|
||||
else:
|
||||
# num_frames = 5 => [0.01, 0.01, 3, 2, 1]
|
||||
mid = (num_frames + 1) // 2
|
||||
weights = [0.01] * mid
|
||||
weights = weights + list(range(mid, 0, -1))
|
||||
else:
|
||||
raise ValueError(f"Unsupported value for weighting_scheme={weighting_scheme}")
|
||||
|
||||
@@ -1087,16 +1113,38 @@ class FreeNoiseTransformerBlock(nn.Module):
|
||||
accumulated_values[:, frame_start:frame_end] += hidden_states_chunk * weights
|
||||
num_times_accumulated[:, frame_start:frame_end] += weights
|
||||
|
||||
hidden_states = torch.where(
|
||||
num_times_accumulated > 0, accumulated_values / num_times_accumulated, accumulated_values
|
||||
hidden_states = torch.cat(
|
||||
[
|
||||
torch.where(num_times_split > 0, accumulated_split / num_times_split, accumulated_split)
|
||||
for accumulated_split, num_times_split in zip(
|
||||
accumulated_values.split(self.context_length, dim=1),
|
||||
num_times_accumulated.split(self.context_length, dim=1),
|
||||
)
|
||||
],
|
||||
dim=1,
|
||||
).to(dtype)
|
||||
|
||||
# hidden_states = torch.where(
|
||||
# num_times_accumulated > 0, accumulated_values / num_times_accumulated, accumulated_values
|
||||
# ).to(dtype)
|
||||
|
||||
# 3. Feed-forward
|
||||
norm_hidden_states = self.norm3(hidden_states)
|
||||
|
||||
if self._chunk_size is not None:
|
||||
ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
|
||||
# norm_hidden_states = torch.cat([
|
||||
# self.norm3(hs_split) for hs_split in hidden_states.split(self._chunk_size, self._chunk_dim)
|
||||
# ], dim=self._chunk_dim)
|
||||
# ff_output = torch.cat([
|
||||
# self.ff(self.norm3(hs_split)) for hs_split in hidden_states.split(self._chunk_size, self._chunk_dim)
|
||||
# ], dim=self._chunk_dim)
|
||||
|
||||
# ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
|
||||
ff_output = _experimental_split_feed_forward(
|
||||
self.ff, norm_hidden_states, self._chunk_size, self._chunk_dim
|
||||
)
|
||||
else:
|
||||
norm_hidden_states = self.norm3(hidden_states)
|
||||
ff_output = self.ff(norm_hidden_states)
|
||||
|
||||
hidden_states = ff_output + hidden_states
|
||||
|
||||
@@ -2221,6 +2221,8 @@ class AttnProcessor2_0:
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
# linear proj
|
||||
# TODO: figure out a better way to do this
|
||||
# hidden_states = torch.cat([attn.to_out[1](attn.to_out[0](x)) for x in hidden_states.split(4, dim=0)], dim=0)
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
@@ -116,7 +116,7 @@ class AnimateDiffTransformer3D(nn.Module):
|
||||
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
self.norm = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
self.proj_in = nn.Linear(in_channels, inner_dim)
|
||||
|
||||
# 3. Define transformers blocks
|
||||
@@ -187,12 +187,12 @@ class AnimateDiffTransformer3D(nn.Module):
|
||||
hidden_states = self.norm(hidden_states)
|
||||
hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape(batch_size * height * width, num_frames, channel)
|
||||
|
||||
hidden_states = self.proj_in(hidden_states)
|
||||
hidden_states = self.proj_in(input=hidden_states)
|
||||
|
||||
# 2. Blocks
|
||||
for block in self.transformer_blocks:
|
||||
hidden_states = block(
|
||||
hidden_states,
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
timestep=timestep,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
@@ -200,7 +200,7 @@ class AnimateDiffTransformer3D(nn.Module):
|
||||
)
|
||||
|
||||
# 3. Output
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
hidden_states = self.proj_out(input=hidden_states)
|
||||
hidden_states = (
|
||||
hidden_states[None, None, :]
|
||||
.reshape(batch_size, height, width, num_frames, channel)
|
||||
@@ -344,7 +344,7 @@ class DownBlockMotion(nn.Module):
|
||||
)
|
||||
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
hidden_states = resnet(input_tensor=hidden_states, temb=temb)
|
||||
|
||||
hidden_states = motion_module(hidden_states, num_frames=num_frames)
|
||||
|
||||
@@ -352,7 +352,7 @@ class DownBlockMotion(nn.Module):
|
||||
|
||||
if self.downsamplers is not None:
|
||||
for downsampler in self.downsamplers:
|
||||
hidden_states = downsampler(hidden_states)
|
||||
hidden_states = downsampler(hidden_states=hidden_states)
|
||||
|
||||
output_states = output_states + (hidden_states,)
|
||||
|
||||
@@ -531,25 +531,18 @@ class CrossAttnDownBlockMotion(nn.Module):
|
||||
temb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
hidden_states = attn(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
attention_mask=attention_mask,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
hidden_states = resnet(input_tensor=hidden_states, temb=temb)
|
||||
|
||||
hidden_states = attn(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
attention_mask=attention_mask,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
hidden_states = attn(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
attention_mask=attention_mask,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
hidden_states = motion_module(
|
||||
hidden_states,
|
||||
num_frames=num_frames,
|
||||
@@ -563,7 +556,7 @@ class CrossAttnDownBlockMotion(nn.Module):
|
||||
|
||||
if self.downsamplers is not None:
|
||||
for downsampler in self.downsamplers:
|
||||
hidden_states = downsampler(hidden_states)
|
||||
hidden_states = downsampler(hidden_states=hidden_states)
|
||||
|
||||
output_states = output_states + (hidden_states,)
|
||||
|
||||
@@ -757,25 +750,18 @@ class CrossAttnUpBlockMotion(nn.Module):
|
||||
temb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
hidden_states = attn(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
attention_mask=attention_mask,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
hidden_states = resnet(input_tensor=hidden_states, temb=temb)
|
||||
|
||||
hidden_states = attn(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
attention_mask=attention_mask,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
hidden_states = attn(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
attention_mask=attention_mask,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
hidden_states = motion_module(
|
||||
hidden_states,
|
||||
num_frames=num_frames,
|
||||
@@ -783,7 +769,7 @@ class CrossAttnUpBlockMotion(nn.Module):
|
||||
|
||||
if self.upsamplers is not None:
|
||||
for upsampler in self.upsamplers:
|
||||
hidden_states = upsampler(hidden_states, upsample_size)
|
||||
hidden_states = upsampler(hidden_states=hidden_states, output_size=upsample_size)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@@ -929,13 +915,13 @@ class UpBlockMotion(nn.Module):
|
||||
create_custom_forward(resnet), hidden_states, temb
|
||||
)
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
hidden_states = resnet(input_tensor=hidden_states, temb=temb)
|
||||
|
||||
hidden_states = motion_module(hidden_states, num_frames=num_frames)
|
||||
|
||||
if self.upsamplers is not None:
|
||||
for upsampler in self.upsamplers:
|
||||
hidden_states = upsampler(hidden_states, upsample_size)
|
||||
hidden_states = upsampler(hidden_states=hidden_states, output_size=upsample_size)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@@ -1080,10 +1066,19 @@ class UNetMidBlockCrossAttnMotion(nn.Module):
|
||||
if cross_attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
|
||||
|
||||
hidden_states = self.resnets[0](hidden_states, temb)
|
||||
hidden_states = self.resnets[0](input_tensor=hidden_states, temb=temb)
|
||||
|
||||
blocks = zip(self.attentions, self.resnets[1:], self.motion_modules)
|
||||
for attn, resnet, motion_module in blocks:
|
||||
hidden_states = attn(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
attention_mask=attention_mask,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
@@ -1096,14 +1091,6 @@ class UNetMidBlockCrossAttnMotion(nn.Module):
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
hidden_states = attn(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
attention_mask=attention_mask,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(motion_module),
|
||||
hidden_states,
|
||||
@@ -1117,19 +1104,11 @@ class UNetMidBlockCrossAttnMotion(nn.Module):
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
else:
|
||||
hidden_states = attn(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
attention_mask=attention_mask,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
hidden_states = motion_module(
|
||||
hidden_states,
|
||||
num_frames=num_frames,
|
||||
)
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
hidden_states = resnet(input_tensor=hidden_states, temb=temb)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@@ -2178,7 +2157,6 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft
|
||||
|
||||
emb = emb if aug_emb is None else emb + aug_emb
|
||||
emb = emb.repeat_interleave(repeats=num_frames, dim=0)
|
||||
encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=num_frames, dim=0)
|
||||
|
||||
if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj":
|
||||
if "image_embeds" not in added_cond_kwargs:
|
||||
|
||||
@@ -432,7 +432,6 @@ class AnimateDiffPipeline(
|
||||
extra_step_kwargs["generator"] = generator
|
||||
return extra_step_kwargs
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
@@ -470,8 +469,8 @@ class AnimateDiffPipeline(
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
elif prompt is not None and not isinstance(prompt, (str, list, dict)):
|
||||
raise ValueError(f"`prompt` has to be of type `str`, `list` or `dict` but is {type(prompt)=}")
|
||||
|
||||
if negative_prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
@@ -557,11 +556,15 @@ class AnimateDiffPipeline(
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_frames: Optional[int] = 16,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
@@ -701,9 +704,10 @@ class AnimateDiffPipeline(
|
||||
self._guidance_scale = guidance_scale
|
||||
self._clip_skip = clip_skip
|
||||
self._cross_attention_kwargs = cross_attention_kwargs
|
||||
self._interrupt = False
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
if prompt is not None and isinstance(prompt, (str, dict)):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
@@ -716,22 +720,39 @@ class AnimateDiffPipeline(
|
||||
text_encoder_lora_scale = (
|
||||
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
|
||||
)
|
||||
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
||||
prompt,
|
||||
device,
|
||||
num_videos_per_prompt,
|
||||
self.do_classifier_free_guidance,
|
||||
negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
lora_scale=text_encoder_lora_scale,
|
||||
clip_skip=self.clip_skip,
|
||||
)
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
if self.do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
if self.free_noise_enabled:
|
||||
prompt_embeds, negative_prompt_embeds = self._encode_prompt_free_noise(
|
||||
prompt=prompt,
|
||||
num_frames=num_frames,
|
||||
device=device,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
negative_prompt=negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
lora_scale=text_encoder_lora_scale,
|
||||
clip_skip=self.clip_skip,
|
||||
)
|
||||
else:
|
||||
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
||||
prompt,
|
||||
device,
|
||||
num_videos_per_prompt,
|
||||
self.do_classifier_free_guidance,
|
||||
negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
lora_scale=text_encoder_lora_scale,
|
||||
clip_skip=self.clip_skip,
|
||||
)
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
if self.do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
prompt_embeds = prompt_embeds.repeat_interleave(repeats=num_frames, dim=0)
|
||||
|
||||
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
|
||||
image_embeds = self.prepare_ip_adapter_image_embeds(
|
||||
@@ -783,6 +804,9 @@ class AnimateDiffPipeline(
|
||||
# 8. Denoising loop
|
||||
with self.progress_bar(total=self._num_timesteps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
@@ -505,8 +505,8 @@ class AnimateDiffControlNetPipeline(
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
elif prompt is not None and not isinstance(prompt, (str, list, dict)):
|
||||
raise ValueError(f"`prompt` has to be of type `str`, `list` or `dict` but is {type(prompt)}")
|
||||
|
||||
if negative_prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
@@ -699,6 +699,10 @@ class AnimateDiffControlNetPipeline(
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
@@ -858,9 +862,10 @@ class AnimateDiffControlNetPipeline(
|
||||
self._guidance_scale = guidance_scale
|
||||
self._clip_skip = clip_skip
|
||||
self._cross_attention_kwargs = cross_attention_kwargs
|
||||
self._interrupt = False
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
if prompt is not None and isinstance(prompt, (str, dict)):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
@@ -883,22 +888,39 @@ class AnimateDiffControlNetPipeline(
|
||||
text_encoder_lora_scale = (
|
||||
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
|
||||
)
|
||||
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
||||
prompt,
|
||||
device,
|
||||
num_videos_per_prompt,
|
||||
self.do_classifier_free_guidance,
|
||||
negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
lora_scale=text_encoder_lora_scale,
|
||||
clip_skip=self.clip_skip,
|
||||
)
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
if self.do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
if self.free_noise_enabled:
|
||||
prompt_embeds, negative_prompt_embeds = self._encode_prompt_free_noise(
|
||||
prompt=prompt,
|
||||
num_frames=num_frames,
|
||||
device=device,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
negative_prompt=negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
lora_scale=text_encoder_lora_scale,
|
||||
clip_skip=self.clip_skip,
|
||||
)
|
||||
else:
|
||||
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
||||
prompt,
|
||||
device,
|
||||
num_videos_per_prompt,
|
||||
self.do_classifier_free_guidance,
|
||||
negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
lora_scale=text_encoder_lora_scale,
|
||||
clip_skip=self.clip_skip,
|
||||
)
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
if self.do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
prompt_embeds = prompt_embeds.repeat_interleave(repeats=num_frames, dim=0)
|
||||
|
||||
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
|
||||
image_embeds = self.prepare_ip_adapter_image_embeds(
|
||||
@@ -990,6 +1012,9 @@ class AnimateDiffControlNetPipeline(
|
||||
# 8. Denoising loop
|
||||
with self.progress_bar(total=self._num_timesteps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
@@ -1002,7 +1027,6 @@ class AnimateDiffControlNetPipeline(
|
||||
else:
|
||||
control_model_input = latent_model_input
|
||||
controlnet_prompt_embeds = prompt_embeds
|
||||
controlnet_prompt_embeds = controlnet_prompt_embeds.repeat_interleave(num_frames, dim=0)
|
||||
|
||||
if isinstance(controlnet_keep[i], list):
|
||||
cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
|
||||
|
||||
@@ -246,7 +246,6 @@ class AnimateDiffVideoToVideoPipeline(
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
@@ -299,7 +298,7 @@ class AnimateDiffVideoToVideoPipeline(
|
||||
else:
|
||||
scale_lora_layers(self.text_encoder, lora_scale)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
if prompt is not None and isinstance(prompt, (str, dict)):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
@@ -582,8 +581,8 @@ class AnimateDiffVideoToVideoPipeline(
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
elif prompt is not None and not isinstance(prompt, (str, list, dict)):
|
||||
raise ValueError(f"`prompt` has to be of type `str`, `list` or `dict` but is {type(prompt)}")
|
||||
|
||||
if negative_prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
@@ -628,23 +627,20 @@ class AnimateDiffVideoToVideoPipeline(
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
video,
|
||||
height,
|
||||
width,
|
||||
num_channels_latents,
|
||||
batch_size,
|
||||
timestep,
|
||||
dtype,
|
||||
device,
|
||||
generator,
|
||||
latents=None,
|
||||
video: Optional[torch.Tensor] = None,
|
||||
height: int = 64,
|
||||
width: int = 64,
|
||||
num_channels_latents: int = 4,
|
||||
batch_size: int = 1,
|
||||
timestep: Optional[int] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
decode_chunk_size: int = 16,
|
||||
):
|
||||
if latents is None:
|
||||
num_frames = video.shape[1]
|
||||
else:
|
||||
num_frames = latents.shape[2]
|
||||
|
||||
add_noise: bool = False,
|
||||
) -> torch.Tensor:
|
||||
num_frames = video.shape[1] if latents is None else latents.shape[2]
|
||||
shape = (
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
@@ -708,8 +704,13 @@ class AnimateDiffVideoToVideoPipeline(
|
||||
if shape != latents.shape:
|
||||
# [B, C, F, H, W]
|
||||
raise ValueError(f"`latents` expected to have {shape=}, but found {latents.shape=}")
|
||||
|
||||
latents = latents.to(device, dtype=dtype)
|
||||
|
||||
if add_noise:
|
||||
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
latents = self.scheduler.add_noise(latents, noise, timestep)
|
||||
|
||||
return latents
|
||||
|
||||
@property
|
||||
@@ -735,6 +736,10 @@ class AnimateDiffVideoToVideoPipeline(
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
@@ -743,6 +748,7 @@ class AnimateDiffVideoToVideoPipeline(
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
num_inference_steps: int = 50,
|
||||
enforce_inference_steps: bool = False,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
guidance_scale: float = 7.5,
|
||||
@@ -874,9 +880,10 @@ class AnimateDiffVideoToVideoPipeline(
|
||||
self._guidance_scale = guidance_scale
|
||||
self._clip_skip = clip_skip
|
||||
self._cross_attention_kwargs = cross_attention_kwargs
|
||||
self._interrupt = False
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
if prompt is not None and isinstance(prompt, (str, dict)):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
@@ -884,29 +891,85 @@ class AnimateDiffVideoToVideoPipeline(
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
dtype = self.dtype
|
||||
|
||||
# 3. Encode input prompt
|
||||
# 3. Prepare timesteps
|
||||
if not enforce_inference_steps:
|
||||
timesteps, num_inference_steps = retrieve_timesteps(
|
||||
self.scheduler, num_inference_steps, device, timesteps, sigmas
|
||||
)
|
||||
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, timesteps, strength, device)
|
||||
latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt)
|
||||
else:
|
||||
denoising_inference_steps = int(num_inference_steps / strength)
|
||||
timesteps, denoising_inference_steps = retrieve_timesteps(
|
||||
self.scheduler, denoising_inference_steps, device, timesteps, sigmas
|
||||
)
|
||||
timesteps = timesteps[-num_inference_steps:]
|
||||
latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt)
|
||||
|
||||
# 4. Prepare latent variables
|
||||
if latents is None:
|
||||
video = self.video_processor.preprocess_video(video, height=height, width=width)
|
||||
# Move the number of frames before the number of channels.
|
||||
video = video.permute(0, 2, 1, 3, 4)
|
||||
video = video.to(device=device, dtype=dtype)
|
||||
num_channels_latents = self.unet.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
video=video,
|
||||
height=height,
|
||||
width=width,
|
||||
num_channels_latents=num_channels_latents,
|
||||
batch_size=batch_size * num_videos_per_prompt,
|
||||
timestep=latent_timestep,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
generator=generator,
|
||||
latents=latents,
|
||||
decode_chunk_size=decode_chunk_size,
|
||||
add_noise=enforce_inference_steps,
|
||||
)
|
||||
|
||||
# 5. Encode input prompt
|
||||
text_encoder_lora_scale = (
|
||||
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
|
||||
)
|
||||
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
||||
prompt,
|
||||
device,
|
||||
num_videos_per_prompt,
|
||||
self.do_classifier_free_guidance,
|
||||
negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
lora_scale=text_encoder_lora_scale,
|
||||
clip_skip=self.clip_skip,
|
||||
)
|
||||
num_frames = latents.shape[2]
|
||||
if self.free_noise_enabled:
|
||||
prompt_embeds, negative_prompt_embeds = self._encode_prompt_free_noise(
|
||||
prompt=prompt,
|
||||
num_frames=num_frames,
|
||||
device=device,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
negative_prompt=negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
lora_scale=text_encoder_lora_scale,
|
||||
clip_skip=self.clip_skip,
|
||||
)
|
||||
else:
|
||||
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
||||
prompt,
|
||||
device,
|
||||
num_videos_per_prompt,
|
||||
self.do_classifier_free_guidance,
|
||||
negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
lora_scale=text_encoder_lora_scale,
|
||||
clip_skip=self.clip_skip,
|
||||
)
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
if self.do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
if self.do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
prompt_embeds = prompt_embeds.repeat_interleave(repeats=num_frames, dim=0)
|
||||
|
||||
# 6. Prepare IP-Adapter embeddings
|
||||
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
|
||||
image_embeds = self.prepare_ip_adapter_image_embeds(
|
||||
ip_adapter_image,
|
||||
@@ -916,38 +979,10 @@ class AnimateDiffVideoToVideoPipeline(
|
||||
self.do_classifier_free_guidance,
|
||||
)
|
||||
|
||||
# 4. Prepare timesteps
|
||||
timesteps, num_inference_steps = retrieve_timesteps(
|
||||
self.scheduler, num_inference_steps, device, timesteps, sigmas
|
||||
)
|
||||
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, timesteps, strength, device)
|
||||
latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt)
|
||||
|
||||
# 5. Prepare latent variables
|
||||
if latents is None:
|
||||
video = self.video_processor.preprocess_video(video, height=height, width=width)
|
||||
# Move the number of frames before the number of channels.
|
||||
video = video.permute(0, 2, 1, 3, 4)
|
||||
video = video.to(device=device, dtype=prompt_embeds.dtype)
|
||||
num_channels_latents = self.unet.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
video=video,
|
||||
height=height,
|
||||
width=width,
|
||||
num_channels_latents=num_channels_latents,
|
||||
batch_size=batch_size * num_videos_per_prompt,
|
||||
timestep=latent_timestep,
|
||||
dtype=prompt_embeds.dtype,
|
||||
device=device,
|
||||
generator=generator,
|
||||
latents=latents,
|
||||
decode_chunk_size=decode_chunk_size,
|
||||
)
|
||||
|
||||
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 7. Add image embeds for IP-Adapter
|
||||
# 8. Add image embeds for IP-Adapter
|
||||
added_cond_kwargs = (
|
||||
{"image_embeds": image_embeds}
|
||||
if ip_adapter_image is not None or ip_adapter_image_embeds is not None
|
||||
@@ -967,9 +1002,12 @@ class AnimateDiffVideoToVideoPipeline(
|
||||
self._num_timesteps = len(timesteps)
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
|
||||
# 8. Denoising loop
|
||||
# 9. Denoising loop
|
||||
with self.progress_bar(total=self._num_timesteps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
@@ -1005,14 +1043,14 @@ class AnimateDiffVideoToVideoPipeline(
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
# 9. Post-processing
|
||||
# 10. Post-processing
|
||||
if output_type == "latent":
|
||||
video = latents
|
||||
else:
|
||||
video_tensor = self.decode_latents(latents, decode_chunk_size)
|
||||
video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type)
|
||||
|
||||
# 10. Offload all models
|
||||
# 11. Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
|
||||
@@ -12,16 +12,21 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional, Union
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ..models.attention import BasicTransformerBlock, FreeNoiseTransformerBlock
|
||||
from ..models.resnet import Downsample2D, ResnetBlock2D, Upsample2D
|
||||
from ..models.transformers.transformer_2d import Transformer2DModel
|
||||
from ..models.unets.unet_motion_model import (
|
||||
AnimateDiffTransformer3D,
|
||||
CrossAttnDownBlockMotion,
|
||||
DownBlockMotion,
|
||||
UpBlockMotion,
|
||||
)
|
||||
from ..pipelines.pipeline_utils import DiffusionPipeline
|
||||
from ..utils import logging
|
||||
from ..utils.torch_utils import randn_tensor
|
||||
|
||||
@@ -29,6 +34,53 @@ from ..utils.torch_utils import randn_tensor
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class ChunkedInferenceModule(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
module: nn.Module,
|
||||
chunk_size: int = 1,
|
||||
chunk_dim: int = 0,
|
||||
input_kwargs_to_chunk: List[str] = ["hidden_states"],
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.module = module
|
||||
self.chunk_size = chunk_size
|
||||
self.chunk_dim = chunk_dim
|
||||
self.input_kwargs_to_chunk = set(input_kwargs_to_chunk)
|
||||
|
||||
def forward(self, *args, **kwargs) -> Union[torch.Tensor, Tuple[torch.Tensor]]:
|
||||
r"""Forward method of `ChunkedInferenceModule`.
|
||||
|
||||
All inputs that should be chunked should be passed as keyword arguments. Only those keywords arguments will be
|
||||
chunked that are specified in `inputs_to_chunk` when initializing the module.
|
||||
"""
|
||||
chunked_inputs = {}
|
||||
|
||||
for key in list(kwargs.keys()):
|
||||
if key not in self.input_kwargs_to_chunk or not torch.is_tensor(kwargs[key]):
|
||||
continue
|
||||
chunked_inputs[key] = torch.split(kwargs[key], self.chunk_size, self.chunk_dim)
|
||||
kwargs.pop(key)
|
||||
|
||||
results = []
|
||||
for chunked_input in zip(*chunked_inputs.values()):
|
||||
inputs = dict(zip(chunked_inputs.keys(), chunked_input))
|
||||
inputs.update(kwargs)
|
||||
|
||||
intermediate_tensor_or_tensor_tuple = self.module(*args, **inputs)
|
||||
results.append(intermediate_tensor_or_tensor_tuple)
|
||||
|
||||
if isinstance(results[0], torch.Tensor):
|
||||
return torch.cat(results, dim=self.chunk_dim)
|
||||
elif isinstance(results[0], tuple):
|
||||
return tuple([torch.cat(x, dim=self.chunk_dim) for x in zip(*results)])
|
||||
else:
|
||||
raise ValueError(
|
||||
"In order to use the ChunkedInferenceModule, it is necessary for the underlying `module` to either return a torch.Tensor or a tuple of torch.Tensor's."
|
||||
)
|
||||
|
||||
|
||||
class AnimateDiffFreeNoiseMixin:
|
||||
r"""Mixin class for [FreeNoise](https://arxiv.org/abs/2310.15169)."""
|
||||
|
||||
@@ -69,6 +121,9 @@ class AnimateDiffFreeNoiseMixin:
|
||||
motion_module.transformer_blocks[i].load_state_dict(
|
||||
basic_transfomer_block.state_dict(), strict=True
|
||||
)
|
||||
motion_module.transformer_blocks[i].set_chunk_feed_forward(
|
||||
basic_transfomer_block._chunk_size, basic_transfomer_block._chunk_dim
|
||||
)
|
||||
|
||||
def _disable_free_noise_in_block(self, block: Union[CrossAttnDownBlockMotion, DownBlockMotion, UpBlockMotion]):
|
||||
r"""Helper function to disable FreeNoise in transformer blocks."""
|
||||
@@ -97,6 +152,145 @@ class AnimateDiffFreeNoiseMixin:
|
||||
motion_module.transformer_blocks[i].load_state_dict(
|
||||
free_noise_transfomer_block.state_dict(), strict=True
|
||||
)
|
||||
motion_module.transformer_blocks[i].set_chunk_feed_forward(
|
||||
free_noise_transfomer_block._chunk_size, free_noise_transfomer_block._chunk_dim
|
||||
)
|
||||
|
||||
def _check_inputs_free_noise(
|
||||
self,
|
||||
prompt,
|
||||
negative_prompt,
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
num_frames,
|
||||
) -> None:
|
||||
if not isinstance(prompt, (str, dict)):
|
||||
raise ValueError(f"Expected `prompt` to have type `str` or `dict` but found {type(prompt)=}")
|
||||
|
||||
if negative_prompt is not None:
|
||||
if not isinstance(negative_prompt, (str, dict)):
|
||||
raise ValueError(
|
||||
f"Expected `negative_prompt` to have type `str` or `dict` but found {type(negative_prompt)=}"
|
||||
)
|
||||
|
||||
if prompt_embeds is not None or negative_prompt_embeds is not None:
|
||||
raise ValueError("`prompt_embeds` and `negative_prompt_embeds` is not supported in FreeNoise yet.")
|
||||
|
||||
frame_indices = [isinstance(x, int) for x in prompt.keys()]
|
||||
frame_prompts = [isinstance(x, str) for x in prompt.values()]
|
||||
min_frame = min(list(prompt.keys()))
|
||||
max_frame = max(list(prompt.keys()))
|
||||
|
||||
if not all(frame_indices):
|
||||
raise ValueError("Expected integer keys in `prompt` dict for FreeNoise.")
|
||||
if not all(frame_prompts):
|
||||
raise ValueError("Expected str values in `prompt` dict for FreeNoise.")
|
||||
if min_frame != 0:
|
||||
raise ValueError("The minimum frame index in `prompt` dict must be 0 as a starting prompt is necessary.")
|
||||
if max_frame >= num_frames:
|
||||
raise ValueError(
|
||||
f"The maximum frame index in `prompt` dict must be lesser than {num_frames=} and follow 0-based indexing."
|
||||
)
|
||||
|
||||
def _encode_prompt_free_noise(
|
||||
self,
|
||||
prompt: Union[str, Dict[int, str]],
|
||||
num_frames: int,
|
||||
device: torch.device,
|
||||
num_videos_per_prompt: int,
|
||||
do_classifier_free_guidance: bool,
|
||||
negative_prompt: Optional[Union[str, Dict[int, str]]] = None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
lora_scale: Optional[float] = None,
|
||||
clip_skip: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
if negative_prompt is None:
|
||||
negative_prompt = ""
|
||||
|
||||
# Ensure that we have a dictionary of prompts
|
||||
if isinstance(prompt, str):
|
||||
prompt = {0: prompt}
|
||||
if isinstance(negative_prompt, str):
|
||||
negative_prompt = {0: negative_prompt}
|
||||
|
||||
self._check_inputs_free_noise(prompt, negative_prompt, prompt_embeds, negative_prompt_embeds, num_frames)
|
||||
|
||||
# Sort the prompts based on frame indices
|
||||
prompt = dict(sorted(prompt.items()))
|
||||
negative_prompt = dict(sorted(negative_prompt.items()))
|
||||
|
||||
# Ensure that we have a prompt for the last frame index
|
||||
prompt[num_frames - 1] = prompt[list(prompt.keys())[-1]]
|
||||
negative_prompt[num_frames - 1] = negative_prompt[list(negative_prompt.keys())[-1]]
|
||||
|
||||
frame_indices = list(prompt.keys())
|
||||
frame_prompts = list(prompt.values())
|
||||
frame_negative_indices = list(negative_prompt.keys())
|
||||
frame_negative_prompts = list(negative_prompt.values())
|
||||
|
||||
# Generate and interpolate positive prompts
|
||||
prompt_embeds, _ = self.encode_prompt(
|
||||
prompt=frame_prompts,
|
||||
device=device,
|
||||
num_images_per_prompt=num_videos_per_prompt,
|
||||
do_classifier_free_guidance=False,
|
||||
negative_prompt=None,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
lora_scale=lora_scale,
|
||||
clip_skip=clip_skip,
|
||||
)
|
||||
|
||||
shape = (num_frames, *prompt_embeds.shape[1:])
|
||||
prompt_interpolation_embeds = prompt_embeds.new_zeros(shape)
|
||||
|
||||
for i in range(len(frame_indices) - 1):
|
||||
start_frame = frame_indices[i]
|
||||
end_frame = frame_indices[i + 1]
|
||||
start_tensor = prompt_embeds[i].unsqueeze(0)
|
||||
end_tensor = prompt_embeds[i + 1].unsqueeze(0)
|
||||
|
||||
prompt_interpolation_embeds[start_frame : end_frame + 1] = self._free_noise_prompt_interpolation_callback(
|
||||
start_frame, end_frame, start_tensor, end_tensor
|
||||
)
|
||||
|
||||
# Generate and interpolate negative prompts
|
||||
negative_prompt_embeds = None
|
||||
negative_prompt_interpolation_embeds = None
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
_, negative_prompt_embeds = self.encode_prompt(
|
||||
prompt=[""] * len(frame_negative_prompts),
|
||||
device=device,
|
||||
num_images_per_prompt=num_videos_per_prompt,
|
||||
do_classifier_free_guidance=True,
|
||||
negative_prompt=frame_negative_prompts,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
lora_scale=lora_scale,
|
||||
clip_skip=clip_skip,
|
||||
)
|
||||
|
||||
negative_prompt_interpolation_embeds = negative_prompt_embeds.new_zeros(shape)
|
||||
|
||||
for i in range(len(frame_negative_indices) - 1):
|
||||
start_frame = frame_negative_indices[i]
|
||||
end_frame = frame_negative_indices[i + 1]
|
||||
start_tensor = negative_prompt_embeds[i].unsqueeze(0)
|
||||
end_tensor = negative_prompt_embeds[i + 1].unsqueeze(0)
|
||||
|
||||
negative_prompt_interpolation_embeds[
|
||||
start_frame : end_frame + 1
|
||||
] = self._free_noise_prompt_interpolation_callback(start_frame, end_frame, start_tensor, end_tensor)
|
||||
|
||||
prompt_embeds = prompt_interpolation_embeds
|
||||
negative_prompt_embeds = negative_prompt_interpolation_embeds
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
def _prepare_latents_free_noise(
|
||||
self,
|
||||
@@ -172,12 +366,29 @@ class AnimateDiffFreeNoiseMixin:
|
||||
latents = latents[:, :, :num_frames]
|
||||
return latents
|
||||
|
||||
def _lerp(
|
||||
self, start_index: int, end_index: int, start_tensor: torch.Tensor, end_tensor: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
num_indices = end_index - start_index + 1
|
||||
interpolated_tensors = []
|
||||
|
||||
for i in range(num_indices):
|
||||
alpha = i / (num_indices - 1)
|
||||
interpolated_tensor = (1 - alpha) * start_tensor + alpha * end_tensor
|
||||
interpolated_tensors.append(interpolated_tensor)
|
||||
|
||||
interpolated_tensors = torch.cat(interpolated_tensors)
|
||||
return interpolated_tensors
|
||||
|
||||
def enable_free_noise(
|
||||
self,
|
||||
context_length: Optional[int] = 16,
|
||||
context_stride: int = 4,
|
||||
weighting_scheme: str = "pyramid",
|
||||
noise_type: str = "shuffle_context",
|
||||
prompt_interpolation_callback: Optional[
|
||||
Callable[[DiffusionPipeline, int, int, torch.Tensor, torch.Tensor], torch.Tensor]
|
||||
] = None,
|
||||
) -> None:
|
||||
r"""
|
||||
Enable long video generation using FreeNoise.
|
||||
@@ -201,7 +412,7 @@ class AnimateDiffFreeNoiseMixin:
|
||||
TODO
|
||||
"""
|
||||
|
||||
allowed_weighting_scheme = ["pyramid"]
|
||||
allowed_weighting_scheme = ["flat", "pyramid", "delayed_reverse_sawtooth"]
|
||||
allowed_noise_type = ["shuffle_context", "repeat_context", "random"]
|
||||
|
||||
if context_length > self.motion_adapter.config.motion_max_seq_length:
|
||||
@@ -219,6 +430,7 @@ class AnimateDiffFreeNoiseMixin:
|
||||
self._free_noise_context_stride = context_stride
|
||||
self._free_noise_weighting_scheme = weighting_scheme
|
||||
self._free_noise_noise_type = noise_type
|
||||
self._free_noise_prompt_interpolation_callback = prompt_interpolation_callback or self._lerp
|
||||
|
||||
blocks = [*self.unet.down_blocks, self.unet.mid_block, *self.unet.up_blocks]
|
||||
for block in blocks:
|
||||
@@ -231,6 +443,56 @@ class AnimateDiffFreeNoiseMixin:
|
||||
for block in blocks:
|
||||
self._disable_free_noise_in_block(block)
|
||||
|
||||
def _enable_chunked_inference_motion_modules_(
|
||||
self, motion_modules: List[AnimateDiffTransformer3D], spatial_chunk_size: int
|
||||
) -> None:
|
||||
for motion_module in motion_modules:
|
||||
motion_module.proj_in = ChunkedInferenceModule(motion_module.proj_in, spatial_chunk_size, 0, ["input"])
|
||||
|
||||
for i in range(len(motion_module.transformer_blocks)):
|
||||
motion_module.transformer_blocks[i] = ChunkedInferenceModule(
|
||||
motion_module.transformer_blocks[i],
|
||||
spatial_chunk_size,
|
||||
0,
|
||||
["hidden_states", "encoder_hidden_states"],
|
||||
)
|
||||
|
||||
motion_module.proj_out = ChunkedInferenceModule(motion_module.proj_out, spatial_chunk_size, 0, ["input"])
|
||||
|
||||
def _enable_chunked_inference_attentions_(
|
||||
self, attentions: List[Transformer2DModel], temporal_chunk_size: int
|
||||
) -> None:
|
||||
for i in range(len(attentions)):
|
||||
attentions[i] = ChunkedInferenceModule(
|
||||
attentions[i], temporal_chunk_size, 0, ["hidden_states", "encoder_hidden_states"]
|
||||
)
|
||||
|
||||
def _enable_chunked_inference_resnets_(self, resnets: List[ResnetBlock2D], temporal_chunk_size: int) -> None:
|
||||
for i in range(len(resnets)):
|
||||
resnets[i] = ChunkedInferenceModule(resnets[i], temporal_chunk_size, 0, ["input_tensor", "temb"])
|
||||
|
||||
def _enable_chunked_inference_samplers_(
|
||||
self, samplers: Union[List[Downsample2D], List[Upsample2D]], temporal_chunk_size: int
|
||||
) -> None:
|
||||
for i in range(len(samplers)):
|
||||
samplers[i] = ChunkedInferenceModule(samplers[i], temporal_chunk_size, 0, ["hidden_states"])
|
||||
|
||||
def enable_free_noise_chunked_inference(
|
||||
self, spatial_chunk_size: int = 256, temporal_chunk_size: int = 16
|
||||
) -> None:
|
||||
blocks = [*self.unet.down_blocks, self.unet.mid_block, *self.unet.up_blocks]
|
||||
for block in blocks:
|
||||
if getattr(block, "motion_modules", None) is not None:
|
||||
self._enable_chunked_inference_motion_modules_(block.motion_modules, spatial_chunk_size)
|
||||
if getattr(block, "attentions", None) is not None:
|
||||
self._enable_chunked_inference_attentions_(block.attentions, temporal_chunk_size)
|
||||
if getattr(block, "resnets", None) is not None:
|
||||
self._enable_chunked_inference_resnets_(block.resnets, temporal_chunk_size)
|
||||
if getattr(block, "downsamplers", None) is not None:
|
||||
self._enable_chunked_inference_samplers_(block.downsamplers, temporal_chunk_size)
|
||||
if getattr(block, "upsamplers", None) is not None:
|
||||
self._enable_chunked_inference_samplers_(block.upsamplers, temporal_chunk_size)
|
||||
|
||||
@property
|
||||
def free_noise_enabled(self):
|
||||
return hasattr(self, "_free_noise_context_length") and self._free_noise_context_length is not None
|
||||
|
||||
Reference in New Issue
Block a user