Compare commits

...

15 Commits

Author SHA1 Message Date
Aryan
5a60a62c47 add experimental support for num_frames not perfectly fitting context length, ocntext stride 2024-07-27 16:24:36 +02:00
Aryan
691facfc2e copy animatediff controlnet implementation from #8972 2024-07-27 15:34:39 +02:00
Aryan
dc96a8d5cd make style 2024-07-27 15:31:12 +02:00
Aryan
1b7bc007d8 make fix-copies 2024-07-27 15:30:42 +02:00
Aryan
1bb09845bf fix copied from comments 2024-07-27 15:29:30 +02:00
Aryan
024e2da864 make style 2024-07-27 15:27:28 +02:00
Aryan
f6897ae46a add decode batch size param to all pipelines 2024-07-27 15:26:14 +02:00
Aryan
a41f843dba remove old helper functions 2024-07-27 15:16:45 +02:00
Aryan
10b65b310c add freenoise 2024-07-27 15:16:02 +02:00
Aryan
610f433d1c revert attention changes 2024-07-27 15:15:34 +02:00
Aryan
690dad693f Merge branch 'main' into freenoise 2024-07-27 13:05:34 +02:00
Aryan
2e97ba7ccb Merge branch 'main' into freenoise 2024-07-25 03:55:41 +05:30
Aryan
5d0f4c3407 add animatediff controlnet implementation 2024-07-24 23:54:42 +02:00
Aryan
441d321152 fix freeinit bug 2024-07-24 23:54:29 +02:00
Aryan
80e530fbfa initial work draft for freenoise; needs massive cleanup 2024-07-24 01:38:18 +02:00
9 changed files with 1616 additions and 28 deletions

View File

@@ -272,6 +272,17 @@ class BasicTransformerBlock(nn.Module):
attention_out_bias: bool = True, attention_out_bias: bool = True,
): ):
super().__init__() super().__init__()
self.dim = dim
self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim
self.dropout = dropout
self.cross_attention_dim = cross_attention_dim
self.activation_fn = activation_fn
self.attention_bias = attention_bias
self.double_self_attention = double_self_attention
self.norm_elementwise_affine = norm_elementwise_affine
self.positional_embeddings = positional_embeddings
self.num_positional_embeddings = num_positional_embeddings
self.only_cross_attention = only_cross_attention self.only_cross_attention = only_cross_attention
# We keep these boolean flags for backward-compatibility. # We keep these boolean flags for backward-compatibility.

View File

