Compare commits

...

10 Commits

Author SHA1 Message Date
Aryan
43ec0bd224 Merge branch 'main' into cogvideox/freenoise 2024-09-11 14:30:06 +02:00
Aryan
9aa2e97e69 update 2024-09-11 14:29:50 +02:00
Aryan
e07fe043a8 update progress 2024-09-11 13:15:57 +02:00
Aryan
cce65ab93f update progress 2024-09-11 00:55:00 +02:00
Aryan
052eeb5da9 update progress 2024-09-10 16:37:46 +02:00
Aryan
c9454bd4f1 Merge branch 'main' into cogvideox/freenoise 2024-09-10 13:51:21 +02:00
Aryan
2e7502f810 make style 2024-09-08 13:19:41 +02:00
Aryan
17b7f8ab8e fix bugs 2024-09-08 13:19:11 +02:00
Aryan
a012fa5748 update progress 2024-09-08 00:04:15 +02:00
Aryan
6e03e72cff update cogvideox freenoise progress 2024-09-07 23:44:03 +02:00
5 changed files with 1053 additions and 73 deletions

View File

@@ -1054,9 +1054,8 @@ class FreeNoiseTransformerBlock(nn.Module):
accumulated_values = torch.zeros_like(hidden_states)
for i, (frame_start, frame_end) in enumerate(frame_indices):
# The reason for slicing here is to ensure that if (frame_end - frame_start) is to handle
# cases like frame_indices=[(0, 16), (16, 20)], if the user provided a video with 19 frames, or
# essentially a non-multiple of `context_length`.
# The reason for slicing here is to handle cases like frame_indices=[(0, 16), (16, 20)],
# if the user provided a video with 19 frames, or essentially a non-multiple of `context_length`.
weights = torch.ones_like(num_times_accumulated[:, frame_start:frame_end])
weights *= frame_weights

View File

