mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-08 13:34:27 +08:00
Compare commits
15 Commits
ci-test-hu
...
freenoise-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5a60a62c47 | ||
|
|
691facfc2e | ||
|
|
dc96a8d5cd | ||
|
|
1b7bc007d8 | ||
|
|
1bb09845bf | ||
|
|
024e2da864 | ||
|
|
f6897ae46a | ||
|
|
a41f843dba | ||
|
|
10b65b310c | ||
|
|
610f433d1c | ||
|
|
690dad693f | ||
|
|
2e97ba7ccb | ||
|
|
5d0f4c3407 | ||
|
|
441d321152 | ||
|
|
80e530fbfa |
@@ -272,6 +272,17 @@ class BasicTransformerBlock(nn.Module):
|
|||||||
attention_out_bias: bool = True,
|
attention_out_bias: bool = True,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.attention_head_dim = attention_head_dim
|
||||||
|
self.dropout = dropout
|
||||||
|
self.cross_attention_dim = cross_attention_dim
|
||||||
|
self.activation_fn = activation_fn
|
||||||
|
self.attention_bias = attention_bias
|
||||||
|
self.double_self_attention = double_self_attention
|
||||||
|
self.norm_elementwise_affine = norm_elementwise_affine
|
||||||
|
self.positional_embeddings = positional_embeddings
|
||||||
|
self.num_positional_embeddings = num_positional_embeddings
|
||||||
self.only_cross_attention = only_cross_attention
|
self.only_cross_attention = only_cross_attention
|
||||||
|
|
||||||
# We keep these boolean flags for backward-compatibility.
|
# We keep these boolean flags for backward-compatibility.
|
||||||
|
|||||||
@@ -11,7 +11,7 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from typing import Any, Dict, Optional, Tuple, Union
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@@ -21,6 +21,8 @@ import torch.utils.checkpoint
|
|||||||
from ...configuration_utils import ConfigMixin, FrozenDict, register_to_config
|
from ...configuration_utils import ConfigMixin, FrozenDict, register_to_config
|
||||||
from ...loaders import FromOriginalModelMixin, UNet2DConditionLoadersMixin
|
from ...loaders import FromOriginalModelMixin, UNet2DConditionLoadersMixin
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
|
from ...utils.torch_utils import maybe_allow_in_graph
|
||||||
|
from ..attention import FeedForward, _chunked_feed_forward
|
||||||
from ..attention_processor import (
|
from ..attention_processor import (
|
||||||
ADDED_KV_ATTENTION_PROCESSORS,
|
ADDED_KV_ATTENTION_PROCESSORS,
|
||||||
CROSS_ATTENTION_PROCESSORS,
|
CROSS_ATTENTION_PROCESSORS,
|
||||||
@@ -33,7 +35,7 @@ from ..attention_processor import (
|
|||||||
IPAdapterAttnProcessor,
|
IPAdapterAttnProcessor,
|
||||||
IPAdapterAttnProcessor2_0,
|
IPAdapterAttnProcessor2_0,
|
||||||
)
|
)
|
||||||
from ..embeddings import TimestepEmbedding, Timesteps
|
from ..embeddings import SinusoidalPositionalEmbedding, TimestepEmbedding, Timesteps
|
||||||
from ..modeling_utils import ModelMixin
|
from ..modeling_utils import ModelMixin
|
||||||
from ..transformers.transformer_temporal import TransformerTemporalModel
|
from ..transformers.transformer_temporal import TransformerTemporalModel
|
||||||
from .unet_2d_blocks import UNetMidBlock2DCrossAttn
|
from .unet_2d_blocks import UNetMidBlock2DCrossAttn
|
||||||
@@ -53,6 +55,302 @@ from .unet_3d_condition import UNet3DConditionOutput
|
|||||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
|
@maybe_allow_in_graph
|
||||||
|
class FreeNoiseTransformerBlock(nn.Module):
|
||||||
|
r"""
|
||||||
|
A FreeNoise Transformer block.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
dim (`int`): The number of channels in the input and output.
|
||||||
|
num_attention_heads (`int`): The number of heads to use for multi-head attention.
|
||||||
|
attention_head_dim (`int`): The number of channels in each head.
|
||||||
|
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
||||||
|
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
|
||||||
|
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
||||||
|
num_embeds_ada_norm (:
|
||||||
|
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
|
||||||
|
attention_bias (:
|
||||||
|
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
|
||||||
|
only_cross_attention (`bool`, *optional*):
|
||||||
|
Whether to use only cross-attention layers. In this case two cross attention layers are used.
|
||||||
|
double_self_attention (`bool`, *optional*):
|
||||||
|
Whether to use two self-attention layers. In this case no cross attention layers are used.
|
||||||
|
upcast_attention (`bool`, *optional*):
|
||||||
|
Whether to upcast the attention computation to float32. This is useful for mixed precision training.
|
||||||
|
norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to use learnable elementwise affine parameters for normalization.
|
||||||
|
norm_type (`str`, *optional*, defaults to `"layer_norm"`):
|
||||||
|
The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
|
||||||
|
final_dropout (`bool` *optional*, defaults to False):
|
||||||
|
Whether to apply a final dropout after the last feed-forward layer.
|
||||||
|
attention_type (`str`, *optional*, defaults to `"default"`):
|
||||||
|
The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
|
||||||
|
positional_embeddings (`str`, *optional*, defaults to `None`):
|
||||||
|
The type of positional embeddings to apply to.
|
||||||
|
num_positional_embeddings (`int`, *optional*, defaults to `None`):
|
||||||
|
The maximum number of positional embeddings to apply.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
num_attention_heads: int,
|
||||||
|
attention_head_dim: int,
|
||||||
|
dropout: float = 0.0,
|
||||||
|
cross_attention_dim: Optional[int] = None,
|
||||||
|
activation_fn: str = "geglu",
|
||||||
|
num_embeds_ada_norm: Optional[int] = None,
|
||||||
|
attention_bias: bool = False,
|
||||||
|
only_cross_attention: bool = False,
|
||||||
|
double_self_attention: bool = False,
|
||||||
|
upcast_attention: bool = False,
|
||||||
|
norm_elementwise_affine: bool = True,
|
||||||
|
norm_type: str = "layer_norm",
|
||||||
|
norm_eps: float = 1e-5,
|
||||||
|
final_dropout: bool = False,
|
||||||
|
positional_embeddings: Optional[str] = None,
|
||||||
|
num_positional_embeddings: Optional[int] = None,
|
||||||
|
ff_inner_dim: Optional[int] = None,
|
||||||
|
ff_bias: bool = True,
|
||||||
|
attention_out_bias: bool = True,
|
||||||
|
context_length: int = 16,
|
||||||
|
context_stride: int = 4,
|
||||||
|
weighting_scheme: str = "pyramid",
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.attention_head_dim = attention_head_dim
|
||||||
|
self.dropout = dropout
|
||||||
|
self.cross_attention_dim = cross_attention_dim
|
||||||
|
self.activation_fn = activation_fn
|
||||||
|
self.attention_bias = attention_bias
|
||||||
|
self.double_self_attention = double_self_attention
|
||||||
|
self.norm_elementwise_affine = norm_elementwise_affine
|
||||||
|
self.positional_embeddings = positional_embeddings
|
||||||
|
self.num_positional_embeddings = num_positional_embeddings
|
||||||
|
self.only_cross_attention = only_cross_attention
|
||||||
|
|
||||||
|
self.set_free_noise_properties(context_length, context_stride, weighting_scheme)
|
||||||
|
|
||||||
|
# We keep these boolean flags for backward-compatibility.
|
||||||
|
self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
|
||||||
|
self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
|
||||||
|
self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
|
||||||
|
self.use_layer_norm = norm_type == "layer_norm"
|
||||||
|
self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous"
|
||||||
|
|
||||||
|
if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
|
||||||
|
f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
|
||||||
|
)
|
||||||
|
|
||||||
|
self.norm_type = norm_type
|
||||||
|
self.num_embeds_ada_norm = num_embeds_ada_norm
|
||||||
|
|
||||||
|
if positional_embeddings and (num_positional_embeddings is None):
|
||||||
|
raise ValueError(
|
||||||
|
"If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
|
||||||
|
)
|
||||||
|
|
||||||
|
if positional_embeddings == "sinusoidal":
|
||||||
|
self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
|
||||||
|
else:
|
||||||
|
self.pos_embed = None
|
||||||
|
|
||||||
|
# Define 3 blocks. Each block has its own normalization layer.
|
||||||
|
# 1. Self-Attn
|
||||||
|
self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
|
||||||
|
|
||||||
|
self.attn1 = Attention(
|
||||||
|
query_dim=dim,
|
||||||
|
heads=num_attention_heads,
|
||||||
|
dim_head=attention_head_dim,
|
||||||
|
dropout=dropout,
|
||||||
|
bias=attention_bias,
|
||||||
|
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
|
||||||
|
upcast_attention=upcast_attention,
|
||||||
|
out_bias=attention_out_bias,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. Cross-Attn
|
||||||
|
if cross_attention_dim is not None or double_self_attention:
|
||||||
|
self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
|
||||||
|
|
||||||
|
self.attn2 = Attention(
|
||||||
|
query_dim=dim,
|
||||||
|
cross_attention_dim=cross_attention_dim if not double_self_attention else None,
|
||||||
|
heads=num_attention_heads,
|
||||||
|
dim_head=attention_head_dim,
|
||||||
|
dropout=dropout,
|
||||||
|
bias=attention_bias,
|
||||||
|
upcast_attention=upcast_attention,
|
||||||
|
out_bias=attention_out_bias,
|
||||||
|
) # is self-attn if encoder_hidden_states is none
|
||||||
|
|
||||||
|
# 3. Feed-forward
|
||||||
|
self.ff = FeedForward(
|
||||||
|
dim,
|
||||||
|
dropout=dropout,
|
||||||
|
activation_fn=activation_fn,
|
||||||
|
final_dropout=final_dropout,
|
||||||
|
inner_dim=ff_inner_dim,
|
||||||
|
bias=ff_bias,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
|
||||||
|
|
||||||
|
# let chunk size default to None
|
||||||
|
self._chunk_size = None
|
||||||
|
self._chunk_dim = 0
|
||||||
|
|
||||||
|
def _get_frame_indices(self, num_frames: int) -> List[Tuple[int, int]]:
|
||||||
|
frame_indices = []
|
||||||
|
for i in range(0, num_frames - self.context_length + 1, self.context_stride):
|
||||||
|
window_start = i
|
||||||
|
window_end = min(num_frames, i + self.context_length)
|
||||||
|
frame_indices.append((window_start, window_end))
|
||||||
|
|
||||||
|
return frame_indices
|
||||||
|
|
||||||
|
def _get_frame_weights(self, num_frames: int, weighting_scheme: str = "pyramid") -> List[float]:
|
||||||
|
if weighting_scheme == "pyramid":
|
||||||
|
if num_frames % 2 == 0:
|
||||||
|
# num_frames = 4 => [1, 2, 2, 1]
|
||||||
|
weights = list(range(1, num_frames // 2 + 1))
|
||||||
|
weights = weights + weights[::-1]
|
||||||
|
else:
|
||||||
|
# num_frames = 5 => [1, 2, 3, 2, 1]
|
||||||
|
weights = list(range(1, num_frames // 2 + 1))
|
||||||
|
weights = weights + [num_frames // 2 + 1] + weights[::-1]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported value for weighting_scheme={weighting_scheme}")
|
||||||
|
|
||||||
|
return weights
|
||||||
|
|
||||||
|
def set_free_noise_properties(
|
||||||
|
self, context_length: int, context_stride: int, weighting_scheme: str = "pyramid"
|
||||||
|
) -> None:
|
||||||
|
self.context_length = context_length
|
||||||
|
self.context_stride = context_stride
|
||||||
|
self.weighting_scheme = weighting_scheme
|
||||||
|
|
||||||
|
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0) -> None:
|
||||||
|
# Sets chunk feed-forward
|
||||||
|
self._chunk_size = chunk_size
|
||||||
|
self._chunk_dim = dim
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||||
|
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
cross_attention_kwargs: Dict[str, Any] = None,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if cross_attention_kwargs is not None:
|
||||||
|
if cross_attention_kwargs.get("scale", None) is not None:
|
||||||
|
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
|
||||||
|
|
||||||
|
cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
|
||||||
|
|
||||||
|
# hidden_states: [B x H x W, F, C]
|
||||||
|
device = hidden_states.device
|
||||||
|
dtype = hidden_states.dtype
|
||||||
|
|
||||||
|
num_frames = hidden_states.size(1)
|
||||||
|
frame_indices = self._get_frame_indices(num_frames)
|
||||||
|
frame_weights = self._get_frame_weights(self.context_length, self.weighting_scheme)
|
||||||
|
frame_weights = torch.tensor(frame_weights, device=device, dtype=dtype).unsqueeze(0).unsqueeze(-1)
|
||||||
|
is_last_frame_batch_complete = frame_indices[-1][1] == num_frames
|
||||||
|
|
||||||
|
# Handle out-of-bounds case if num_frames isn't perfectly divisible by context_length
|
||||||
|
# For example, num_frames=25, context_length=16, context_stride=4, then we expect the ranges:
|
||||||
|
# [(0, 16), (4, 20), (8, 24), (10, 26)]
|
||||||
|
if not is_last_frame_batch_complete:
|
||||||
|
if num_frames < self.context_length:
|
||||||
|
raise ValueError(f"Expected {num_frames=} to be greater or equal than {self.context_length=}")
|
||||||
|
last_frame_batch_length = num_frames - frame_indices[-1][1]
|
||||||
|
frame_indices.append((num_frames - self.context_length, num_frames))
|
||||||
|
|
||||||
|
num_times_accumulated = torch.zeros((1, num_frames, 1), device=device)
|
||||||
|
accumulated_values = torch.zeros_like(hidden_states)
|
||||||
|
|
||||||
|
for i, (frame_start, frame_end) in enumerate(frame_indices):
|
||||||
|
# The reason for slicing here is to ensure that if (frame_end - frame_start) is to handle
|
||||||
|
# cases like frame_indices=[(0, 16), (16, 20)], if the user provided a video with 19 frames, or
|
||||||
|
# essentially a non-multiple of `context_length`.
|
||||||
|
weights = torch.ones_like(num_times_accumulated[:, frame_start:frame_end])
|
||||||
|
weights *= frame_weights
|
||||||
|
|
||||||
|
hidden_states_chunk = hidden_states[:, frame_start:frame_end]
|
||||||
|
|
||||||
|
# Notice that normalization is always applied before the real computation in the following blocks.
|
||||||
|
# 1. Self-Attention
|
||||||
|
# assert self.norm_type == "layer_norm"
|
||||||
|
norm_hidden_states = self.norm1(hidden_states_chunk)
|
||||||
|
|
||||||
|
if self.pos_embed is not None:
|
||||||
|
norm_hidden_states = self.pos_embed(norm_hidden_states)
|
||||||
|
|
||||||
|
attn_output = self.attn1(
|
||||||
|
norm_hidden_states,
|
||||||
|
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
**cross_attention_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states_chunk = attn_output + hidden_states_chunk
|
||||||
|
if hidden_states_chunk.ndim == 4:
|
||||||
|
hidden_states_chunk = hidden_states_chunk.squeeze(1)
|
||||||
|
|
||||||
|
# 2. Cross-Attention
|
||||||
|
if self.attn2 is not None:
|
||||||
|
norm_hidden_states = self.norm2(hidden_states_chunk)
|
||||||
|
|
||||||
|
if self.pos_embed is not None and self.norm_type != "ada_norm_single":
|
||||||
|
norm_hidden_states = self.pos_embed(norm_hidden_states)
|
||||||
|
|
||||||
|
attn_output = self.attn2(
|
||||||
|
norm_hidden_states,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
attention_mask=encoder_attention_mask,
|
||||||
|
**cross_attention_kwargs,
|
||||||
|
)
|
||||||
|
hidden_states_chunk = attn_output + hidden_states_chunk
|
||||||
|
|
||||||
|
if i == len(frame_indices) - 1 and not is_last_frame_batch_complete:
|
||||||
|
accumulated_values[:, -last_frame_batch_length:] += (
|
||||||
|
hidden_states_chunk[:, -last_frame_batch_length:] * weights[:, -last_frame_batch_length:]
|
||||||
|
)
|
||||||
|
num_times_accumulated[:, -last_frame_batch_length:] += weights[:, -last_frame_batch_length]
|
||||||
|
else:
|
||||||
|
accumulated_values[:, frame_start:frame_end] += hidden_states_chunk * weights
|
||||||
|
num_times_accumulated[:, frame_start:frame_end] += weights
|
||||||
|
|
||||||
|
hidden_states = torch.where(
|
||||||
|
num_times_accumulated > 0, accumulated_values / num_times_accumulated, accumulated_values
|
||||||
|
).to(dtype)
|
||||||
|
|
||||||
|
# 3. Feed-forward
|
||||||
|
norm_hidden_states = self.norm3(hidden_states)
|
||||||
|
|
||||||
|
if self._chunk_size is not None:
|
||||||
|
# "feed_forward_chunk_size" can be used to save memory
|
||||||
|
ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
|
||||||
|
else:
|
||||||
|
ff_output = self.ff(norm_hidden_states)
|
||||||
|
|
||||||
|
hidden_states = ff_output + hidden_states
|
||||||
|
if hidden_states.ndim == 4:
|
||||||
|
hidden_states = hidden_states.squeeze(1)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class MotionModules(nn.Module):
|
class MotionModules(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -42,6 +42,7 @@ from ...utils import (
|
|||||||
from ...utils.torch_utils import randn_tensor
|
from ...utils.torch_utils import randn_tensor
|
||||||
from ...video_processor import VideoProcessor
|
from ...video_processor import VideoProcessor
|
||||||
from ..free_init_utils import FreeInitMixin
|
from ..free_init_utils import FreeInitMixin
|
||||||
|
from ..free_noise_utils import AnimateDiffFreeNoiseMixin
|
||||||
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
||||||
from .pipeline_output import AnimateDiffPipelineOutput
|
from .pipeline_output import AnimateDiffPipelineOutput
|
||||||
|
|
||||||
@@ -72,6 +73,7 @@ class AnimateDiffPipeline(
|
|||||||
IPAdapterMixin,
|
IPAdapterMixin,
|
||||||
StableDiffusionLoraLoaderMixin,
|
StableDiffusionLoraLoaderMixin,
|
||||||
FreeInitMixin,
|
FreeInitMixin,
|
||||||
|
AnimateDiffFreeNoiseMixin,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Pipeline for text-to-video generation.
|
Pipeline for text-to-video generation.
|
||||||
@@ -394,15 +396,20 @@ class AnimateDiffPipeline(
|
|||||||
|
|
||||||
return ip_adapter_image_embeds
|
return ip_adapter_image_embeds
|
||||||
|
|
||||||
# Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents
|
def decode_latents(self, latents, decode_batch_size: int = 16):
|
||||||
def decode_latents(self, latents):
|
|
||||||
latents = 1 / self.vae.config.scaling_factor * latents
|
latents = 1 / self.vae.config.scaling_factor * latents
|
||||||
|
|
||||||
batch_size, channels, num_frames, height, width = latents.shape
|
batch_size, channels, num_frames, height, width = latents.shape
|
||||||
latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width)
|
latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width)
|
||||||
|
|
||||||
image = self.vae.decode(latents).sample
|
video = []
|
||||||
video = image[None, :].reshape((batch_size, num_frames, -1) + image.shape[2:]).permute(0, 2, 1, 3, 4)
|
for i in range(0, latents.shape[0], decode_batch_size):
|
||||||
|
batch_latents = latents[i : i + decode_batch_size]
|
||||||
|
batch_latents = self.vae.decode(batch_latents).sample
|
||||||
|
video.append(batch_latents)
|
||||||
|
|
||||||
|
video = torch.cat(video)
|
||||||
|
video = video[None, :].reshape((batch_size, num_frames, -1) + video.shape[2:]).permute(0, 2, 1, 3, 4)
|
||||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
||||||
video = video.float()
|
video = video.float()
|
||||||
return video
|
return video
|
||||||
@@ -495,7 +502,6 @@ class AnimateDiffPipeline(
|
|||||||
f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
|
f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.prepare_latents
|
|
||||||
def prepare_latents(
|
def prepare_latents(
|
||||||
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
|
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
|
||||||
):
|
):
|
||||||
@@ -517,6 +523,22 @@ class AnimateDiffPipeline(
|
|||||||
else:
|
else:
|
||||||
latents = latents.to(device)
|
latents = latents.to(device)
|
||||||
|
|
||||||
|
if self.free_noise_enabled and self._free_noise_shuffle:
|
||||||
|
for i in range(self._free_noise_context_length, num_frames, self._free_noise_context_stride):
|
||||||
|
# ensure window is within bounds
|
||||||
|
window_start = max(0, i - self._free_noise_context_length)
|
||||||
|
window_end = min(num_frames, window_start + self._free_noise_context_stride)
|
||||||
|
window_length = window_end - window_start
|
||||||
|
|
||||||
|
if window_length == 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
indices = torch.LongTensor(list(range(window_start, window_end)))
|
||||||
|
shuffled_indices = indices[torch.randperm(window_length, generator=generator)]
|
||||||
|
|
||||||
|
# shuffle latents in every window
|
||||||
|
latents[:, :, window_start:window_end] = latents[:, :, shuffled_indices]
|
||||||
|
|
||||||
# scale the initial noise by the standard deviation required by the scheduler
|
# scale the initial noise by the standard deviation required by the scheduler
|
||||||
latents = latents * self.scheduler.init_noise_sigma
|
latents = latents * self.scheduler.init_noise_sigma
|
||||||
return latents
|
return latents
|
||||||
@@ -569,6 +591,7 @@ class AnimateDiffPipeline(
|
|||||||
clip_skip: Optional[int] = None,
|
clip_skip: Optional[int] = None,
|
||||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||||
|
decode_batch_size: int = 16,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
@@ -637,6 +660,8 @@ class AnimateDiffPipeline(
|
|||||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||||
|
decode_batch_size (`int`, defaults to `16`):
|
||||||
|
The number of frames to decode at a time when calling `decode_latents` method.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
|
|
||||||
@@ -808,7 +833,7 @@ class AnimateDiffPipeline(
|
|||||||
if output_type == "latent":
|
if output_type == "latent":
|
||||||
video = latents
|
video = latents
|
||||||
else:
|
else:
|
||||||
video_tensor = self.decode_latents(latents)
|
video_tensor = self.decode_latents(latents, decode_batch_size)
|
||||||
video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type)
|
video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type)
|
||||||
|
|
||||||
# 10. Offload all models
|
# 10. Offload all models
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -56,6 +56,7 @@ from ...utils import (
|
|||||||
from ...utils.torch_utils import randn_tensor
|
from ...utils.torch_utils import randn_tensor
|
||||||
from ...video_processor import VideoProcessor
|
from ...video_processor import VideoProcessor
|
||||||
from ..free_init_utils import FreeInitMixin
|
from ..free_init_utils import FreeInitMixin
|
||||||
|
from ..free_noise_utils import AnimateDiffFreeNoiseMixin
|
||||||
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
||||||
from .pipeline_output import AnimateDiffPipelineOutput
|
from .pipeline_output import AnimateDiffPipelineOutput
|
||||||
|
|
||||||
@@ -194,6 +195,7 @@ class AnimateDiffSDXLPipeline(
|
|||||||
TextualInversionLoaderMixin,
|
TextualInversionLoaderMixin,
|
||||||
IPAdapterMixin,
|
IPAdapterMixin,
|
||||||
FreeInitMixin,
|
FreeInitMixin,
|
||||||
|
AnimateDiffFreeNoiseMixin,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Pipeline for text-to-video generation using Stable Diffusion XL.
|
Pipeline for text-to-video generation using Stable Diffusion XL.
|
||||||
@@ -606,15 +608,21 @@ class AnimateDiffSDXLPipeline(
|
|||||||
|
|
||||||
return ip_adapter_image_embeds
|
return ip_adapter_image_embeds
|
||||||
|
|
||||||
# Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents
|
# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.AnimateDiffPipeline.decode_latents
|
||||||
def decode_latents(self, latents):
|
def decode_latents(self, latents, decode_batch_size: int = 16):
|
||||||
latents = 1 / self.vae.config.scaling_factor * latents
|
latents = 1 / self.vae.config.scaling_factor * latents
|
||||||
|
|
||||||
batch_size, channels, num_frames, height, width = latents.shape
|
batch_size, channels, num_frames, height, width = latents.shape
|
||||||
latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width)
|
latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width)
|
||||||
|
|
||||||
image = self.vae.decode(latents).sample
|
video = []
|
||||||
video = image[None, :].reshape((batch_size, num_frames, -1) + image.shape[2:]).permute(0, 2, 1, 3, 4)
|
for i in range(0, latents.shape[0], decode_batch_size):
|
||||||
|
batch_latents = latents[i : i + decode_batch_size]
|
||||||
|
batch_latents = self.vae.decode(batch_latents).sample
|
||||||
|
video.append(batch_latents)
|
||||||
|
|
||||||
|
video = torch.cat(video)
|
||||||
|
video = video[None, :].reshape((batch_size, num_frames, -1) + video.shape[2:]).permute(0, 2, 1, 3, 4)
|
||||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
||||||
video = video.float()
|
video = video.float()
|
||||||
return video
|
return video
|
||||||
@@ -876,6 +884,7 @@ class AnimateDiffSDXLPipeline(
|
|||||||
clip_skip: Optional[int] = None,
|
clip_skip: Optional[int] = None,
|
||||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||||
|
decode_batch_size: int = 16,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Function invoked when calling the pipeline for generation.
|
Function invoked when calling the pipeline for generation.
|
||||||
@@ -1015,6 +1024,8 @@ class AnimateDiffSDXLPipeline(
|
|||||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||||
|
decode_batch_size (`int`, defaults to `16`):
|
||||||
|
The number of frames to decode at a time when calling `decode_latents` method.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
|
|
||||||
@@ -1258,7 +1269,7 @@ class AnimateDiffSDXLPipeline(
|
|||||||
if output_type == "latent":
|
if output_type == "latent":
|
||||||
video = latents
|
video = latents
|
||||||
else:
|
else:
|
||||||
video_tensor = self.decode_latents(latents)
|
video_tensor = self.decode_latents(latents, decode_batch_size)
|
||||||
video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type)
|
video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type)
|
||||||
|
|
||||||
# cast back to fp16 if needed
|
# cast back to fp16 if needed
|
||||||
|
|||||||
@@ -38,6 +38,7 @@ from ...utils import (
|
|||||||
from ...utils.torch_utils import is_compiled_module, randn_tensor
|
from ...utils.torch_utils import is_compiled_module, randn_tensor
|
||||||
from ...video_processor import VideoProcessor
|
from ...video_processor import VideoProcessor
|
||||||
from ..free_init_utils import FreeInitMixin
|
from ..free_init_utils import FreeInitMixin
|
||||||
|
from ..free_noise_utils import AnimateDiffFreeNoiseMixin
|
||||||
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
||||||
from .pipeline_output import AnimateDiffPipelineOutput
|
from .pipeline_output import AnimateDiffPipelineOutput
|
||||||
|
|
||||||
@@ -127,6 +128,7 @@ class AnimateDiffSparseControlNetPipeline(
|
|||||||
IPAdapterMixin,
|
IPAdapterMixin,
|
||||||
StableDiffusionLoraLoaderMixin,
|
StableDiffusionLoraLoaderMixin,
|
||||||
FreeInitMixin,
|
FreeInitMixin,
|
||||||
|
AnimateDiffFreeNoiseMixin,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Pipeline for controlled text-to-video generation using the method described in [SparseCtrl: Adding Sparse Controls
|
Pipeline for controlled text-to-video generation using the method described in [SparseCtrl: Adding Sparse Controls
|
||||||
@@ -448,15 +450,21 @@ class AnimateDiffSparseControlNetPipeline(
|
|||||||
|
|
||||||
return ip_adapter_image_embeds
|
return ip_adapter_image_embeds
|
||||||
|
|
||||||
# Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents
|
# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.AnimateDiffPipeline.decode_latents
|
||||||
def decode_latents(self, latents):
|
def decode_latents(self, latents, decode_batch_size: int = 16):
|
||||||
latents = 1 / self.vae.config.scaling_factor * latents
|
latents = 1 / self.vae.config.scaling_factor * latents
|
||||||
|
|
||||||
batch_size, channels, num_frames, height, width = latents.shape
|
batch_size, channels, num_frames, height, width = latents.shape
|
||||||
latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width)
|
latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width)
|
||||||
|
|
||||||
image = self.vae.decode(latents).sample
|
video = []
|
||||||
video = image[None, :].reshape((batch_size, num_frames, -1) + image.shape[2:]).permute(0, 2, 1, 3, 4)
|
for i in range(0, latents.shape[0], decode_batch_size):
|
||||||
|
batch_latents = latents[i : i + decode_batch_size]
|
||||||
|
batch_latents = self.vae.decode(batch_latents).sample
|
||||||
|
video.append(batch_latents)
|
||||||
|
|
||||||
|
video = torch.cat(video)
|
||||||
|
video = video[None, :].reshape((batch_size, num_frames, -1) + video.shape[2:]).permute(0, 2, 1, 3, 4)
|
||||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
||||||
video = video.float()
|
video = video.float()
|
||||||
return video
|
return video
|
||||||
@@ -728,6 +736,7 @@ class AnimateDiffSparseControlNetPipeline(
|
|||||||
clip_skip: Optional[int] = None,
|
clip_skip: Optional[int] = None,
|
||||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||||
|
decode_batch_size: int = 16,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
The call function to the pipeline for generation.
|
The call function to the pipeline for generation.
|
||||||
@@ -806,6 +815,8 @@ class AnimateDiffSparseControlNetPipeline(
|
|||||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||||
|
decode_batch_size (`int`, defaults to `16`):
|
||||||
|
The number of frames to decode at a time when calling `decode_latents` method.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
|
|
||||||
@@ -996,7 +1007,7 @@ class AnimateDiffSparseControlNetPipeline(
|
|||||||
if output_type == "latent":
|
if output_type == "latent":
|
||||||
video = latents
|
video = latents
|
||||||
else:
|
else:
|
||||||
video_tensor = self.decode_latents(latents)
|
video_tensor = self.decode_latents(latents, decode_batch_size)
|
||||||
video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type)
|
video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type)
|
||||||
|
|
||||||
# 12. Offload all models
|
# 12. Offload all models
|
||||||
|
|||||||
@@ -35,6 +35,7 @@ from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_
|
|||||||
from ...utils.torch_utils import randn_tensor
|
from ...utils.torch_utils import randn_tensor
|
||||||
from ...video_processor import VideoProcessor
|
from ...video_processor import VideoProcessor
|
||||||
from ..free_init_utils import FreeInitMixin
|
from ..free_init_utils import FreeInitMixin
|
||||||
|
from ..free_noise_utils import AnimateDiffFreeNoiseMixin
|
||||||
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
||||||
from .pipeline_output import AnimateDiffPipelineOutput
|
from .pipeline_output import AnimateDiffPipelineOutput
|
||||||
|
|
||||||
@@ -176,6 +177,7 @@ class AnimateDiffVideoToVideoPipeline(
|
|||||||
IPAdapterMixin,
|
IPAdapterMixin,
|
||||||
StableDiffusionLoraLoaderMixin,
|
StableDiffusionLoraLoaderMixin,
|
||||||
FreeInitMixin,
|
FreeInitMixin,
|
||||||
|
AnimateDiffFreeNoiseMixin,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Pipeline for video-to-video generation.
|
Pipeline for video-to-video generation.
|
||||||
@@ -498,15 +500,21 @@ class AnimateDiffVideoToVideoPipeline(
|
|||||||
|
|
||||||
return ip_adapter_image_embeds
|
return ip_adapter_image_embeds
|
||||||
|
|
||||||
# Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents
|
# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.AnimateDiffPipeline.decode_latents
|
||||||
def decode_latents(self, latents):
|
def decode_latents(self, latents, decode_batch_size: int = 16):
|
||||||
latents = 1 / self.vae.config.scaling_factor * latents
|
latents = 1 / self.vae.config.scaling_factor * latents
|
||||||
|
|
||||||
batch_size, channels, num_frames, height, width = latents.shape
|
batch_size, channels, num_frames, height, width = latents.shape
|
||||||
latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width)
|
latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width)
|
||||||
|
|
||||||
image = self.vae.decode(latents).sample
|
video = []
|
||||||
video = image[None, :].reshape((batch_size, num_frames, -1) + image.shape[2:]).permute(0, 2, 1, 3, 4)
|
for i in range(0, latents.shape[0], decode_batch_size):
|
||||||
|
batch_latents = latents[i : i + decode_batch_size]
|
||||||
|
batch_latents = self.vae.decode(batch_latents).sample
|
||||||
|
video.append(batch_latents)
|
||||||
|
|
||||||
|
video = torch.cat(video)
|
||||||
|
video = video[None, :].reshape((batch_size, num_frames, -1) + video.shape[2:]).permute(0, 2, 1, 3, 4)
|
||||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
||||||
video = video.float()
|
video = video.float()
|
||||||
return video
|
return video
|
||||||
@@ -747,6 +755,7 @@ class AnimateDiffVideoToVideoPipeline(
|
|||||||
clip_skip: Optional[int] = None,
|
clip_skip: Optional[int] = None,
|
||||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||||
|
decode_batch_size: int = 16,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
The call function to the pipeline for generation.
|
The call function to the pipeline for generation.
|
||||||
@@ -822,6 +831,8 @@ class AnimateDiffVideoToVideoPipeline(
|
|||||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||||
|
decode_batch_size (`int`, defaults to `16`):
|
||||||
|
The number of frames to decode at a time when calling `decode_latents` method.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
|
|
||||||
@@ -990,7 +1001,7 @@ class AnimateDiffVideoToVideoPipeline(
|
|||||||
if output_type == "latent":
|
if output_type == "latent":
|
||||||
video = latents
|
video = latents
|
||||||
else:
|
else:
|
||||||
video_tensor = self.decode_latents(latents)
|
video_tensor = self.decode_latents(latents, decode_batch_size)
|
||||||
video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type)
|
video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type)
|
||||||
|
|
||||||
# 10. Offload all models
|
# 10. Offload all models
|
||||||
|
|||||||
141
src/diffusers/pipelines/free_noise_utils.py
Normal file
141
src/diffusers/pipelines/free_noise_utils.py
Normal file
@@ -0,0 +1,141 @@
|
|||||||
|
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
from ..models.attention import BasicTransformerBlock
|
||||||
|
from ..models.unets.unet_motion_model import (
|
||||||
|
CrossAttnDownBlockMotion,
|
||||||
|
DownBlockMotion,
|
||||||
|
FreeNoiseTransformerBlock,
|
||||||
|
TransformerTemporalModel,
|
||||||
|
UpBlockMotion,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AnimateDiffFreeNoiseMixin:
|
||||||
|
r"""Mixin class for [FreeNoise](https://arxiv.org/abs/2310.15169)."""
|
||||||
|
|
||||||
|
def _enable_free_noise_in_block(self, block: Union[CrossAttnDownBlockMotion, DownBlockMotion, UpBlockMotion]):
|
||||||
|
r"""Helper function to enable FreeNoise in transformer blocks."""
|
||||||
|
|
||||||
|
for motion_module in block.motion_modules:
|
||||||
|
motion_module: TransformerTemporalModel
|
||||||
|
num_transformer_blocks = len(motion_module.transformer_blocks)
|
||||||
|
|
||||||
|
for i in range(num_transformer_blocks):
|
||||||
|
if isinstance(motion_module.transformer_blocks[i], FreeNoiseTransformerBlock):
|
||||||
|
motion_module.transformer_blocks[i].set_free_noise_properties(
|
||||||
|
self._free_noise_context_length,
|
||||||
|
self._free_noise_context_stride,
|
||||||
|
self._free_noise_weighting_scheme,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert isinstance(motion_module.transformer_blocks[i], BasicTransformerBlock)
|
||||||
|
basic_transfomer_block = motion_module.transformer_blocks[i]
|
||||||
|
|
||||||
|
motion_module.transformer_blocks[i] = FreeNoiseTransformerBlock(
|
||||||
|
dim=basic_transfomer_block.dim,
|
||||||
|
num_attention_heads=basic_transfomer_block.num_attention_heads,
|
||||||
|
attention_head_dim=basic_transfomer_block.attention_head_dim,
|
||||||
|
dropout=basic_transfomer_block.dropout,
|
||||||
|
cross_attention_dim=basic_transfomer_block.cross_attention_dim,
|
||||||
|
activation_fn=basic_transfomer_block.activation_fn,
|
||||||
|
attention_bias=basic_transfomer_block.attention_bias,
|
||||||
|
only_cross_attention=basic_transfomer_block.only_cross_attention,
|
||||||
|
double_self_attention=basic_transfomer_block.double_self_attention,
|
||||||
|
positional_embeddings=basic_transfomer_block.positional_embeddings,
|
||||||
|
num_positional_embeddings=basic_transfomer_block.num_positional_embeddings,
|
||||||
|
context_length=self._free_noise_context_length,
|
||||||
|
context_stride=self._free_noise_context_stride,
|
||||||
|
weighting_scheme=self._free_noise_weighting_scheme,
|
||||||
|
).to(device=self.device, dtype=self.dtype)
|
||||||
|
|
||||||
|
motion_module.transformer_blocks[i].load_state_dict(
|
||||||
|
basic_transfomer_block.state_dict(), strict=True
|
||||||
|
)
|
||||||
|
|
||||||
|
def _disable_free_noise_in_block(self, block: Union[CrossAttnDownBlockMotion, DownBlockMotion, UpBlockMotion]):
|
||||||
|
r"""Helper function to disable FreeNoise in transformer blocks."""
|
||||||
|
|
||||||
|
for motion_module in block.motion_modules:
|
||||||
|
motion_module: TransformerTemporalModel
|
||||||
|
num_transformer_blocks = len(motion_module.transformer_blocks)
|
||||||
|
|
||||||
|
for i in range(num_transformer_blocks):
|
||||||
|
if isinstance(motion_module.transformer_blocks[i], FreeNoiseTransformerBlock):
|
||||||
|
free_noise_transfomer_block = motion_module.transformer_blocks[i]
|
||||||
|
|
||||||
|
motion_module.transformer_blocks[i] = BasicTransformerBlock(
|
||||||
|
dim=free_noise_transfomer_block.dim,
|
||||||
|
num_attention_heads=free_noise_transfomer_block.num_attention_heads,
|
||||||
|
attention_head_dim=free_noise_transfomer_block.attention_head_dim,
|
||||||
|
dropout=free_noise_transfomer_block.dropout,
|
||||||
|
cross_attention_dim=free_noise_transfomer_block.cross_attention_dim,
|
||||||
|
activation_fn=free_noise_transfomer_block.activation_fn,
|
||||||
|
attention_bias=free_noise_transfomer_block.attention_bias,
|
||||||
|
only_cross_attention=free_noise_transfomer_block.only_cross_attention,
|
||||||
|
double_self_attention=free_noise_transfomer_block.double_self_attention,
|
||||||
|
positional_embeddings=free_noise_transfomer_block.positional_embeddings,
|
||||||
|
num_positional_embeddings=free_noise_transfomer_block.num_positional_embeddings,
|
||||||
|
).to(device=self.device, dtype=self.dtype)
|
||||||
|
|
||||||
|
motion_module.transformer_blocks[i].load_state_dict(
|
||||||
|
free_noise_transfomer_block.state_dict(), strict=True
|
||||||
|
)
|
||||||
|
|
||||||
|
def enable_free_noise(
|
||||||
|
self,
|
||||||
|
context_length: Optional[int] = 16,
|
||||||
|
context_stride: int = 4,
|
||||||
|
weighting_scheme: str = "pyramid",
|
||||||
|
shuffle: bool = True,
|
||||||
|
) -> None:
|
||||||
|
r"""
|
||||||
|
Enable long video generation using FreeNoise.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context_length (`int`, defaults to `16`, *optional*):
|
||||||
|
The number of video frames to process at once. It's recommended to set this to the maximum frames the
|
||||||
|
Motion Adapter was trained with (usually 16/24/32). If `None`, the default value from the motion
|
||||||
|
adapter config is used.
|
||||||
|
context_stride (`int`, *optional*):
|
||||||
|
Long videos are generated by processing many frames. FreeNoise processes these frames in sliding
|
||||||
|
windows of size `context_length`. Context stride allows you to specify how many frames to skip between
|
||||||
|
each window. For example, a context length of 16 and context stride of 4 would process 24 frames as:
|
||||||
|
[0, 15], [4, 19], [8, 23] (0-based indexing)
|
||||||
|
weighting_scheme (`str`, defaults to `4`):
|
||||||
|
TODO(aryan)
|
||||||
|
shuffle (`str`, defaults to `True`):
|
||||||
|
TODO(aryan): decide if this is even needed
|
||||||
|
"""
|
||||||
|
self._free_noise_context_length = context_length or self.motion_adapter.config.motion_max_seq_length
|
||||||
|
self._free_noise_context_stride = context_stride
|
||||||
|
self._free_noise_weighting_scheme = weighting_scheme
|
||||||
|
self._free_noise_shuffle = shuffle
|
||||||
|
|
||||||
|
blocks = [*self.unet.down_blocks, self.unet.mid_block, *self.unet.up_blocks]
|
||||||
|
for block in blocks:
|
||||||
|
self._enable_free_noise_in_block(block)
|
||||||
|
|
||||||
|
def disable_free_noise(self) -> None:
|
||||||
|
self._free_noise_context_length = None
|
||||||
|
|
||||||
|
blocks = [*self.unet.down_blocks, self.unet.mid_block, *self.unet.up_blocks]
|
||||||
|
for block in blocks:
|
||||||
|
self._disable_free_noise_in_block(block)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def free_noise_enabled(self):
|
||||||
|
return hasattr(self, "_free_noise_context_length") and self._free_noise_context_length is not None
|
||||||
@@ -45,6 +45,7 @@ from ...utils import (
|
|||||||
from ...utils.torch_utils import randn_tensor
|
from ...utils.torch_utils import randn_tensor
|
||||||
from ...video_processor import VideoProcessor
|
from ...video_processor import VideoProcessor
|
||||||
from ..free_init_utils import FreeInitMixin
|
from ..free_init_utils import FreeInitMixin
|
||||||
|
from ..free_noise_utils import AnimateDiffFreeNoiseMixin
|
||||||
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
||||||
|
|
||||||
|
|
||||||
@@ -131,6 +132,7 @@ class PIAPipeline(
|
|||||||
StableDiffusionLoraLoaderMixin,
|
StableDiffusionLoraLoaderMixin,
|
||||||
FromSingleFileMixin,
|
FromSingleFileMixin,
|
||||||
FreeInitMixin,
|
FreeInitMixin,
|
||||||
|
AnimateDiffFreeNoiseMixin,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Pipeline for text-to-video generation.
|
Pipeline for text-to-video generation.
|
||||||
@@ -407,15 +409,21 @@ class PIAPipeline(
|
|||||||
|
|
||||||
return image_embeds, uncond_image_embeds
|
return image_embeds, uncond_image_embeds
|
||||||
|
|
||||||
# Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents
|
# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.AnimateDiffPipeline.decode_latents
|
||||||
def decode_latents(self, latents):
|
def decode_latents(self, latents, decode_batch_size: int = 16):
|
||||||
latents = 1 / self.vae.config.scaling_factor * latents
|
latents = 1 / self.vae.config.scaling_factor * latents
|
||||||
|
|
||||||
batch_size, channels, num_frames, height, width = latents.shape
|
batch_size, channels, num_frames, height, width = latents.shape
|
||||||
latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width)
|
latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width)
|
||||||
|
|
||||||
image = self.vae.decode(latents).sample
|
video = []
|
||||||
video = image[None, :].reshape((batch_size, num_frames, -1) + image.shape[2:]).permute(0, 2, 1, 3, 4)
|
for i in range(0, latents.shape[0], decode_batch_size):
|
||||||
|
batch_latents = latents[i : i + decode_batch_size]
|
||||||
|
batch_latents = self.vae.decode(batch_latents).sample
|
||||||
|
video.append(batch_latents)
|
||||||
|
|
||||||
|
video = torch.cat(video)
|
||||||
|
video = video[None, :].reshape((batch_size, num_frames, -1) + video.shape[2:]).permute(0, 2, 1, 3, 4)
|
||||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
||||||
video = video.float()
|
video = video.float()
|
||||||
return video
|
return video
|
||||||
@@ -687,6 +695,7 @@ class PIAPipeline(
|
|||||||
clip_skip: Optional[int] = None,
|
clip_skip: Optional[int] = None,
|
||||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||||
|
decode_batch_size: int = 16,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
The call function to the pipeline for generation.
|
The call function to the pipeline for generation.
|
||||||
@@ -763,6 +772,8 @@ class PIAPipeline(
|
|||||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||||
|
decode_batch_size (`int`, defaults to `16`):
|
||||||
|
The number of frames to decode at a time when calling `decode_latents` method.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
|
|
||||||
@@ -931,7 +942,7 @@ class PIAPipeline(
|
|||||||
if output_type == "latent":
|
if output_type == "latent":
|
||||||
video = latents
|
video = latents
|
||||||
else:
|
else:
|
||||||
video_tensor = self.decode_latents(latents)
|
video_tensor = self.decode_latents(latents, decode_batch_size)
|
||||||
video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type)
|
video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type)
|
||||||
|
|
||||||
# 10. Offload all models
|
# 10. Offload all models
|
||||||
|
|||||||
Reference in New Issue
Block a user