@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Any, Dict, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
@@ -21,6 +21,8 @@ import torch.utils.checkpoint
from ...configuration_utils import ConfigMixin, FrozenDict, register_to_config from ...configuration_utils import ConfigMixin, FrozenDict, register_to_config
from ...loaders import FromOriginalModelMixin, UNet2DConditionLoadersMixin from ...loaders import FromOriginalModelMixin, UNet2DConditionLoadersMixin
from ...utils import logging from ...utils import logging
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import FeedForward, _chunked_feed_forward
from ..attention_processor import ( from ..attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS, ADDED_KV_ATTENTION_PROCESSORS,
CROSS_ATTENTION_PROCESSORS, CROSS_ATTENTION_PROCESSORS,
@@ -33,7 +35,7 @@ from ..attention_processor import (
IPAdapterAttnProcessor, IPAdapterAttnProcessor,
IPAdapterAttnProcessor2_0, IPAdapterAttnProcessor2_0,
) )
from ..embeddings import TimestepEmbedding, Timesteps from ..embeddings import SinusoidalPositionalEmbedding, TimestepEmbedding, Timesteps
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from ..transformers.transformer_temporal import TransformerTemporalModel from ..transformers.transformer_temporal import TransformerTemporalModel
from .unet_2d_blocks import UNetMidBlock2DCrossAttn from .unet_2d_blocks import UNetMidBlock2DCrossAttn
@@ -53,6 +55,302 @@ from .unet_3d_condition import UNet3DConditionOutput
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@maybe_allow_in_graph
class FreeNoiseTransformerBlock(nn.Module):
r"""
A FreeNoise Transformer block.
Parameters:
dim (`int`): The number of channels in the input and output.
num_attention_heads (`int`): The number of heads to use for multi-head attention.
attention_head_dim (`int`): The number of channels in each head.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
num_embeds_ada_norm (:
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
attention_bias (:
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
only_cross_attention (`bool`, *optional*):
Whether to use only cross-attention layers. In this case two cross attention layers are used.
double_self_attention (`bool`, *optional*):
Whether to use two self-attention layers. In this case no cross attention layers are used.
upcast_attention (`bool`, *optional*):
Whether to upcast the attention computation to float32. This is useful for mixed precision training.
norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
Whether to use learnable elementwise affine parameters for normalization.
norm_type (`str`, *optional*, defaults to `"layer_norm"`):
The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
final_dropout (`bool` *optional*, defaults to False):
Whether to apply a final dropout after the last feed-forward layer.
attention_type (`str`, *optional*, defaults to `"default"`):
The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
positional_embeddings (`str`, *optional*, defaults to `None`):
The type of positional embeddings to apply to.
num_positional_embeddings (`int`, *optional*, defaults to `None`):
The maximum number of positional embeddings to apply.
"""
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
dropout: float = 0.0,
cross_attention_dim: Optional[int] = None,
activation_fn: str = "geglu",
num_embeds_ada_norm: Optional[int] = None,
attention_bias: bool = False,
only_cross_attention: bool = False,
double_self_attention: bool = False,
upcast_attention: bool = False,
norm_elementwise_affine: bool = True,
norm_type: str = "layer_norm",
norm_eps: float = 1e-5,
final_dropout: bool = False,
positional_embeddings: Optional[str] = None,
num_positional_embeddings: Optional[int] = None,
ff_inner_dim: Optional[int] = None,
ff_bias: bool = True,
attention_out_bias: bool = True,
context_length: int = 16,
context_stride: int = 4,
weighting_scheme: str = "pyramid",
):
super().__init__()
self.dim = dim
self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim
self.dropout = dropout
self.cross_attention_dim = cross_attention_dim
self.activation_fn = activation_fn
self.attention_bias = attention_bias
self.double_self_attention = double_self_attention
self.norm_elementwise_affine = norm_elementwise_affine
self.positional_embeddings = positional_embeddings
self.num_positional_embeddings = num_positional_embeddings
self.only_cross_attention = only_cross_attention
self.set_free_noise_properties(context_length, context_stride, weighting_scheme)
# We keep these boolean flags for backward-compatibility.
self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
self.use_layer_norm = norm_type == "layer_norm"
self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous"
if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
raise ValueError(
f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
)
self.norm_type = norm_type
self.num_embeds_ada_norm = num_embeds_ada_norm
if positional_embeddings and (num_positional_embeddings is None):
raise ValueError(
"If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
)
if positional_embeddings == "sinusoidal":
self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
else:
self.pos_embed = None
# Define 3 blocks. Each block has its own normalization layer.
# 1. Self-Attn
self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
self.attn1 = Attention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
upcast_attention=upcast_attention,
out_bias=attention_out_bias,
)
# 2. Cross-Attn
if cross_attention_dim is not None or double_self_attention:
self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
self.attn2 = Attention(
query_dim=dim,
cross_attention_dim=cross_attention_dim if not double_self_attention else None,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
upcast_attention=upcast_attention,
out_bias=attention_out_bias,
) # is self-attn if encoder_hidden_states is none
# 3. Feed-forward
self.ff = FeedForward(
dim,
dropout=dropout,
activation_fn=activation_fn,
final_dropout=final_dropout,
inner_dim=ff_inner_dim,
bias=ff_bias,
)
self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
# let chunk size default to None
self._chunk_size = None
self._chunk_dim = 0
def _get_frame_indices(self, num_frames: int) -> List[Tuple[int, int]]:
frame_indices = []
for i in range(0, num_frames - self.context_length + 1, self.context_stride):
window_start = i
window_end = min(num_frames, i + self.context_length)
frame_indices.append((window_start, window_end))
return frame_indices
def _get_frame_weights(self, num_frames: int, weighting_scheme: str = "pyramid") -> List[float]:
if weighting_scheme == "pyramid":
if num_frames % 2 == 0:
# num_frames = 4 => [1, 2, 2, 1]
weights = list(range(1, num_frames // 2 + 1))
weights = weights + weights[::-1]
else:
# num_frames = 5 => [1, 2, 3, 2, 1]
weights = list(range(1, num_frames // 2 + 1))
weights = weights + [num_frames // 2 + 1] + weights[::-1]
else:
raise ValueError(f"Unsupported value for weighting_scheme={weighting_scheme}")
return weights
def set_free_noise_properties(
self, context_length: int, context_stride: int, weighting_scheme: str = "pyramid"
) -> None:
self.context_length = context_length
self.context_stride = context_stride
self.weighting_scheme = weighting_scheme
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0) -> None:
# Sets chunk feed-forward
self._chunk_size = chunk_size
self._chunk_dim = dim
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
cross_attention_kwargs: Dict[str, Any] = None,
*args,
**kwargs,
) -> torch.Tensor:
if cross_attention_kwargs is not None:
if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
# hidden_states: [B x H x W, F, C]
device = hidden_states.device
dtype = hidden_states.dtype
num_frames = hidden_states.size(1)
frame_indices = self._get_frame_indices(num_frames)
frame_weights = self._get_frame_weights(self.context_length, self.weighting_scheme)
frame_weights = torch.tensor(frame_weights, device=device, dtype=dtype).unsqueeze(0).unsqueeze(-1)
is_last_frame_batch_complete = frame_indices[-1][1] == num_frames
# Handle out-of-bounds case if num_frames isn't perfectly divisible by context_length
# For example, num_frames=25, context_length=16, context_stride=4, then we expect the ranges:
# [(0, 16), (4, 20), (8, 24), (10, 26)]
if not is_last_frame_batch_complete:
if num_frames < self.context_length:
raise ValueError(f"Expected {num_frames=} to be greater or equal than {self.context_length=}")
last_frame_batch_length = num_frames - frame_indices[-1][1]
frame_indices.append((num_frames - self.context_length, num_frames))
num_times_accumulated = torch.zeros((1, num_frames, 1), device=device)
accumulated_values = torch.zeros_like(hidden_states)
for i, (frame_start, frame_end) in enumerate(frame_indices):
# The reason for slicing here is to ensure that if (frame_end - frame_start) is to handle
# cases like frame_indices=[(0, 16), (16, 20)], if the user provided a video with 19 frames, or
# essentially a non-multiple of `context_length`.
weights = torch.ones_like(num_times_accumulated[:, frame_start:frame_end])
weights *= frame_weights
hidden_states_chunk = hidden_states[:, frame_start:frame_end]
# Notice that normalization is always applied before the real computation in the following blocks.
# 1. Self-Attention
# assert self.norm_type == "layer_norm"
norm_hidden_states = self.norm1(hidden_states_chunk)
if self.pos_embed is not None:
norm_hidden_states = self.pos_embed(norm_hidden_states)
attn_output = self.attn1(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
attention_mask=attention_mask,
**cross_attention_kwargs,
)
hidden_states_chunk = attn_output + hidden_states_chunk
if hidden_states_chunk.ndim == 4:
hidden_states_chunk = hidden_states_chunk.squeeze(1)
# 2. Cross-Attention
if self.attn2 is not None:
norm_hidden_states = self.norm2(hidden_states_chunk)
if self.pos_embed is not None and self.norm_type != "ada_norm_single":
norm_hidden_states = self.pos_embed(norm_hidden_states)
attn_output = self.attn2(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
**cross_attention_kwargs,
)
hidden_states_chunk = attn_output + hidden_states_chunk
if i == len(frame_indices) - 1 and not is_last_frame_batch_complete:
accumulated_values[:, -last_frame_batch_length:] += (
hidden_states_chunk[:, -last_frame_batch_length:] * weights[:, -last_frame_batch_length:]
)
num_times_accumulated[:, -last_frame_batch_length:] += weights[:, -last_frame_batch_length]
else:
accumulated_values[:, frame_start:frame_end] += hidden_states_chunk * weights
num_times_accumulated[:, frame_start:frame_end] += weights
hidden_states = torch.where(
num_times_accumulated > 0, accumulated_values / num_times_accumulated, accumulated_values
).to(dtype)
# 3. Feed-forward
norm_hidden_states = self.norm3(hidden_states)
if self._chunk_size is not None:
# "feed_forward_chunk_size" can be used to save memory
ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
else:
ff_output = self.ff(norm_hidden_states)
hidden_states = ff_output + hidden_states
if hidden_states.ndim == 4:
hidden_states = hidden_states.squeeze(1)
return hidden_states
class MotionModules(nn.Module): class MotionModules(nn.Module):
def __init__( def __init__(
self, self,

View File

@@ -42,6 +42,7 @@ from ...utils import (
from ...utils.torch_utils import randn_tensor from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor from ...video_processor import VideoProcessor
from ..free_init_utils import FreeInitMixin from ..free_init_utils import FreeInitMixin
from ..free_noise_utils import AnimateDiffFreeNoiseMixin
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from .pipeline_output import AnimateDiffPipelineOutput from .pipeline_output import AnimateDiffPipelineOutput
@@ -72,6 +73,7 @@ class AnimateDiffPipeline(
IPAdapterMixin, IPAdapterMixin,
StableDiffusionLoraLoaderMixin, StableDiffusionLoraLoaderMixin,
FreeInitMixin, FreeInitMixin,
AnimateDiffFreeNoiseMixin,
): ):
r""" r"""
Pipeline for text-to-video generation. Pipeline for text-to-video generation.
@@ -394,15 +396,20 @@ class AnimateDiffPipeline(
return ip_adapter_image_embeds return ip_adapter_image_embeds
# Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents def decode_latents(self, latents, decode_batch_size: int = 16):
def decode_latents(self, latents):
latents = 1 / self.vae.config.scaling_factor * latents latents = 1 / self.vae.config.scaling_factor * latents
batch_size, channels, num_frames, height, width = latents.shape batch_size, channels, num_frames, height, width = latents.shape
latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width) latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width)
image = self.vae.decode(latents).sample video = []
video = image[None, :].reshape((batch_size, num_frames, -1) + image.shape[2:]).permute(0, 2, 1, 3, 4) for i in range(0, latents.shape[0], decode_batch_size):
batch_latents = latents[i : i + decode_batch_size]
batch_latents = self.vae.decode(batch_latents).sample
video.append(batch_latents)
video = torch.cat(video)
video = video[None, :].reshape((batch_size, num_frames, -1) + video.shape[2:]).permute(0, 2, 1, 3, 4)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
video = video.float() video = video.float()
return video return video
@@ -495,7 +502,6 @@ class AnimateDiffPipeline(
f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
) )
# Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.prepare_latents
def prepare_latents( def prepare_latents(
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
): ):
@@ -517,6 +523,22 @@ class AnimateDiffPipeline(
else: else:
latents = latents.to(device) latents = latents.to(device)
if self.free_noise_enabled and self._free_noise_shuffle:
for i in range(self._free_noise_context_length, num_frames, self._free_noise_context_stride):
# ensure window is within bounds
window_start = max(0, i - self._free_noise_context_length)
window_end = min(num_frames, window_start + self._free_noise_context_stride)
window_length = window_end - window_start
if window_length == 0:
break
indices = torch.LongTensor(list(range(window_start, window_end)))
shuffled_indices = indices[torch.randperm(window_length, generator=generator)]
# shuffle latents in every window
latents[:, :, window_start:window_end] = latents[:, :, shuffled_indices]
# scale the initial noise by the standard deviation required by the scheduler # scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma latents = latents * self.scheduler.init_noise_sigma
return latents return latents
@@ -569,6 +591,7 @@ class AnimateDiffPipeline(
clip_skip: Optional[int] = None, clip_skip: Optional[int] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"], callback_on_step_end_tensor_inputs: List[str] = ["latents"],
decode_batch_size: int = 16,
**kwargs, **kwargs,
): ):
r""" r"""
@@ -637,6 +660,8 @@ class AnimateDiffPipeline(
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class. `._callback_tensor_inputs` attribute of your pipeline class.
decode_batch_size (`int`, defaults to `16`):
The number of frames to decode at a time when calling `decode_latents` method.
Examples: Examples:
@@ -808,7 +833,7 @@ class AnimateDiffPipeline(
if output_type == "latent": if output_type == "latent":
video = latents video = latents
else: else:
video_tensor = self.decode_latents(latents) video_tensor = self.decode_latents(latents, decode_batch_size)
video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type) video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type)
# 10. Offload all models # 10. Offload all models

File diff suppressed because it is too large Load Diff

View File

@@ -56,6 +56,7 @@ from ...utils import (
from ...utils.torch_utils import randn_tensor from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor from ...video_processor import VideoProcessor
from ..free_init_utils import FreeInitMixin from ..free_init_utils import FreeInitMixin
from ..free_noise_utils import AnimateDiffFreeNoiseMixin
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from .pipeline_output import AnimateDiffPipelineOutput from .pipeline_output import AnimateDiffPipelineOutput
@@ -194,6 +195,7 @@ class AnimateDiffSDXLPipeline(
TextualInversionLoaderMixin, TextualInversionLoaderMixin,
IPAdapterMixin, IPAdapterMixin,
FreeInitMixin, FreeInitMixin,
AnimateDiffFreeNoiseMixin,
): ):
r""" r"""
Pipeline for text-to-video generation using Stable Diffusion XL. Pipeline for text-to-video generation using Stable Diffusion XL.
@@ -606,15 +608,21 @@ class AnimateDiffSDXLPipeline(
return ip_adapter_image_embeds return ip_adapter_image_embeds
# Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents # Copied from diffusers.pipelines.animatediff.pipeline_animatediff.AnimateDiffPipeline.decode_latents
def decode_latents(self, latents): def decode_latents(self, latents, decode_batch_size: int = 16):
latents = 1 / self.vae.config.scaling_factor * latents latents = 1 / self.vae.config.scaling_factor * latents
batch_size, channels, num_frames, height, width = latents.shape batch_size, channels, num_frames, height, width = latents.shape
latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width) latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width)
image = self.vae.decode(latents).sample video = []
video = image[None, :].reshape((batch_size, num_frames, -1) + image.shape[2:]).permute(0, 2, 1, 3, 4) for i in range(0, latents.shape[0], decode_batch_size):
batch_latents = latents[i : i + decode_batch_size]
batch_latents = self.vae.decode(batch_latents).sample
video.append(batch_latents)
video = torch.cat(video)
video = video[None, :].reshape((batch_size, num_frames, -1) + video.shape[2:]).permute(0, 2, 1, 3, 4)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
video = video.float() video = video.float()
return video return video
@@ -876,6 +884,7 @@ class AnimateDiffSDXLPipeline(
clip_skip: Optional[int] = None, clip_skip: Optional[int] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"], callback_on_step_end_tensor_inputs: List[str] = ["latents"],
decode_batch_size: int = 16,
): ):
r""" r"""
Function invoked when calling the pipeline for generation. Function invoked when calling the pipeline for generation.
@@ -1015,6 +1024,8 @@ class AnimateDiffSDXLPipeline(
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class. `._callback_tensor_inputs` attribute of your pipeline class.
decode_batch_size (`int`, defaults to `16`):
The number of frames to decode at a time when calling `decode_latents` method.
Examples: Examples:
@@ -1258,7 +1269,7 @@ class AnimateDiffSDXLPipeline(
if output_type == "latent": if output_type == "latent":
video = latents video = latents
else: else:
video_tensor = self.decode_latents(latents) video_tensor = self.decode_latents(latents, decode_batch_size)
video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type) video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type)
# cast back to fp16 if needed # cast back to fp16 if needed

View File

@@ -38,6 +38,7 @@ from ...utils import (
from ...utils.torch_utils import is_compiled_module, randn_tensor from ...utils.torch_utils import is_compiled_module, randn_tensor
from ...video_processor import VideoProcessor from ...video_processor import VideoProcessor
from ..free_init_utils import FreeInitMixin from ..free_init_utils import FreeInitMixin
from ..free_noise_utils import AnimateDiffFreeNoiseMixin
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from .pipeline_output import AnimateDiffPipelineOutput from .pipeline_output import AnimateDiffPipelineOutput
@@ -127,6 +128,7 @@ class AnimateDiffSparseControlNetPipeline(
IPAdapterMixin, IPAdapterMixin,
StableDiffusionLoraLoaderMixin, StableDiffusionLoraLoaderMixin,
FreeInitMixin, FreeInitMixin,
AnimateDiffFreeNoiseMixin,
): ):
r""" r"""
Pipeline for controlled text-to-video generation using the method described in [SparseCtrl: Adding Sparse Controls Pipeline for controlled text-to-video generation using the method described in [SparseCtrl: Adding Sparse Controls
@@ -448,15 +450,21 @@ class AnimateDiffSparseControlNetPipeline(
return ip_adapter_image_embeds return ip_adapter_image_embeds
# Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents # Copied from diffusers.pipelines.animatediff.pipeline_animatediff.AnimateDiffPipeline.decode_latents
def decode_latents(self, latents): def decode_latents(self, latents, decode_batch_size: int = 16):
latents = 1 / self.vae.config.scaling_factor * latents latents = 1 / self.vae.config.scaling_factor * latents
batch_size, channels, num_frames, height, width = latents.shape batch_size, channels, num_frames, height, width = latents.shape
latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width) latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width)
image = self.vae.decode(latents).sample video = []
video = image[None, :].reshape((batch_size, num_frames, -1) + image.shape[2:]).permute(0, 2, 1, 3, 4) for i in range(0, latents.shape[0], decode_batch_size):
batch_latents = latents[i : i + decode_batch_size]
batch_latents = self.vae.decode(batch_latents).sample
video.append(batch_latents)
video = torch.cat(video)
video = video[None, :].reshape((batch_size, num_frames, -1) + video.shape[2:]).permute(0, 2, 1, 3, 4)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
video = video.float() video = video.float()
return video return video
@@ -728,6 +736,7 @@ class AnimateDiffSparseControlNetPipeline(
clip_skip: Optional[int] = None, clip_skip: Optional[int] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"], callback_on_step_end_tensor_inputs: List[str] = ["latents"],
decode_batch_size: int = 16,
): ):
r""" r"""
The call function to the pipeline for generation. The call function to the pipeline for generation.
@@ -806,6 +815,8 @@ class AnimateDiffSparseControlNetPipeline(
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class. `._callback_tensor_inputs` attribute of your pipeline class.
decode_batch_size (`int`, defaults to `16`):
The number of frames to decode at a time when calling `decode_latents` method.
Examples: Examples:
@@ -996,7 +1007,7 @@ class AnimateDiffSparseControlNetPipeline(
if output_type == "latent": if output_type == "latent":
video = latents video = latents
else: else:
video_tensor = self.decode_latents(latents) video_tensor = self.decode_latents(latents, decode_batch_size)
video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type) video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type)
# 12. Offload all models # 12. Offload all models

View File

@@ -35,6 +35,7 @@ from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_
from ...utils.torch_utils import randn_tensor from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor from ...video_processor import VideoProcessor
from ..free_init_utils import FreeInitMixin from ..free_init_utils import FreeInitMixin
from ..free_noise_utils import AnimateDiffFreeNoiseMixin
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from .pipeline_output import AnimateDiffPipelineOutput from .pipeline_output import AnimateDiffPipelineOutput
@@ -176,6 +177,7 @@ class AnimateDiffVideoToVideoPipeline(
IPAdapterMixin, IPAdapterMixin,
StableDiffusionLoraLoaderMixin, StableDiffusionLoraLoaderMixin,
FreeInitMixin, FreeInitMixin,
AnimateDiffFreeNoiseMixin,
): ):
r""" r"""
Pipeline for video-to-video generation. Pipeline for video-to-video generation.
@@ -498,15 +500,21 @@ class AnimateDiffVideoToVideoPipeline(
return ip_adapter_image_embeds return ip_adapter_image_embeds
# Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents # Copied from diffusers.pipelines.animatediff.pipeline_animatediff.AnimateDiffPipeline.decode_latents
def decode_latents(self, latents): def decode_latents(self, latents, decode_batch_size: int = 16):
latents = 1 / self.vae.config.scaling_factor * latents latents = 1 / self.vae.config.scaling_factor * latents
batch_size, channels, num_frames, height, width = latents.shape batch_size, channels, num_frames, height, width = latents.shape
latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width) latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width)
image = self.vae.decode(latents).sample video = []
video = image[None, :].reshape((batch_size, num_frames, -1) + image.shape[2:]).permute(0, 2, 1, 3, 4) for i in range(0, latents.shape[0], decode_batch_size):
batch_latents = latents[i : i + decode_batch_size]
batch_latents = self.vae.decode(batch_latents).sample
video.append(batch_latents)
video = torch.cat(video)
video = video[None, :].reshape((batch_size, num_frames, -1) + video.shape[2:]).permute(0, 2, 1, 3, 4)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
video = video.float() video = video.float()
return video return video
@@ -747,6 +755,7 @@ class AnimateDiffVideoToVideoPipeline(
clip_skip: Optional[int] = None, clip_skip: Optional[int] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"], callback_on_step_end_tensor_inputs: List[str] = ["latents"],
decode_batch_size: int = 16,
): ):
r""" r"""
The call function to the pipeline for generation. The call function to the pipeline for generation.
@@ -822,6 +831,8 @@ class AnimateDiffVideoToVideoPipeline(
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class. `._callback_tensor_inputs` attribute of your pipeline class.
decode_batch_size (`int`, defaults to `16`):
The number of frames to decode at a time when calling `decode_latents` method.
Examples: Examples:
@@ -990,7 +1001,7 @@ class AnimateDiffVideoToVideoPipeline(
if output_type == "latent": if output_type == "latent":
video = latents video = latents
else: else:
video_tensor = self.decode_latents(latents) video_tensor = self.decode_latents(latents, decode_batch_size)
video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type) video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type)
# 10. Offload all models # 10. Offload all models