@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import List, Optional, Tuple, Union
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
import torch
@@ -354,9 +354,12 @@ class CogVideoXPatchEmbed(nn.Module):
super().__init__()
self.patch_size = patch_size
self.in_channels = in_channels
self.embed_dim = embed_dim
self.sample_height = sample_height
self.text_embed_dim = text_embed_dim
self.bias = bias
self.sample_width = sample_width
self.sample_height = sample_height
self.sample_frames = sample_frames
self.temporal_compression_ratio = temporal_compression_ratio
self.max_text_seq_length = max_text_seq_length
@@ -377,7 +380,6 @@ class CogVideoXPatchEmbed(nn.Module):
post_patch_height = sample_height // self.patch_size
post_patch_width = sample_width // self.patch_size
post_time_compression_frames = (sample_frames - 1) // self.temporal_compression_ratio + 1
num_patches = post_patch_height * post_patch_width * post_time_compression_frames
pos_embedding = get_3d_sincos_pos_embed(
self.embed_dim,
@@ -387,12 +389,7 @@ class CogVideoXPatchEmbed(nn.Module):
self.temporal_interpolation_scale,
)
pos_embedding = torch.from_numpy(pos_embedding).flatten(0, 1)
joint_pos_embedding = torch.zeros(
1, self.max_text_seq_length + num_patches, self.embed_dim, requires_grad=False
)
joint_pos_embedding.data[:, self.max_text_seq_length :].copy_(pos_embedding)
return joint_pos_embedding
return pos_embedding
def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
r"""
@@ -409,11 +406,7 @@ class CogVideoXPatchEmbed(nn.Module):
image_embeds = self.proj(image_embeds)
image_embeds = image_embeds.view(batch, num_frames, *image_embeds.shape[1:])
image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels]
image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels]
embeds = torch.cat(
[text_embeds, image_embeds], dim=1
).contiguous() # [batch, seq_length + num_frames x height x width, channels]
image_embeds = image_embeds.flatten(1, 2).contiguous() # [batch, num_frames x height x width, channels]
if self.use_positional_embeddings:
pre_time_compression_frames = (num_frames - 1) * self.temporal_compression_ratio + 1
@@ -423,13 +416,114 @@ class CogVideoXPatchEmbed(nn.Module):
or self.sample_frames != pre_time_compression_frames
):
pos_embedding = self._get_positional_embeddings(height, width, pre_time_compression_frames)
pos_embedding = pos_embedding.to(embeds.device, dtype=embeds.dtype)
pos_embedding = pos_embedding.to(image_embeds.device, dtype=image_embeds.dtype)
else:
pos_embedding = self.pos_embedding
embeds = embeds + pos_embedding
image_embeds = image_embeds + pos_embedding
return embeds
return text_embeds, image_embeds
class FreeNoiseCogVideoXPatchEmbed(nn.Module):
def __init__(
self,
patch_size: int = 2,
in_channels: int = 16,
embed_dim: int = 1920,
text_embed_dim: int = 4096,
bias: bool = True,
sample_width: int = 90,
sample_height: int = 60,
sample_frames: int = 49,
temporal_compression_ratio: int = 4,
max_text_seq_length: int = 226,
spatial_interpolation_scale: float = 1.875,
temporal_interpolation_scale: float = 1.0,
use_positional_embeddings: bool = True,
) -> None:
super().__init__()
self.patch_size = patch_size
self.in_channels = in_channels
self.embed_dim = embed_dim
self.text_embed_dim = text_embed_dim
self.bias = bias
self.sample_width = sample_width
self.sample_height = sample_height
self.sample_frames = sample_frames
self.temporal_compression_ratio = temporal_compression_ratio
self.max_text_seq_length = max_text_seq_length
self.spatial_interpolation_scale = spatial_interpolation_scale
self.temporal_interpolation_scale = temporal_interpolation_scale
self.use_positional_embeddings = use_positional_embeddings
self.proj = nn.Conv2d(
in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
)
self.text_proj = nn.Linear(text_embed_dim, embed_dim)
if use_positional_embeddings:
pos_embedding = self._get_positional_embeddings(sample_height, sample_width, sample_frames)
self.register_buffer("pos_embedding", pos_embedding, persistent=False)
# Copied from diffusers.models.embeddings.CogVideoXPatchEmbed
def _get_positional_embeddings(self, sample_height: int, sample_width: int, sample_frames: int) -> torch.Tensor:
post_patch_height = sample_height // self.patch_size
post_patch_width = sample_width // self.patch_size
post_time_compression_frames = (sample_frames - 1) // self.temporal_compression_ratio + 1
pos_embedding = get_3d_sincos_pos_embed(
self.embed_dim,
(post_patch_width, post_patch_height),
post_time_compression_frames,
self.spatial_interpolation_scale,
self.temporal_interpolation_scale,
)
pos_embedding = torch.from_numpy(pos_embedding).flatten(0, 1)
return pos_embedding
def forward(self, text_embeds: Union[torch.Tensor, Tuple[Dict[int, torch.Tensor]]], image_embeds: torch.Tensor):
r"""
Args:
text_embeds (`torch.Tensor`):
Input text embeddings. Expected shape: (batch_size, seq_length, embedding_dim).
image_embeds (`torch.Tensor`):
Input image embeddings. Expected shape: (batch_size, num_frames, channels, height, width).
"""
if isinstance(text_embeds, torch.Tensor):
text_embeds = self.text_proj(text_embeds)
else:
assert isinstance(text_embeds, tuple)
text_embeds_output = []
for tuple_index in range(len(text_embeds)):
text_embeds_output.append({})
for key, text_embed in list(text_embeds[tuple_index].items()):
text_embeds_output[tuple_index][key] = self.text_proj(text_embed)
text_embeds = tuple(text_embeds_output)
batch, num_frames, channels, height, width = image_embeds.shape
image_embeds = image_embeds.reshape(-1, channels, height, width)
image_embeds = self.proj(image_embeds)
image_embeds = image_embeds.view(batch, num_frames, *image_embeds.shape[1:])
image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels]
image_embeds = image_embeds.flatten(1, 2).contiguous() # [batch, num_frames x height x width, channels]
if self.use_positional_embeddings:
pre_time_compression_frames = (num_frames - 1) * self.temporal_compression_ratio + 1
if (
self.sample_height != height
or self.sample_width != width
or self.sample_frames != pre_time_compression_frames
):
pos_embedding = self._get_positional_embeddings(height, width, pre_time_compression_frames)
pos_embedding = pos_embedding.to(image_embeds.device, dtype=image_embeds.dtype)
else:
pos_embedding = self.pos_embedding
image_embeds = image_embeds + pos_embedding
return text_embeds, image_embeds
def get_3d_rotary_pos_embed(

View File

@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch
from torch import nn
@@ -86,6 +86,20 @@ class CogVideoXBlock(nn.Module):
attention_out_bias: bool = True,
):
super().__init__()
self.dim = dim
self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim
self.time_embed_dim = time_embed_dim
self.dropout = dropout
self.activation_fn = activation_fn
self.attention_bias = attention_bias
self.qk_norm = qk_norm
self.norm_elementwise_affine = norm_elementwise_affine
self.norm_eps = norm_eps
self.final_dropout = final_dropout
self.ff_inner_dim = ff_inner_dim
self.ff_bias = ff_bias
self.attention_out_bias = attention_out_bias
# 1. Self Attention
self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
@@ -119,6 +133,7 @@ class CogVideoXBlock(nn.Module):
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
num_frames: int = None,
) -> torch.Tensor:
text_seq_length = encoder_hidden_states.size(1)
@@ -152,6 +167,410 @@ class CogVideoXBlock(nn.Module):
return hidden_states, encoder_hidden_states
# norm_final is just nn.LayerNorm, all ops will be on channel dimension, so we dont have to care about frame dimension
# proj_out is also just along channel dimension
# norm_out is just linear and nn.LayerNorm, again only on channel dimension, so we dont care about frame dims
# same story with norm1, norm2 and ff
# patch embed layer just applies on channel dim too and condenses to [B, FHW, C]
# only attention layer seems to be actually doing anything with the frame dimension and so only location where FreeNoise needs to be applied
# Since it does not matter for norm1, norm2, ff, and they might create memory bottleneck, just use FreeNoise frame split on them too
@maybe_allow_in_graph
class FreeNoiseCogVideoXBlock(nn.Module):
r"""
FreeNoise block used in [CogVideoX](https://github.com/THUDM/CogVideo) model.
Parameters:
dim (`int`):
The number of channels in the input and output.
num_attention_heads (`int`):
The number of heads to use for multi-head attention.
attention_head_dim (`int`):
The number of channels in each head.
time_embed_dim (`int`):
The number of channels in timestep embedding.
dropout (`float`, defaults to `0.0`):
The dropout probability to use.
activation_fn (`str`, defaults to `"gelu-approximate"`):
Activation function to be used in feed-forward.
attention_bias (`bool`, defaults to `False`):
Whether or not to use bias in attention projection layers.
qk_norm (`bool`, defaults to `True`):
Whether or not to use normalization after query and key projections in Attention.
norm_elementwise_affine (`bool`, defaults to `True`):
Whether to use learnable elementwise affine parameters for normalization.
norm_eps (`float`, defaults to `1e-5`):
Epsilon value for normalization layers.
final_dropout (`bool` defaults to `False`):
Whether to apply a final dropout after the last feed-forward layer.
ff_inner_dim (`int`, *optional*, defaults to `None`):
Custom hidden dimension of Feed-forward layer. If not provided, `4 * dim` is used.
ff_bias (`bool`, defaults to `True`):
Whether or not to use bias in Feed-forward layer.
attention_out_bias (`bool`, defaults to `True`):
Whether or not to use bias in Attention output projection layer.
"""
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
time_embed_dim: int,
dropout: float = 0.0,
activation_fn: str = "gelu-approximate",
attention_bias: bool = False,
qk_norm: bool = True,
norm_elementwise_affine: bool = True,
norm_eps: float = 1e-5,
final_dropout: bool = True,
ff_inner_dim: Optional[int] = None,
ff_bias: bool = True,
attention_out_bias: bool = True,
context_length: int = 16,
context_stride: int = 4,
weighting_scheme: str = "pyramid",
prompt_interpolation_callback: Callable[[int, int, torch.Tensor, torch.Tensor], torch.Tensor] = None,
prompt_pooling_callback: Callable[[List[torch.Tensor]], torch.Tensor] = None,
):
super().__init__()
self.dim = dim
self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim
self.time_embed_dim = time_embed_dim
self.dropout = dropout
self.activation_fn = activation_fn
self.attention_bias = attention_bias
self.qk_norm = qk_norm
self.norm_elementwise_affine = norm_elementwise_affine
self.norm_eps = norm_eps
self.final_dropout = final_dropout
self.ff_inner_dim = ff_inner_dim
self.ff_bias = ff_bias
self.attention_out_bias = attention_out_bias
self.set_free_noise_properties(
context_length, context_stride, weighting_scheme, prompt_interpolation_callback, prompt_pooling_callback
)
# 1. Self Attention
self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
self.attn1 = Attention(
query_dim=dim,
dim_head=attention_head_dim,
heads=num_attention_heads,
qk_norm="layer_norm" if qk_norm else None,
eps=1e-6,
bias=attention_bias,
out_bias=attention_out_bias,
processor=CogVideoXAttnProcessor2_0(),
)
# 2. Feed Forward
self.norm2 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
self.ff = FeedForward(
dim,
dropout=dropout,
activation_fn=activation_fn,
final_dropout=final_dropout,
inner_dim=ff_inner_dim,
bias=ff_bias,
)
# Copied from diffusers.models.attention.FreeNoiseTransformerBlock.set_free_noise_properties
def set_free_noise_properties(
self,
context_length: int,
context_stride: int,
weighting_scheme: str = "pyramid",
prompt_interpolation_callback: Callable[[int, int, torch.Tensor, torch.Tensor], torch.Tensor] = None,
prompt_pooling_callback: Callable[[List[torch.Tensor]], torch.Tensor] = None,
) -> None:
if prompt_interpolation_callback is None:
raise ValueError("Must pass a callback to interpolate between prompt embeddings.")
if prompt_pooling_callback is None:
raise ValueError("Must pass a callback to pool prompt embeddings.")
self.context_length = context_length
self.context_stride = context_stride
self.weighting_scheme = weighting_scheme
self.prompt_interpolation_callback = prompt_interpolation_callback
self.prompt_pooling_callback = prompt_pooling_callback
# Copied from diffusers.models.attention.FreeNoiseTransformerBlock._get_frame_indices
def _get_frame_indices(self, num_frames: int) -> List[Tuple[int, int]]:
frame_indices = []
for i in range(0, num_frames - self.context_length + 1, self.context_stride):
window_start = i
window_end = min(num_frames, i + self.context_length)
frame_indices.append((window_start, window_end))
return frame_indices
# Copied from diffusers.models.attention.FreeNoiseTransformerBlock._get_frame_weights
def _get_frame_weights(self, num_frames: int, weighting_scheme: str = "pyramid") -> List[float]:
if weighting_scheme == "flat":
weights = [1.0] * num_frames
elif weighting_scheme == "pyramid":
if num_frames % 2 == 0:
# num_frames = 4 => [1, 2, 2, 1]
mid = num_frames // 2
weights = list(range(1, mid + 1))
weights = weights + weights[::-1]
else:
# num_frames = 5 => [1, 2, 3, 2, 1]
mid = (num_frames + 1) // 2
weights = list(range(1, mid))
weights = weights + [mid] + weights[::-1]
elif weighting_scheme == "delayed_reverse_sawtooth":
if num_frames % 2 == 0:
# num_frames = 4 => [0.01, 2, 2, 1]
mid = num_frames // 2
weights = [0.01] * (mid - 1) + [mid]
weights = weights + list(range(mid, 0, -1))
else:
# num_frames = 5 => [0.01, 0.01, 3, 2, 1]
mid = (num_frames + 1) // 2
weights = [0.01] * mid
weights = weights + list(range(mid, 0, -1))
else:
raise ValueError(f"Unsupported value for weighting_scheme={weighting_scheme}")
return weights
def _prepare_free_noise_encoder_hidden_states(
self,
encoder_hidden_states: Union[
torch.Tensor, List[torch.Tensor], Tuple[Dict[int, torch.Tensor], Optional[Dict[int, torch.Tensor]]]
],
frame_indices: List[int],
) -> List[torch.Tensor]:
if torch.is_tensor(encoder_hidden_states):
encoder_hidden_states = [encoder_hidden_states.clone() for _ in range(len(frame_indices))]
elif isinstance(encoder_hidden_states, tuple):
print("frame_indices:", frame_indices)
pooled_prompt_embeds_list = []
pooled_negative_prompt_embeds_list = []
negative_prompt_embeds_dict, prompt_embeds_dict = encoder_hidden_states
last_frame_start = 0
# For every batch of frames that is to be processed, pool the positive and negative prompt embeddings.
# TODO(aryan): Since this is experimental, I didn't try many different things. I found from testing
# that pooling with previous batch frame embeddings necessary to produce better results and help with
# prompt transitions.
for frame_start, frame_end in frame_indices:
pooled_prompt_embeds = None
pooled_negative_prompt_embeds = None
pooling_list = [
prompt_embeds_dict[i] for i in range(last_frame_start, frame_end) if i in prompt_embeds_dict
]
if len(pooling_list) > 0:
print("pooling", [i for i in range(last_frame_start, frame_end) if i in prompt_embeds_dict])
pooled_prompt_embeds = self.prompt_pooling_callback(pooling_list)
print("after pooling:", pooled_prompt_embeds.isnan().any())
if negative_prompt_embeds_dict is not None:
pooling_list = [
negative_prompt_embeds_dict[i]
for i in range(last_frame_start, frame_end)
if i in negative_prompt_embeds_dict
]
if len(pooling_list) > 0:
print(
"negative pooling", [i for i in range(last_frame_start, frame_end) if i in prompt_embeds_dict]
)
pooled_negative_prompt_embeds = self.prompt_pooling_callback(pooling_list)
print("after negative pooling:", pooled_negative_prompt_embeds.isnan().any())
pooled_prompt_embeds_list.append(pooled_prompt_embeds)
pooled_negative_prompt_embeds_list.append(pooled_negative_prompt_embeds)
last_frame_start = frame_start
assert pooled_prompt_embeds_list[0] is not None
assert pooled_prompt_embeds[-1] is not None
if negative_prompt_embeds_dict is not None:
assert pooled_negative_prompt_embeds_list[0] is not None
assert pooled_negative_prompt_embeds_list[-1] is not None
# If there were no relevant prompts for certain frame batches, interpolate and fill in the gaps
last_existent_embed_index = 0
for i in range(1, len(frame_indices)):
if pooled_prompt_embeds_list[i] is not None and i - last_existent_embed_index > 1:
print("interpolating:", last_existent_embed_index, i)
interpolated_embeds = self.prompt_interpolation_callback(
last_existent_embed_index,
i,
pooled_prompt_embeds_list[last_existent_embed_index],
pooled_prompt_embeds_list[i],
)
print("after interpolating", interpolated_embeds.isnan().any())
pooled_prompt_embeds_list[last_existent_embed_index : i + 1] = interpolated_embeds.split(1, dim=0)
last_existent_embed_index = i
assert all(x is not None for x in pooled_prompt_embeds_list)
if negative_prompt_embeds_dict is not None:
last_existent_embed_index = 0
for i in range(1, len(frame_indices)):
if pooled_negative_prompt_embeds_list[i] is not None and i - last_existent_embed_index > 1:
print("negative interpolating:", last_existent_embed_index, i)
interpolated_embeds = self.prompt_interpolation_callback(
last_existent_embed_index,
i,
pooled_negative_prompt_embeds_list[last_existent_embed_index],
pooled_negative_prompt_embeds_list[i],
)
print("after negative interpolating", interpolated_embeds.isnan().any())
pooled_negative_prompt_embeds_list[
last_existent_embed_index : i + 1
] = interpolated_embeds.split(1, dim=0)
last_existent_embed_index = i
assert all(x is not None for x in pooled_negative_prompt_embeds_list)
if negative_prompt_embeds_dict is not None:
# Classifier-Free Guidance
pooled_prompt_embeds_list = [
torch.cat([negative_prompt_embeds, prompt_embeds])
for negative_prompt_embeds, prompt_embeds in zip(
pooled_negative_prompt_embeds_list, pooled_prompt_embeds_list
)
]
encoder_hidden_states = pooled_prompt_embeds_list
elif not isinstance(encoder_hidden_states, list):
raise ValueError(
f"Expected `encoder_hidden_states` to be a tensor, list of tensor, or a tuple of dictionaries, but found {type(encoder_hidden_states)=}"
)
assert isinstance(encoder_hidden_states, list) and len(encoder_hidden_states) == len(frame_indices)
return encoder_hidden_states
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
num_frames: int = None,
) -> torch.Tensor:
# hidden_states: [B, F x H x W, C]
device = hidden_states.device
dtype = hidden_states.dtype
frame_indices = self._get_frame_indices(num_frames)
frame_weights = self._get_frame_weights(self.context_length, self.weighting_scheme)
frame_weights = (
torch.tensor(frame_weights, device=device, dtype=dtype).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
)
is_last_frame_batch_complete = frame_indices[-1][1] == num_frames
# Handle out-of-bounds case if num_frames isn't perfectly divisible by context_length
# For example, num_frames=25, context_length=16, context_stride=4, then we expect the ranges:
# [(0, 16), (4, 20), (8, 24), (10, 26)]
if not is_last_frame_batch_complete:
if num_frames < self.context_length:
raise ValueError(f"Expected {num_frames=} to be greater or equal than {self.context_length=}")
last_frame_batch_length = num_frames - frame_indices[-1][1]
frame_indices.append((num_frames - self.context_length, num_frames))
# Unflatten frame dimension: [B, F, HW, C]
batch_size, frames_height_width, channels = hidden_states.shape
hidden_states = hidden_states.reshape(batch_size, num_frames, frames_height_width // num_frames, channels)
encoder_hidden_states = self._prepare_free_noise_encoder_hidden_states(encoder_hidden_states, frame_indices)
num_times_accumulated = torch.zeros((1, num_frames, 1, 1), device=device)
accumulated_values = torch.zeros_like(hidden_states)
text_seq_length = _get_text_seq_length(encoder_hidden_states)
for i, (frame_start, frame_end) in enumerate(frame_indices):
# The reason for slicing here is to handle cases like frame_indices=[(0, 16), (16, 20)],
# if the user provided a video with 19 frames, or essentially a non-multiple of `context_length`.
weights = torch.ones_like(num_times_accumulated[:, frame_start:frame_end])
weights *= frame_weights
# Flatten frame dimension: [B, F'HW, C]
hidden_states_chunk = hidden_states[:, frame_start:frame_end].flatten(1, 2)
print(
"debug:",
text_seq_length,
torch.isnan(hidden_states_chunk).any(),
torch.isnan(encoder_hidden_states[i]).any(),
)
# norm & modulate
norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
hidden_states_chunk, encoder_hidden_states[i], temb
)
# attention
attn_hidden_states, attn_encoder_hidden_states = self.attn1(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
)
hidden_states_chunk = hidden_states_chunk + gate_msa * attn_hidden_states
encoder_hidden_states[i] = encoder_hidden_states[i] + enc_gate_msa * attn_encoder_hidden_states
# norm & modulate
norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2(
hidden_states_chunk, encoder_hidden_states[i], temb
)
# feed-forward
norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
ff_output = self.ff(norm_hidden_states)
hidden_states_chunk = hidden_states_chunk + gate_ff * ff_output[:, text_seq_length:]
encoder_hidden_states[i] = encoder_hidden_states[i] + enc_gate_ff * ff_output[:, :text_seq_length]
# Unflatten frame dimension: [B, F', HW, C]
_num_frames = frame_end - frame_start
hidden_states_chunk = hidden_states_chunk.reshape(batch_size, _num_frames, -1, channels)
if i == len(frame_indices) - 1 and not is_last_frame_batch_complete:
accumulated_values[:, -last_frame_batch_length:] += (
hidden_states_chunk[:, -last_frame_batch_length:] * weights[:, -last_frame_batch_length:]
)
num_times_accumulated[:, -last_frame_batch_length:] += weights[:, -last_frame_batch_length]
else:
accumulated_values[:, frame_start:frame_end] += hidden_states_chunk * weights
num_times_accumulated[:, frame_start:frame_end] += weights
# TODO(aryan): Maybe this could be done in a better way.
#
# Previously, this was:
# hidden_states = torch.where(
# num_times_accumulated > 0, accumulated_values / num_times_accumulated, accumulated_values
# )
#
# The reasoning for the change here is `torch.where` became a bottleneck at some point when golfing memory
# spikes. It is particularly noticeable when the number of frames is high. My understanding is that this comes
# from tensors being copied - which is why we resort to spliting and concatenating here. I've not particularly
# looked into this deeply because other memory optimizations led to more pronounced reductions.
hidden_states = torch.cat(
[
torch.where(num_times_split > 0, accumulated_split / num_times_split, accumulated_split)
for accumulated_split, num_times_split in zip(
accumulated_values.split(self.context_length, dim=1),
num_times_accumulated.split(self.context_length, dim=1),
)
],
dim=1,
).to(dtype)
# Flatten frame dimension: [B, FHW, C]
hidden_states = hidden_states.flatten(1, 2)
return hidden_states, encoder_hidden_states
class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
"""
A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo).
@@ -417,13 +836,10 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
emb = self.time_embedding(t_emb, timestep_cond)
# 2. Patch embedding
hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
text_seq_length = _get_text_seq_length(encoder_hidden_states)
encoder_hidden_states, hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
hidden_states = self.embedding_dropout(hidden_states)
text_seq_length = encoder_hidden_states.shape[1]
encoder_hidden_states = hidden_states[:, :text_seq_length]
hidden_states = hidden_states[:, text_seq_length:]
# 3. Transformer blocks
for i, block in enumerate(self.transformer_blocks):
if self.training and self.gradient_checkpointing:
@@ -441,6 +857,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
encoder_hidden_states,
emb,
image_rotary_emb,
num_frames,
**ckpt_kwargs,
)
else:
@@ -449,6 +866,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
encoder_hidden_states=encoder_hidden_states,
temb=emb,
image_rotary_emb=image_rotary_emb,
num_frames=num_frames,
)
if not self.config.use_rotary_positional_embeddings:
@@ -472,3 +890,13 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)
def _get_text_seq_length(x) -> int:
if isinstance(x, torch.Tensor):
return x.shape[1]
if isinstance(x, list):
return _get_text_seq_length(x[0])
if isinstance(x, dict):
return _get_text_seq_length(next(iter(x.values())))
return None