View File

@@ -0,0 +1,141 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Union
from ..models.attention import BasicTransformerBlock
from ..models.unets.unet_motion_model import (
CrossAttnDownBlockMotion,
DownBlockMotion,
FreeNoiseTransformerBlock,
TransformerTemporalModel,
UpBlockMotion,
)
class AnimateDiffFreeNoiseMixin:
r"""Mixin class for [FreeNoise](https://arxiv.org/abs/2310.15169)."""
def _enable_free_noise_in_block(self, block: Union[CrossAttnDownBlockMotion, DownBlockMotion, UpBlockMotion]):
r"""Helper function to enable FreeNoise in transformer blocks."""
for motion_module in block.motion_modules:
motion_module: TransformerTemporalModel
num_transformer_blocks = len(motion_module.transformer_blocks)
for i in range(num_transformer_blocks):
if isinstance(motion_module.transformer_blocks[i], FreeNoiseTransformerBlock):
motion_module.transformer_blocks[i].set_free_noise_properties(
self._free_noise_context_length,
self._free_noise_context_stride,
self._free_noise_weighting_scheme,
)
else:
assert isinstance(motion_module.transformer_blocks[i], BasicTransformerBlock)
basic_transfomer_block = motion_module.transformer_blocks[i]
motion_module.transformer_blocks[i] = FreeNoiseTransformerBlock(
dim=basic_transfomer_block.dim,
num_attention_heads=basic_transfomer_block.num_attention_heads,
attention_head_dim=basic_transfomer_block.attention_head_dim,
dropout=basic_transfomer_block.dropout,
cross_attention_dim=basic_transfomer_block.cross_attention_dim,
activation_fn=basic_transfomer_block.activation_fn,
attention_bias=basic_transfomer_block.attention_bias,
only_cross_attention=basic_transfomer_block.only_cross_attention,
double_self_attention=basic_transfomer_block.double_self_attention,
positional_embeddings=basic_transfomer_block.positional_embeddings,
num_positional_embeddings=basic_transfomer_block.num_positional_embeddings,
context_length=self._free_noise_context_length,
context_stride=self._free_noise_context_stride,
weighting_scheme=self._free_noise_weighting_scheme,
).to(device=self.device, dtype=self.dtype)
motion_module.transformer_blocks[i].load_state_dict(
basic_transfomer_block.state_dict(), strict=True
)
def _disable_free_noise_in_block(self, block: Union[CrossAttnDownBlockMotion, DownBlockMotion, UpBlockMotion]):
r"""Helper function to disable FreeNoise in transformer blocks."""
for motion_module in block.motion_modules:
motion_module: TransformerTemporalModel
num_transformer_blocks = len(motion_module.transformer_blocks)
for i in range(num_transformer_blocks):
if isinstance(motion_module.transformer_blocks[i], FreeNoiseTransformerBlock):
free_noise_transfomer_block = motion_module.transformer_blocks[i]
motion_module.transformer_blocks[i] = BasicTransformerBlock(
dim=free_noise_transfomer_block.dim,
num_attention_heads=free_noise_transfomer_block.num_attention_heads,
attention_head_dim=free_noise_transfomer_block.attention_head_dim,
dropout=free_noise_transfomer_block.dropout,
cross_attention_dim=free_noise_transfomer_block.cross_attention_dim,
activation_fn=free_noise_transfomer_block.activation_fn,
attention_bias=free_noise_transfomer_block.attention_bias,
only_cross_attention=free_noise_transfomer_block.only_cross_attention,
double_self_attention=free_noise_transfomer_block.double_self_attention,
positional_embeddings=free_noise_transfomer_block.positional_embeddings,
num_positional_embeddings=free_noise_transfomer_block.num_positional_embeddings,
).to(device=self.device, dtype=self.dtype)
motion_module.transformer_blocks[i].load_state_dict(
free_noise_transfomer_block.state_dict(), strict=True
)
def enable_free_noise(
self,
context_length: Optional[int] = 16,
context_stride: int = 4,
weighting_scheme: str = "pyramid",
shuffle: bool = True,
) -> None:
r"""
Enable long video generation using FreeNoise.
Args:
context_length (`int`, defaults to `16`, *optional*):
The number of video frames to process at once. It's recommended to set this to the maximum frames the
Motion Adapter was trained with (usually 16/24/32). If `None`, the default value from the motion
adapter config is used.
context_stride (`int`, *optional*):
Long videos are generated by processing many frames. FreeNoise processes these frames in sliding
windows of size `context_length`. Context stride allows you to specify how many frames to skip between
each window. For example, a context length of 16 and context stride of 4 would process 24 frames as:
[0, 15], [4, 19], [8, 23] (0-based indexing)
weighting_scheme (`str`, defaults to `4`):
TODO(aryan)
shuffle (`str`, defaults to `True`):
TODO(aryan): decide if this is even needed
"""
self._free_noise_context_length = context_length or self.motion_adapter.config.motion_max_seq_length
self._free_noise_context_stride = context_stride
self._free_noise_weighting_scheme = weighting_scheme
self._free_noise_shuffle = shuffle
blocks = [*self.unet.down_blocks, self.unet.mid_block, *self.unet.up_blocks]
for block in blocks:
self._enable_free_noise_in_block(block)
def disable_free_noise(self) -> None:
self._free_noise_context_length = None
blocks = [*self.unet.down_blocks, self.unet.mid_block, *self.unet.up_blocks]
for block in blocks:
self._disable_free_noise_in_block(block)
@property
def free_noise_enabled(self):
return hasattr(self, "_free_noise_context_length") and self._free_noise_context_length is not None

View File

@@ -45,6 +45,7 @@ from ...utils import (
from ...utils.torch_utils import randn_tensor from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor from ...video_processor import VideoProcessor
from ..free_init_utils import FreeInitMixin from ..free_init_utils import FreeInitMixin
from ..free_noise_utils import AnimateDiffFreeNoiseMixin
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
@@ -131,6 +132,7 @@ class PIAPipeline(
StableDiffusionLoraLoaderMixin, StableDiffusionLoraLoaderMixin,
FromSingleFileMixin, FromSingleFileMixin,
FreeInitMixin, FreeInitMixin,
AnimateDiffFreeNoiseMixin,
): ):
r""" r"""
Pipeline for text-to-video generation. Pipeline for text-to-video generation.
@@ -407,15 +409,21 @@ class PIAPipeline(
return image_embeds, uncond_image_embeds return image_embeds, uncond_image_embeds
# Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents # Copied from diffusers.pipelines.animatediff.pipeline_animatediff.AnimateDiffPipeline.decode_latents
def decode_latents(self, latents): def decode_latents(self, latents, decode_batch_size: int = 16):
latents = 1 / self.vae.config.scaling_factor * latents latents = 1 / self.vae.config.scaling_factor * latents
batch_size, channels, num_frames, height, width = latents.shape batch_size, channels, num_frames, height, width = latents.shape
latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width) latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width)
image = self.vae.decode(latents).sample video = []
video = image[None, :].reshape((batch_size, num_frames, -1) + image.shape[2:]).permute(0, 2, 1, 3, 4) for i in range(0, latents.shape[0], decode_batch_size):
batch_latents = latents[i : i + decode_batch_size]
batch_latents = self.vae.decode(batch_latents).sample
video.append(batch_latents)
video = torch.cat(video)
video = video[None, :].reshape((batch_size, num_frames, -1) + video.shape[2:]).permute(0, 2, 1, 3, 4)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
video = video.float() video = video.float()
return video return video
@@ -687,6 +695,7 @@ class PIAPipeline(
clip_skip: Optional[int] = None, clip_skip: Optional[int] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"], callback_on_step_end_tensor_inputs: List[str] = ["latents"],
decode_batch_size: int = 16,
): ):
r""" r"""
The call function to the pipeline for generation. The call function to the pipeline for generation.
@@ -763,6 +772,8 @@ class PIAPipeline(
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class. `._callback_tensor_inputs` attribute of your pipeline class.
decode_batch_size (`int`, defaults to `16`):
The number of frames to decode at a time when calling `decode_latents` method.
Examples: Examples:
@@ -931,7 +942,7 @@ class PIAPipeline(
if output_type == "latent": if output_type == "latent":
video = latents video = latents
else: else:
video_tensor = self.decode_latents(latents) video_tensor = self.decode_latents(latents, decode_batch_size)
video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type) video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type)
# 10. Offload all models # 10. Offload all models