View File

@@ -28,6 +28,7 @@ from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
from ...utils import logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
from ..free_noise_utils import CogVideoXFreeNoiseMixin
from .pipeline_output import CogVideoXPipelineOutput
@@ -136,7 +137,7 @@ def retrieve_timesteps(
return timesteps, num_inference_steps
class CogVideoXPipeline(DiffusionPipeline):
class CogVideoXPipeline(DiffusionPipeline, CogVideoXFreeNoiseMixin):
r"""
Pipeline for text-to-video generation using CogVideoX.
@@ -316,9 +317,17 @@ class CogVideoXPipeline(DiffusionPipeline):
def prepare_latents(
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
):
num_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
# If FreeNoise is enabled, generate latents as described in Equation (7) of [FreeNoise](https://arxiv.org/abs/2310.15169)
if self.free_noise_enabled:
latents = self._prepare_latents_free_noise(
batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents
)
shape = (
batch_size,
(num_frames - 1) // self.vae_scale_factor_temporal + 1,
num_frames,
num_channels_latents,
height // self.vae_scale_factor_spatial,
width // self.vae_scale_factor_spatial,
@@ -392,8 +401,8 @@ class CogVideoXPipeline(DiffusionPipeline):
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
elif prompt is not None and not isinstance(prompt, (str, list, dict)):
raise ValueError(f"`prompt` has to be of type `str`, `list` or `dict` but is {type(prompt)=}")
if prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
@@ -470,7 +479,7 @@ class CogVideoXPipeline(DiffusionPipeline):
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Optional[Union[str, List[str]]] = None,
prompt: Optional[Union[str, List[str], Dict[int, str]]] = None,
negative_prompt: Optional[Union[str, List[str]]] = None,
height: int = 480,
width: int = 720,
@@ -497,7 +506,7 @@ class CogVideoXPipeline(DiffusionPipeline):
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
prompt (`str` or `List[str]` or `Dict[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
negative_prompt (`str` or `List[str]`, *optional*):
@@ -569,11 +578,6 @@ class CogVideoXPipeline(DiffusionPipeline):
`tuple`. When returning a tuple, the first element is a list with the generated images.
"""
if num_frames > 49:
raise ValueError(
"The number of frames must be less than 49 for now due to static positional embeddings. This will be updated in the future to remove this limitation."
)
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
@@ -595,7 +599,7 @@ class CogVideoXPipeline(DiffusionPipeline):
self._interrupt = False
# 2. Default call parameters
if prompt is not None and isinstance(prompt, str):
if prompt is not None and isinstance(prompt, (str, dict)):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
@@ -610,18 +614,33 @@ class CogVideoXPipeline(DiffusionPipeline):
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt,
negative_prompt,
do_classifier_free_guidance,
num_videos_per_prompt=num_videos_per_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
max_sequence_length=max_sequence_length,
device=device,
)
if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
if self.free_noise_enabled:
prompt_embeds, negative_prompt_embeds = self._encode_prompt_free_noise(
prompt=prompt,
num_frames=num_frames,
device=device,
num_videos_per_prompt=num_videos_per_prompt,
do_classifier_free_guidance=do_classifier_free_guidance,
negative_prompt=negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
)
prompt_embeds_dtype = next(iter(prompt_embeds[0].values())).dtype
else:
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt,
negative_prompt,
do_classifier_free_guidance,
num_videos_per_prompt=num_videos_per_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
max_sequence_length=max_sequence_length,
device=device,
)
prompt_embeds_dtype = prompt_embeds.dtype
if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
# 4. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
@@ -635,7 +654,7 @@ class CogVideoXPipeline(DiffusionPipeline):
num_frames,
height,
width,
prompt_embeds.dtype,
prompt_embeds_dtype,
device,
generator,
latents,
@@ -699,7 +718,7 @@ class CogVideoXPipeline(DiffusionPipeline):
**extra_step_kwargs,
return_dict=False,
)
latents = latents.to(prompt_embeds.dtype)
latents = latents.to(prompt_embeds_dtype)
# call the callback, if provided
if callback_on_step_end is not None:

View File

@@ -12,13 +12,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import functools
from typing import Callable, Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
from ..models.attention import BasicTransformerBlock, FreeNoiseTransformerBlock
from ..models.embeddings import CogVideoXPatchEmbed, FreeNoiseCogVideoXPatchEmbed
from ..models.resnet import Downsample2D, ResnetBlock2D, Upsample2D
from ..models.transformers.cogvideox_transformer_3d import (
CogVideoXBlock,
CogVideoXTransformer3DModel,
FreeNoiseCogVideoXBlock,
)
from ..models.transformers.transformer_2d import Transformer2DModel
from ..models.unets.unet_motion_model import (
AnimateDiffTransformer3D,
@@ -26,7 +34,6 @@ from ..models.unets.unet_motion_model import (
DownBlockMotion,
UpBlockMotion,
)
from ..pipelines.pipeline_utils import DiffusionPipeline
from ..utils import logging
from ..utils.torch_utils import randn_tensor
@@ -34,6 +41,56 @@ from ..utils.torch_utils import randn_tensor
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def _lerp(start_index: int, end_index: int, start_tensor: torch.Tensor, end_tensor: torch.Tensor) -> torch.Tensor:
num_indices = end_index - start_index + 1
interpolated_tensors = []
for i in range(num_indices):
alpha = i / (num_indices - 1)
interpolated_tensor = (1 - alpha) * start_tensor + alpha * end_tensor
interpolated_tensors.append(interpolated_tensor)
interpolated_tensors = torch.cat(interpolated_tensors)
return interpolated_tensors
def _weighted_pooling(
hidden_states: List[torch.Tensor], pooling_type: str = "inverse_distance", decay_rate: float = 0.9
) -> torch.Tensor:
length = len(hidden_states)
if length == 0:
raise ValueError("This method cannot be called with an empty list.")
if length == 1:
return hidden_states[0]
if pooling_type == "average":
weights = [1] * length
elif pooling_type == "exponential_decay":
weights = [decay_rate**i for i in range(length)]
elif pooling_type == "reverse_exponential_decay":
weights = [decay_rate**i for i in range(length)][::-1]
elif pooling_type == "distance":
weights = list(range(1, length + 1))
elif pooling_type == "reverse_distance":
weights = list(range(1, length + 1))[::-1]
elif pooling_type == "pyramid":
if length % 2 == 0:
weights = list(range(1, length // 2 + 1))
weights = weights + weights[::-1]
else:
weights = list(range(1, length // 2 + 1))
weights = weights + [length // 2] + weights[::-1]
else:
raise ValueError("Invalid weighted pooling type.")
weights = torch.tensor(weights, device=hidden_states[0].device, dtype=hidden_states[0].dtype)
weights = weights / weights.sum()
pooled_embeds = sum(weight * hidden_state for weight, hidden_state in zip(weights, hidden_states))
return pooled_embeds
class SplitInferenceModule(nn.Module):
r"""
A wrapper module class that splits inputs along a specified dimension before performing a forward pass.
@@ -143,7 +200,7 @@ class SplitInferenceModule(nn.Module):
class AnimateDiffFreeNoiseMixin:
r"""Mixin class for [FreeNoise](https://arxiv.org/abs/2310.15169)."""
r"""Mixin class for [FreeNoise](https://arxiv.org/abs/2310.15169) as used in AnimateDiff."""
def _enable_free_noise_in_block(self, block: Union[CrossAttnDownBlockMotion, DownBlockMotion, UpBlockMotion]):
r"""Helper function to enable FreeNoise in transformer blocks."""
@@ -265,7 +322,7 @@ class AnimateDiffFreeNoiseMixin:
negative_prompt_embeds: Optional[torch.Tensor] = None,
lora_scale: Optional[float] = None,
clip_skip: Optional[int] = None,
) -> torch.Tensor:
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
if negative_prompt is None:
negative_prompt = ""
@@ -427,29 +484,13 @@ class AnimateDiffFreeNoiseMixin:
latents = latents[:, :, :num_frames]
return latents
def _lerp(
self, start_index: int, end_index: int, start_tensor: torch.Tensor, end_tensor: torch.Tensor
) -> torch.Tensor:
num_indices = end_index - start_index + 1
interpolated_tensors = []
for i in range(num_indices):
alpha = i / (num_indices - 1)
interpolated_tensor = (1 - alpha) * start_tensor + alpha * end_tensor
interpolated_tensors.append(interpolated_tensor)
interpolated_tensors = torch.cat(interpolated_tensors)
return interpolated_tensors
def enable_free_noise(
self,
context_length: Optional[int] = 16,
context_stride: int = 4,
weighting_scheme: str = "pyramid",
noise_type: str = "shuffle_context",
prompt_interpolation_callback: Optional[
Callable[[DiffusionPipeline, int, int, torch.Tensor, torch.Tensor], torch.Tensor]
] = None,
prompt_interpolation_callback: Optional[Callable[[int, int, torch.Tensor, torch.Tensor], torch.Tensor]] = None,
) -> None:
r"""
Enable long video generation using FreeNoise.
@@ -505,7 +546,7 @@ class AnimateDiffFreeNoiseMixin:
self._free_noise_context_stride = context_stride
self._free_noise_weighting_scheme = weighting_scheme
self._free_noise_noise_type = noise_type
self._free_noise_prompt_interpolation_callback = prompt_interpolation_callback or self._lerp
self._free_noise_prompt_interpolation_callback = prompt_interpolation_callback or _lerp
if hasattr(self.unet.mid_block, "motion_modules"):
blocks = [*self.unet.down_blocks, self.unet.mid_block, *self.unet.up_blocks]
@@ -594,3 +635,402 @@ class AnimateDiffFreeNoiseMixin:
@property
def free_noise_enabled(self):
return hasattr(self, "_free_noise_context_length") and self._free_noise_context_length is not None
class CogVideoXFreeNoiseMixin:
r"""Mixin class for [FreeNoise](https://arxiv.org/abs/2310.15169) as used in CogVideoX."""
def _enable_free_noise_in_block(self, transformer: CogVideoXTransformer3DModel):
r"""Helper function to enable FreeNoise in transformer blocks."""
patch_embed = transformer.patch_embed
transformer.patch_embed = FreeNoiseCogVideoXPatchEmbed(
patch_size=patch_embed.patch_size,
in_channels=patch_embed.in_channels,
embed_dim=patch_embed.embed_dim,
text_embed_dim=patch_embed.text_embed_dim,
bias=patch_embed.bias,
sample_width=patch_embed.sample_width,
sample_height=patch_embed.sample_height,
sample_frames=patch_embed.sample_frames,
temporal_compression_ratio=patch_embed.temporal_compression_ratio,
max_text_seq_length=patch_embed.max_text_seq_length,
spatial_interpolation_scale=patch_embed.spatial_interpolation_scale,
temporal_interpolation_scale=patch_embed.temporal_interpolation_scale,
use_positional_embeddings=patch_embed.use_positional_embeddings,
).to(device=self.device, dtype=self.dtype)
transformer.patch_embed.load_state_dict(patch_embed.state_dict(), strict=True)
for i in range(len(transformer.transformer_blocks)):
block = transformer.transformer_blocks[i]
if isinstance(block, FreeNoiseCogVideoXBlock):
block.set_free_noise_properties(
self._free_noise_context_length,
self._free_noise_context_stride,
self._free_noise_weighting_scheme,
self._free_noise_prompt_interpolation_callback,
self._free_noise_prompt_pooling_callback,
)
else:
transformer.transformer_blocks[i] = FreeNoiseCogVideoXBlock(
dim=block.dim,
num_attention_heads=block.num_attention_heads,
attention_head_dim=block.attention_head_dim,
time_embed_dim=block.time_embed_dim,
dropout=block.dropout,
activation_fn=block.activation_fn,
attention_bias=block.attention_bias,
qk_norm=block.qk_norm,
norm_elementwise_affine=block.norm_elementwise_affine,
norm_eps=block.norm_eps,
final_dropout=block.final_dropout,
ff_inner_dim=block.ff_inner_dim,
ff_bias=block.ff_bias,
attention_out_bias=block.attention_out_bias,
context_length=self._free_noise_context_length,
context_stride=self._free_noise_context_stride,
weighting_scheme=self._free_noise_weighting_scheme,
prompt_interpolation_callback=self._free_noise_prompt_interpolation_callback,
prompt_pooling_callback=self._free_noise_prompt_pooling_callback,
).to(device=self.device, dtype=self.dtype)
transformer.transformer_blocks[i].load_state_dict(block.state_dict(), strict=True)
def _disable_free_noise_in_block(self, transformer: CogVideoXTransformer3DModel):
r"""Helper function to disable FreeNoise in transformer blocks."""
patch_embed = transformer.patch_embed
transformer.patch_embed = CogVideoXPatchEmbed(
patch_size=patch_embed.patch_size,
in_channels=patch_embed.in_channels,
embed_dim=patch_embed.embed_dim,
text_embed_dim=patch_embed.text_embed_dim,
bias=patch_embed.bias,
sample_width=patch_embed.sample_width,
sample_height=patch_embed.sample_height,
sample_frames=patch_embed.sample_frames,
temporal_compression_ratio=patch_embed.temporal_compression_ratio,
max_text_seq_length=patch_embed.max_text_seq_length,
spatial_interpolation_scale=patch_embed.spatial_interpolation_scale,
temporal_interpolation_scale=patch_embed.temporal_interpolation_scale,
use_positional_embeddings=patch_embed.use_positional_embeddings,
).to(device=self.device, dtype=self.dtype)
transformer.patch_embed.load_state_dict(patch_embed.state_dict(), strict=True)
for i in range(len(transformer.transformer_blocks)):
block = transformer.transformer_blocks[i]
if isinstance(block, FreeNoiseCogVideoXBlock):
transformer.transformer_blocks[i] = CogVideoXBlock(
dim=block.dim,
num_attention_heads=block.num_attention_heads,
attention_head_dim=block.attention_head_dim,
time_embed_dim=block.time_embed_dim,
dropout=block.dropout,
activation_fn=block.activation_fn,
attention_bias=block.attention_bias,
qk_norm=block.qk_norm,
norm_elementwise_affine=block.norm_elementwise_affine,
norm_eps=block.norm_eps,
final_dropout=block.final_dropout,
ff_inner_dim=block.ff_inner_dim,
ff_bias=block.ff_bias,
attention_out_bias=block.attention_out_bias,
).to(device=self.device, dtype=self.dtype)
transformer.transformer_blocks[i].load_state_dict(block.state_dict(), strict=True)
def _check_inputs_free_noise(
self,
prompt,
negative_prompt,
prompt_embeds,
negative_prompt_embeds,
num_frames,
) -> None:
if not isinstance(prompt, (str, dict)):
raise ValueError(f"Expected `prompt` to have type `str` or `dict` but found {type(prompt)=}")
if negative_prompt is not None:
if not isinstance(negative_prompt, (str, dict)):
raise ValueError(
f"Expected `negative_prompt` to have type `str` or `dict` but found {type(negative_prompt)=}"
)
if prompt_embeds is not None or negative_prompt_embeds is not None:
raise ValueError("`prompt_embeds` and `negative_prompt_embeds` is not supported in FreeNoise yet.")
frame_indices = [isinstance(x, int) for x in prompt.keys()]
frame_prompts = [isinstance(x, str) for x in prompt.values()]
min_frame = min(list(prompt.keys()))
max_frame = max(list(prompt.keys()))
if not all(frame_indices):
raise ValueError("Expected integer keys in `prompt` dict for FreeNoise.")
if not all(frame_prompts):
raise ValueError("Expected str values in `prompt` dict for FreeNoise.")
if min_frame != 0:
raise ValueError("The minimum frame index in `prompt` dict must be 0 as a starting prompt is necessary.")
if max_frame >= num_frames:
raise ValueError(
f"The maximum frame index in `prompt` dict must be lesser than {num_frames=} and follow 0-based indexing."
)
def _encode_prompt_free_noise(
self,
prompt: Union[str, Dict[int, str]],
num_frames: int,
device: torch.device,
num_videos_per_prompt: int,
do_classifier_free_guidance: bool,
negative_prompt: Optional[Union[str, Dict[int, str]]] = None,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
max_sequence_length: Optional[int] = None,
) -> Tuple[Dict[int, torch.Tensor], Optional[Dict[int, torch.Tensor]]]:
if negative_prompt is None:
negative_prompt = ""
# Ensure that we have a dictionary of prompts
if isinstance(prompt, str):
prompt = {0: prompt}
if isinstance(negative_prompt, str):
negative_prompt = {0: negative_prompt}
self._check_inputs_free_noise(prompt, negative_prompt, prompt_embeds, negative_prompt_embeds, num_frames)
# Sort the prompts based on frame indices
prompt = dict(sorted(prompt.items()))
negative_prompt = dict(sorted(negative_prompt.items()))
# Ensure that we have a prompt for the last frame index
prompt[num_frames - 1] = prompt[list(prompt.keys())[-1]]
negative_prompt[num_frames - 1] = negative_prompt[list(negative_prompt.keys())[-1]]
frame_indices = list(prompt.keys())
frame_prompts = list(prompt.values())
frame_negative_indices = list(negative_prompt.keys())
frame_negative_prompts = list(negative_prompt.values())
# Generate and bucketify positive prompts
prompt_embeds, _ = self.encode_prompt(
prompt=frame_prompts,
device=device,
num_videos_per_prompt=num_videos_per_prompt,
do_classifier_free_guidance=False,
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
max_sequence_length=max_sequence_length,
)
# If many prompts fall into the same latent frame bucket, pool them
prompt_embeds_frame_buckets = collections.defaultdict(lambda: [])
for i in range(len(frame_indices)):
latent_frame_index = frame_indices[i] // self.transformer.config.temporal_compression_ratio
prompt_embeds_frame_buckets[latent_frame_index].append(prompt_embeds[i].unsqueeze(0))
for frame_index, embeds in list(prompt_embeds_frame_buckets.items()):
prompt_embeds_frame_buckets[frame_index] = self._free_noise_prompt_pooling_callback(embeds)
# Generate and bucketify negative prompts
negative_prompt_embeds_frame_buckets = None
if do_classifier_free_guidance:
_, negative_prompt_embeds = self.encode_prompt(
prompt=[""] * len(frame_negative_prompts),
device=device,
num_videos_per_prompt=num_videos_per_prompt,
do_classifier_free_guidance=True,
negative_prompt=frame_negative_prompts,
prompt_embeds=None,
negative_prompt_embeds=None,
max_sequence_length=max_sequence_length,
)
# If many prompts fall into the same latent frame bucket, pool them
negative_prompt_embeds_frame_buckets = collections.defaultdict(lambda: [])
for i in range(len(frame_negative_indices)):
latent_frame_index = frame_negative_indices[i] // self.transformer.config.temporal_compression_ratio
negative_prompt_embeds_frame_buckets[latent_frame_index].append(negative_prompt_embeds[i].unsqueeze(0))
for frame_index, embeds in list(negative_prompt_embeds_frame_buckets.items()):
negative_prompt_embeds_frame_buckets[frame_index] = self._free_noise_prompt_pooling_callback(embeds)
prompt_embeds_frame_buckets = (negative_prompt_embeds_frame_buckets, prompt_embeds_frame_buckets)
else:
prompt_embeds_frame_buckets = (prompt_embeds_frame_buckets,)
return prompt_embeds_frame_buckets, negative_prompt_embeds_frame_buckets
def _prepare_latents_free_noise(
self,
batch_size: int,
num_channels_latents: int,
num_frames: int,
height: int,
width: int,
dtype: torch.dtype,
device: torch.device,
generator: Optional[torch.Generator] = None,
latents: Optional[torch.Tensor] = None,
):
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
context_num_frames = (
self._free_noise_context_length if self._free_noise_context_length == "repeat_context" else num_frames
)
shape = (
batch_size,
context_num_frames,
num_channels_latents,
height // self.vae_scale_factor_spatial,
width // self.vae_scale_factor_spatial,
)
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
if self._free_noise_noise_type == "random":
return latents
else:
if latents.size(2) == num_frames:
return latents
elif latents.size(2) != self._free_noise_context_length:
raise ValueError(
f"You have passed `latents` as a parameter to FreeNoise. The expected number of frames is either {num_frames} or {self._free_noise_context_length}, but found {latents.size(2)}"
)
latents = latents.to(device)
if self._free_noise_noise_type == "shuffle_context":
for i in range(self._free_noise_context_length, num_frames, self._free_noise_context_stride):
# ensure window is within bounds
window_start = max(0, i - self._free_noise_context_length)
window_end = min(num_frames, window_start + self._free_noise_context_stride)
window_length = window_end - window_start
if window_length == 0:
break
indices = torch.LongTensor(list(range(window_start, window_end)))
shuffled_indices = indices[torch.randperm(window_length, generator=generator)]
current_start = i
current_end = min(num_frames, current_start + window_length)
if current_end == current_start + window_length:
# batch of frames perfectly fits the window
latents[:, current_start:current_end] = latents[:, shuffled_indices]
else:
# handle the case where the last batch of frames does not fit perfectly with the window
prefix_length = current_end - current_start
shuffled_indices = shuffled_indices[:prefix_length]
latents[:, current_start:current_end] = latents[:, shuffled_indices]
elif self._free_noise_noise_type == "repeat_context":
num_repeats = (num_frames + self._free_noise_context_length - 1) // self._free_noise_context_length
latents = torch.cat([latents] * num_repeats, dim=1)
latents = latents[:, :num_frames]
return latents
def enable_free_noise(
self,
context_length: Optional[int] = 13, # 49 pixel-space frames
context_stride: int = 4, # 16 pixel-space frames
weighting_scheme: str = "pyramid",
noise_type: str = "shuffle_context",
prompt_interpolation_callback: Optional[Callable[[int, int, torch.Tensor, torch.Tensor], torch.Tensor]] = None,
prompt_pooling_type_or_callback: Optional[Union[str, Callable[[List[torch.Tensor]], torch.Tensor]]] = None,
) -> None:
r"""
Enable long video generation using FreeNoise.
Args:
context_length (`int`, defaults to `16`, *optional*):
The number of video frames to process at once. It's recommended to set this to the maximum frames the
Motion Adapter was trained with (usually 16/24/32). If `None`, the default value from the motion
adapter config is used.
context_stride (`int`, *optional*):
Long videos are generated by processing many frames. FreeNoise processes these frames in sliding
windows of size `context_length`. Context stride allows you to specify how many frames to skip between
each window. For example, a context length of 16 and context stride of 4 would process 24 frames as:
[0, 15], [4, 19], [8, 23] (0-based indexing)
weighting_scheme (`str`, defaults to `pyramid`):
Weighting scheme for averaging latents after accumulation in FreeNoise blocks. The following weighting
schemes are supported currently:
- "flat"
Performs weighting averaging with a flat weight pattern: [1, 1, 1, 1, 1].
- "pyramid"
Performs weighted averaging with a pyramid like weight pattern: [1, 2, 3, 2, 1].
- "delayed_reverse_sawtooth"
Performs weighted averaging with low weights for earlier frames and high-to-low weights for
later frames: [0.01, 0.01, 3, 2, 1].
noise_type (`str`, defaults to "shuffle_context"):
Must be one of ["shuffle_context", "repeat_context", "random"].
- "shuffle_context"
Shuffles a fixed batch of `context_length` latents to create a final latent of size
`num_frames`. This is usually the best setting for most generation scenarious. However, there
might be visible repetition noticeable in the kinds of motion/animation generated.
- "repeated_context"
Repeats a fixed batch of `context_length` latents to create a final latent of size
`num_frames`.
- "random"
The final latents are random without any repetition.
"""
allowed_weighting_scheme = ["flat", "pyramid", "delayed_reverse_sawtooth"]
allowed_noise_type = ["shuffle_context", "repeat_context", "random"]
# Note: we need to do this because when CogVideoX was originally introduced, it used the wrong value
# for `sample_frames`. It should have been 13 and not 49 because the expected number of latent frames
# in the transformer is 13.
# Maybe try and look into what can be done before 1.0.0 release.
if context_length % 2 == 0:
sample_frames_context_length = context_length * self.transformer.config.temporal_compression_ratio
else:
sample_frames_context_length = (
context_length - 1
) * self.transformer.config.temporal_compression_ratio + 1
if sample_frames_context_length > self.transformer.config.sample_frames:
logger.warning(
f"You have set {context_length=} which is greater than {self.motion_adapter.config.motion_max_seq_length=}. This can lead to bad generation results."
)
if weighting_scheme not in allowed_weighting_scheme:
raise ValueError(
f"The parameter `weighting_scheme` must be one of {allowed_weighting_scheme}, but got {weighting_scheme=}"
)
if noise_type not in allowed_noise_type:
raise ValueError(f"The parameter `noise_type` must be one of {allowed_noise_type}, but got {noise_type=}")
self._free_noise_context_length = context_length or self.motion_adapter.config.motion_max_seq_length
self._free_noise_context_stride = context_stride
self._free_noise_weighting_scheme = weighting_scheme
self._free_noise_noise_type = noise_type
self._free_noise_prompt_interpolation_callback = prompt_interpolation_callback or _lerp
if prompt_pooling_type_or_callback is None:
prompt_pooling_type_or_callback = "pyramid"
if isinstance(prompt_pooling_type_or_callback, str):
self._free_noise_prompt_pooling_callback = functools.partial(
_weighted_pooling, pooling_type=prompt_pooling_type_or_callback
)
else:
self._free_noise_prompt_pooling_callback = prompt_pooling_type_or_callback
self._enable_free_noise_in_block(self.transformer)
def disable_free_noise(self) -> None:
r"""Disable the FreeNoise sampling mechanism."""
self._free_noise_context_length = None
self._disable_free_noise_in_block(self.transformer)
@property
def free_noise_enabled(self):
return hasattr(self, "_free_noise_context_length") and self._free_noise_context_length is not None